In [1]:
from google.colab import drive
drive.mount('/content/drive') 

Mounted at /content/drive


In [None]:
!unzip "/content/drive/My Drive/Colab Notebooks/AI/ViT/satellite_images.zip"

In [9]:
patch_size = 64 
stride = 64
num_of_patches = 16
latent_dimension = 512 
num_of_classes = 4

In [10]:
from torch.nn import Module, Flatten, Parameter, Linear
from torch.nn import ModuleList
from torch.nn import Softmax
from torch.nn import Tanh, Sigmoid
from torch.nn import GELU
from torch.nn import LayerNorm
from math import sqrt

class InputLayer(Module):

  def __init__(self):
    super(InputLayer, self).__init__()

    self.flattening = Flatten(start_dim=1)
    self.dense_1 = Linear(patch_size*patch_size*3, latent_dimension) 
    

  def forward(self, input):
    flattened_patches = self.flattening(input)
    # eventualmente da spostare 
    dev_type = "cuda:0" if torch.cuda.is_available() else "cpu"
    class_embedding = torch.rand(patch_size*patch_size*3)
    class_embedding = class_embedding.to(dev_type)
    flattened_patches = torch.cat((class_embedding.unsqueeze(0), flattened_patches), dim=0) 
    return self.dense_1(flattened_patches)

class PositionalEncoding(Module):
  
  def __init__(self):
    super(PositionalEncoding, self).__init__()
    self.E_pos = Parameter(torch.randn(num_of_patches+1, latent_dimension)) 
    self.E_pos.requires_grad = True
  
  def forward(self, input):
    return input + self.E_pos

class ScaledDotProductAttention(Module):
  
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
    self.softmax = Softmax(dim=1) 

  def forward(self, input):
    Q, K, V = input[0], input[1], input[2]
    K_transposed = torch.transpose(K, 0, 1)
    mat_product = torch.matmul(Q, K_transposed)
    scaled_mat_product = torch.div(mat_product, sqrt(latent_dimension))
    normalized_mat_product = self.softmax(scaled_mat_product)
    matmul = torch.matmul(normalized_mat_product, V)
    return matmul

class MultiHeadAttention(Module):

  def __init__(self):
    super(MultiHeadAttention, self).__init__()
    self.num_of_stacked_blocks = 16 
    self.projected_dim = int(latent_dimension/self.num_of_stacked_blocks) 

    self.linear_layers_Q = ModuleList([Linear(latent_dimension, self.projected_dim) for i in range(self.num_of_stacked_blocks)])
    self.linear_layers_K = ModuleList([Linear(latent_dimension, self.projected_dim) for i in range(self.num_of_stacked_blocks)])
    self.linear_layers_V = ModuleList([Linear(latent_dimension, self.projected_dim) for i in range(self.num_of_stacked_blocks)])

    self.scaled_dot_prod_attention_layers = ModuleList([ScaledDotProductAttention() for i in range(self.num_of_stacked_blocks)])

    self.attention_values = [None for i in range(self.num_of_stacked_blocks)]
    
    self.dense = Linear(self.projected_dim * self.num_of_stacked_blocks, latent_dimension) 

  def forward(self, input):
    # V = K = Q
    for i in range(self.num_of_stacked_blocks):
      Q, K, V = self.linear_layers_Q[i](input), self.linear_layers_K[i](input), self.linear_layers_V[i](input)
      scaled_dot_product_input = [Q, K, V]
      self.attention_values[i] = self.scaled_dot_prod_attention_layers[i](scaled_dot_product_input)

    concat_attention_values = torch.cat(self.attention_values, dim=1) 

    return self.dense(concat_attention_values)


class MlpLayer(Module):

  def __init__(self):
    super(MlpLayer, self).__init__()
    
    self.gelu_1 = GELU()
    self.gelu_2 = GELU()

    self.linear_1 = Linear(latent_dimension, 4096)
    self.linear_2 = Linear(4096, latent_dimension)

  def forward(self, input):
    return self.gelu_2(self.linear_2(self.gelu_1(self.linear_1(input))))

