In [None]:
#Download dataset and install packages
%%capture
!wget https://madm.dfki.de/files/sentinel/EuroSATallBands.zip --no-check-certificate
!unzip -q EuroSATallBands.zip
!pip install rasterio
!pip install einops

#Required Packages

In [None]:
#Import packages
import os
import glob
import numpy as np
import rasterio as rio
from rasterio.plot import reshape_as_image
import matplotlib.pyplot as plt
import torchvision
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch.nn as nn
import tqdm 
from torch import optim
from math import sqrt
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import re

#Prepare Dataset

In [None]:
#Data directory
eurosat_dir = "/content/ds/images/remote_sensing/otherDatasets/sentinel_2/tif"

#Classes
classes = [
    "AnnualCrop",
    "Forest",
    "HerbaceousVegetation",
    "Highway",
    "Industrial",
    "Pasture",
    "PermanentCrop",
    "Residential",
    "River",
    "SeaLake",
]

#Class labels
class_labels = {
    0: "AnnualCrop",
    1: "Forest",
    2: "HerbaceousVegetation",
    3: "Highway",
    4: "Industrial",
    5: "Pasture",
    6: "PermanentCrop",
    7: "Residential",
    8: "River",
    9: "SeaLake",
}

#Data augmentation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90)
])

#Prepare dataset
class EurosatDataset(Dataset):
    def __init__(self, data_dir, classes, transform=None):
        self.data_dir = data_dir
        self.classes = classes
        self.transform = transform
        self.file_paths = []
        self.labels = []

        for i, class_name in self.classes.items():
            class_dir = os.path.join(self.data_dir, class_name)
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                self.file_paths.append(file_path)
                self.labels.append(i)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
            file_path = self.file_paths[index]
            with rio.open(file_path, "r") as src:
                # read all the bands in the file
                img = src.read()
            img = img.astype(np.int32)
            img = img / 10000
            # move channel 13 to after channel 8
            img = np.concatenate((img[:8,:,:], img[12:13,:,:], img[9:10,:,:], img[11:13,:,:]), axis=0)
            # calculate NDVI
            ndvi = (img[8,:,:] - img[3,:,:]) / (img[8,:,:] + img[3,:,:] + 1e-10)
            ndvi = ndvi.astype(np.float32)
            # append NDVI band to tensor
            img = np.concatenate((img, ndvi[np.newaxis,:,:]), axis=0)
            
            img = np.concatenate((img[0:4,:,:], img[12:13,:,:]), axis=0)
            img = img[1:4,:,:]
            if self.transform is not None:
                img = self.transform(img)
            
                #img = img.permute(1, 0, 2)
                img = img.permute(1,2,0)
            label = self.labels[index]
            return img, label

# Instantiate the dataset
dataset = EurosatDataset(eurosat_dir, class_labels, transform)

# Split the data into training and testing sets
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.2, stratify=dataset.labels)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

# Define the data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#Vision Transformer Model Design

In [None]:
# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class LSA(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()

        mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
        mask_value = -torch.finfo(dots.dtype).max
        dots = dots.masked_fill(mask, mask_value)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class SPT(nn.Module):
    def __init__(self, *, dim, patch_size, channels = 3):
        super().__init__()
        patch_dim = patch_size * patch_size * 5 * channels

        self.to_patch_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim)
        )

    def forward(self, x):
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        return self.to_patch_tokens(x_with_shifts)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

#Create model
model = ViT(
    image_size = 64,
    patch_size = 4,
    num_classes = 10,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 1024,
    #Regularization   
    dropout = 0.01,
    emb_dropout = 0.01
)

#Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transfer model to compute device
model.to(device)

ViT(
  (to_patch_embedding): SPT(
    (to_patch_tokens): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
      (1): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
      (2): Linear(in_features=240, out_features=1024, bias=True)
    )
  )
  (dropout): Dropout(p=0.01, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): LSA(
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.01, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1024, out_features=1024, bias=True)
              (1): Dropout(p=0.01, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (f

#Model Training

In [None]:
num_epochs = 2
# Init collection of training epoch losses
train_epoch_losses = []
train_epoch_accs = []

# Set the model in training mode
model.train()

# Define the optimization criterion / loss function
cross_entropy = nn.CrossEntropyLoss()

# Define learning rate and optimization strategy
learning_rate = 0.001

optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)

# Train for n epochs
for epoch in range(num_epochs):
    
    # Init collection of mini-batch losses
    train_mini_batch_losses = []
    train_mini_batch_accs = []
    train_loader_progress = tqdm.tqdm(train_loader)

    # Update for each min-batch
    for i, (x, y) in enumerate(train_loader_progress):
        
        # Transfer data to compute device
        x, y = x.to(device), y.to(device)
        # Forward pass
        x = x.float()
        pred = model(x)
        # Reset model's gradients
        model.zero_grad()
        # Compute loss
        loss = cross_entropy(pred, y)
        acc = (pred.argmax(dim=1) == y).float().mean() * 100
        # Run backward pass
        loss.backward()
        # Update network paramaters
        optimizer.step()
        
        # Store mini-batch losses
        train_mini_batch_losses.append(loss.data.item())
        train_loader_progress.set_description(f"Loss: {loss.item():0.5f}")
        train_mini_batch_accs.append(acc.item())    
        train_loader_progress.set_description(f"Loss: {loss.item():0.5f} - Acc: {acc.item():0.2f}%")

    # Compute epoch loss
    train_epoch_loss = np.mean(train_mini_batch_losses)
    train_epoch_losses.append(train_epoch_loss)
    train_epoch_acc = np.mean(train_mini_batch_accs)
    train_epoch_accs.append(train_epoch_acc)

    print(f"Epoch {epoch} - Loss: {train_epoch_loss:0.5f} - Acc: {train_epoch_acc:0.2f}%") 
    # Get the last accuracy value
    last_acc = train_epoch_accs[-1]
    # Check if the last value is larger than all previous values
    if all(last_acc > acc for acc in train_epoch_accs[:-1]):
      # Save final model 
      torch.save(model.state_dict(), '/content/model.pt')

torch.save(model.state_dict(), '/content/model_final.pt')

Loss: 2.27870 - Acc: 12.50%: 100%|██████████| 675/675 [09:29<00:00,  1.19it/s]


Epoch 0 - Loss: 2.40991 - Acc: 11.43%


Loss: 2.31996 - Acc: 15.62%: 100%|██████████| 675/675 [09:24<00:00,  1.20it/s]


Epoch 1 - Loss: 2.24696 - Acc: 15.20%


#Model Evaluation

In [None]:
#Loss criterion
criterion = nn.CrossEntropyLoss()

# Set the model to evaluation mode
model.eval()
# Initialize some variables to keep track of the accuracy and loss
correct = 0
total = 0
test_loss = 0

# Turn off gradients
with torch.no_grad():
    # Loop through the test data in batches
    for images, labels in test_loader:
        # Convert the images and labels to the correct datatype
        images = images.float().to(device)
        labels = labels.long().to(device)
        # Make predictions using the model
        outputs = model(images)
        # Calculate the loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        # Get the predicted classes
        _, predicted = torch.max(outputs.data, 1)
        # Update the correct and total counts
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Print the accuracy and loss
print('Accuracy on test dataset: {:.2f}%'.format(100 * correct / total))
print('Average test loss: {:.4f}'.format(test_loss / len(test_loader)))

Accuracy on test dataset: 14.78%
Average test loss: 2.2011