class TransformerEncoder(Module):

  def __init__(self):
    super(TransformerEncoder, self).__init__()

    self.multi_head_att = MultiHeadAttention()
    
    self.layer_norm_1 = LayerNorm([num_of_patches+1, latent_dimension])
    self.layer_norm_2 = LayerNorm([num_of_patches+1, latent_dimension])

    self.mlp = MlpLayer()

  def forward(self, input):
    normalized_input = self.layer_norm_1(input)
    attention_values = self.multi_head_att(normalized_input)
    mlp_input = input + attention_values
    norm_mlp_input = self.layer_norm_2(mlp_input)
    out = self.mlp(norm_mlp_input)
    return out + mlp_input

class MlpHead(Module):
  
  def __init__(self):
    super(MlpHead, self).__init__()

    self.tanh = Tanh()
    
    self.linear_1 = Linear(latent_dimension, 1024) 
    self.tanh = Tanh()
    self.linear_2 = Linear(1024, num_of_classes) 
    self.sigmoid = Sigmoid()

  def forward(self, input):
    class_embedding = input[0]
    return self.sigmoid(self.linear_2(self.tanh(self.linear_1(class_embedding))))

class VisionTransformer(Module):

  def __init__(self):
    super(VisionTransformer, self).__init__()
    self.num_of_encoders = 24

    self.input_layer = InputLayer()
    self.positional_encoding = PositionalEncoding()
    self.transformer_encoders = ModuleList([TransformerEncoder() for i in range(self.num_of_encoders)])
    self.mlp_head = MlpHead()

  def forward(self, input):
    linearly_projected_patches = self.input_layer(input)
    positional_encoded_patches = self.positional_encoding(linearly_projected_patches)
    transformer_input = positional_encoded_patches

    for i in range(self.num_of_encoders):
      out = self.transformer_encoders[i](transformer_input) 
      transformer_input = out
      
    return self.mlp_head(out)


In [11]:
import glob
import cv2
import random
import torch

from torch.utils.data import Dataset

random.seed(42)

class SatelliteImagesDataset(Dataset):

  def __init__(self):
      self.size = 0
      self.num_of_classes = 0
      self.imgs_path = "data/"
      self.classes_indexes = {}

      index_counter = 0
      # retrieve the list of images in the specified path
      folder_list = glob.glob(self.imgs_path + "*")
      self.data = []
      for class_path in folder_list:
        class_name = class_path.split("/")[-1]
        self.num_of_classes += 1
        file_names = [] 

        self.classes_indexes[class_name] = index_counter
        index_counter += 1

        for img_path in glob.glob(class_path + "/*.jpg"):
          file_names.append(img_path)
          self.size += 1

        self.data.append((class_name, file_names))

  def __len__(self):
      return self.num_of_classes

  def __getitem__(self, idx):
      return self.data[idx]

  def get_classes_indexes(self):
    return self.classes_indexes

class SatelliteImagesPartition(Dataset):
  def __init__(self, data, transform):
      self.data = data
      self.transform = transform
    
  def __len__(self):
      return len(self.data)

  def __getitem__(self, idx):
      item = self.data[idx]
      class_name, img_path = item[0], item[1]
      img = cv2.imread(img_path)
      img_tensor = torch.from_numpy(img)
      # swap tensor's dimensions in order to have a tensor with the following structure : (C_in, W, H)
      img_tensor = img_tensor.permute(2, 0, 1)
    
      if self.transform is not None:
        img_tensor = self.transform(img_tensor)
        
      patches = img_tensor.unfold(1, patch_size, stride).unfold(2, patch_size, stride)
      patches = torch.permute(patches, (1,2,0,3,4))
      patches = torch.flatten(patches,start_dim=0, end_dim=1)

      return class_name, patches
  
  def shuffle_data(self):
      random.shuffle(self.data)


In [12]:
import math 

def train_test_split(dataset, train_ratio, test_ratio, num_of_classes):
  train_size = math.floor(dataset.size * train_ratio)
  test_size = math.floor(dataset.size * test_ratio) 

  train_set =  [None for i in range(train_size)]
  test_set = [None for i in range(test_size)]
  
  num_of_sample_per_class_training = math.floor(train_size / num_of_classes) 
  seen_samples = 0

  for i in range(num_of_classes):
    class_name, samples = dataset.__getitem__(i)
    random.shuffle(samples)

    offset = 0

    for sample_index in range(num_of_sample_per_class_training):
      train_set[num_of_sample_per_class_training * i + sample_index] = (class_name, samples[offset + sample_index]) 

    offset = offset + num_of_sample_per_class_training
    remaining_samples = len(samples) - offset 
    num_of_sample_per_class_test= math.floor(test_size / num_of_classes)

    for sample_index in range(remaining_samples):
      test_set[seen_samples + sample_index - 1] = (class_name, samples[offset + sample_index]) 

    seen_samples += remaining_samples

  return train_set, test_set

In [13]:
import numpy as np
import matplotlib.pyplot as plt

from torch.optim import Adam
from torch.nn import CrossEntropyLoss


def train(num_of_iter, batch_size, optimizer, device_type, data_loader, classes_indexes, vit):
  loss = CrossEntropyLoss(reduction='sum')
  # moving the model to GPU
  vit.to(device_type)
  iterator = iter(data_loader)
  loss_values = []

  for iteration in range(num_of_iter):
    # zeroing gradient
    optimizer.zero_grad()
    # getting the next batch
    class_names, batch = next(iterator)
    # moving batch to GPU
    batch = batch.to(device_type)
    index_counter = 0
    loss_value = 0

    for image in batch:
      output_probabilities = vit(image)
      target_probabilities = torch.zeros(4) 
      target_probabilities = target_probabilities.to(device_type)
      class_name = class_names[index_counter]
      index_counter += 1
      class_index = classes_indexes[class_name]
      target_probabilities[class_index] = 1
      loss_value += loss(output_probabilities, target_probabilities)

    loss_value = torch.div(loss_value, batch_size)

    print(output_probabilities)

    print("loss_value is ", loss_value, "at iteration ", iteration)
    # computing gradient
    loss_value.backward()
    # updating model's parameters
    optimizer.step()

In [None]:
import torchvision
from torch.utils import data
from torch.utils.data import RandomSampler

satellite_images_dataset = SatelliteImagesDataset()
num_of_iter = 50000
batch_size = 256

train_set, test_set = train_test_split(satellite_images_dataset, 0.8, 0.2, 4)
classes_indexes = satellite_images_dataset.get_classes_indexes()

dev_type = "cuda:0" if torch.cuda.is_available() else "cpu"

train_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256,256)),
        torchvision.transforms.AugMix(),
    ])

train_set = SatelliteImagesPartition(train_set, train_transform)

random_sampler = RandomSampler(train_set, replacement=True, num_samples=num_of_iter * batch_size)
data_loader = data.DataLoader(train_set, batch_size = batch_size, sampler=random_sampler, pin_memory=True)
iterator = iter(data_loader)
vit = VisionTransformer()
optimizer = Adam(vit.parameters(), lr=0.0001) 
train(num_of_iter, batch_size, optimizer, dev_type, data_loader, classes_indexes, vit)

In [None]:
test_transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((256,256)),
    ])

test_set = SatelliteImagesPartition(test_set, test_transform)
data_loader = data.DataLoader(test_set, batch_size = 3000, pin_memory=True)

iterator = iter(data_loader)
class_names, batch = next(iterator)
batch = batch.to(dev_type)
vit.to(dev_type)

index_counter = 0
correct_predictions = 0
accuracy = 0

for image in batch:
  actual_class_name = class_names[index_counter]
  actual_class_index = classes_indexes[actual_class_name]
  index_counter += 1

  output_probabilities = vit(image)
  max_prob = -1
  for i in range(4):
    max_value = max(output_probabilities[i], max_prob)
    if max_value == output_probabilities[i]:
      max_prob = output_probabilities[i]
      predicted_class_index = i

  if predicted_class_index == actual_class_index:
    correct_predictions += 1


accuracy = correct_predictions / 3000 
print(accuracy)
