In [158]:
import math
import torch 
import torch.nn as nn
import torchvision 
from torchvision.transforms import v2 
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import json, os
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np
from torch import optim

In [159]:
class GELUActivation(nn.Module):
    def __init__(self):
        super(GELUActivation, self).__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    
class PatchEmbeddings(nn.Module):
    def __init__(self , config):
        super().__init__()
        self.img_size = config.image_size 
        self.patch_size = config.patch_size
        self.in_channels = config.in_channels
        self.hidden_size = config.hidden_size
        self.num_patches = (self.img_size // self.patch_size) * (self.img_size // self.patch_size)
        self.project = nn.Conv2d(self.in_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)
        
    def forward(self, x):
        x = self.project(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x
    
class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.positional_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config.hidden_size))
        self.dropout = nn.Dropout(config.dropout_rate)
        
    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size = x.shape[0]
        clas_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((clas_tokens, x), dim=1)
        x = x + self.positional_embeddings
        x = self.dropout(x)
        return x


In [160]:
class AttentionHead(nn.Module):
    def __init__(self, hidden_size , attention_head_size, dropout , bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        self.dropout = nn.Dropout(dropout)
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        # Attention scores = softmax (Q * K. T/ sqrt (head_size) )*V
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.dropout(attention_probs)
        attention_output = torch.matmul(attention_probs, value)
        return (attention_output, attention_probs)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.qkv_bias = config.qkv_bias
        self.heads = nn.ModuleList([])
        for _ in range(self.num_attention_heads):
            head = AttentionHead(
                self.hidden_size,
                self.attention_head_size,
                config.attention_probs_dropout_prob,
                self.qkv_bias
            )
            self.heads.append(head)
        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(config.dropout_rate)
        
    def forward(self, x , output_attentions=False):
        
        attention_outputs = []
        for attention_head in self.heads:
            attention_output = attention_head(x)
            attention_outputs.append(attention_output)
            
        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
        attention_output = self.dense(attention_output)
        attention_output = self.dropout(attention_output)
        
        if not output_attentions:
            return (attention_output, None)
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return (attention_output, attention_probs)

In [161]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.dropout = nn.Dropout(config.dropout_rate)
        self.fc1 = nn.Linear(self.hidden_size, self.intermediate_size)
        self.fc2 = nn.Linear(self.intermediate_size, self.hidden_size)
        self.activation = GELUActivation()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x
    
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.layernorm1 = nn.LayerNorm(config.hidden_size)
        self.mlp = MLP(config)
        self.layernorm2 = nn.LayerNorm(config.hidden_size)
        
    def forward(self, x, output_attentions=False):
        attention_output, attention_probs = self.attention(self.layernorm1(x), output_attentions)
        x = x + attention_output
        mlp_output = self.mlp(self.layernorm2(x))
        x = x + mlp_output
        if not output_attentions:
            return (x, None)
        else:
            return (x, attention_probs)
        
class Encoder(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(config.num_hidden_layers):
            block = Block(config)
            self.blocks.append(block)
            
    def forward(self, x, output_attentions=False):
        attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions)
            if output_attentions:
                attentions.append(attention_probs)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attentions)
        
class VisionTransformer(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.image_size = config.image_size
        self.hidden_size = config.hidden_size
        self.num_classes = config.num_classes
        self.embeddings = Embeddings(config)
        self.encoder = Encoder(config)
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)
        self.apply(self.init_weights)
        
    def forward(self, y, output_attentions=False):
        x = self.embeddings(y)
        x, attentions = self.encoder(x, output_attentions)
        x = x[:, 0,: ]
        x = self.classifier(x)
        if not output_attentions:
            return (x, None)
        else:
            return (x, attentions)
        
    def init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings):
            module.positional_embeddings.data = nn.init.trunc_normal_(
                module.positional_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.positional_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.cls_token.dtype)

In [162]:
def create_one_hot_encoding(label, classes):
    one_hot = torch.zeros(len(classes), dtype=torch.float32)
    
    if label in classes:
        index = classes.index(label)
        one_hot[index] = 1.0
    return one_hot

In [163]:
class CustomFERDataset(Dataset):
    def __init__(self, image_parent_directory, data_directory,classes , transform=None):
        self.image_parent_directory = image_parent_directory
        self.df = pd.read_csv(data_directory)
        self.transform = transform
        self.classes = classes

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.image_parent_directory + self.df.iloc[idx, 0]
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = self.df.iloc[idx, 1]
        if label in self.classes:
            index = self.classes.index(label)
        # label_tensor = torch.tensor(create_one_hot_encoding(label, self.classes))
        
        return image, index

In [164]:
def prepareData():
    transform = v2.Compose([
        v2.ToImage() ,
        v2.ToDtype(torch.uint8, scale=True),
        v2.CenterCrop((96, 96)),
        v2.ToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    classes = ('anger', 'contempt', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise')
    train_data = CustomFERDataset("data/archive/" , "data/train.csv" ,classes , transform=transform)
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
    test_data = CustomFERDataset("data/archive/" , "data/test.csv" ,classes , transform=transform )
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)
    return train_loader, test_loader ,classes


In [165]:
train_loader, test_loader ,classes = prepareData()

for _ ,batch in enumerate(train_loader):
    if _==0:
        print(batch)



[tensor([[[[-1.7754, -1.7754, -1.7754,  ..., -2.0665, -2.0837, -2.1179],
          [-1.7754, -1.7754, -1.7754,  ..., -2.0665, -2.0837, -2.1008],
          [-1.7925, -1.7754, -1.7583,  ..., -2.0837, -2.0665, -2.0665],
          ...,
          [ 0.5193,  0.5536,  0.5878,  ..., -0.3369, -0.3712, -0.4739],
          [ 0.5707,  0.6049,  0.5878,  ..., -0.3369, -0.3712, -0.4568],
          [ 0.5022,  0.5022,  0.5878,  ..., -0.3198, -0.2856, -0.4054]],

         [[-1.6856, -1.6856, -1.6856,  ..., -1.9832, -2.0007, -2.0357],
          [-1.6856, -1.6856, -1.6856,  ..., -1.9832, -2.0007, -2.0182],
          [-1.7031, -1.6856, -1.6681,  ..., -2.0007, -1.9832, -1.9832],
          ...,
          [-1.7031, -1.6681, -1.7206,  ..., -0.1975, -0.1800, -0.2850],
          [-1.6155, -1.5805, -1.6681,  ..., -0.1975, -0.1975, -0.2850],
          [-1.6331, -1.6856, -1.6331,  ..., -0.1800, -0.1450, -0.2325]],

         [[-1.4559, -1.4559, -1.4559,  ..., -1.7522, -1.7696, -1.8044],
          [-1.4559, -1.4559, 

In [166]:

def save_checkpoint(experiment_name, model, epoch, base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)
    cpfile = os.path.join(outdir, f'model_{epoch}.pt')
    torch.save(model.state_dict(), cpfile)

def save_experiment(experiment_name , config , model , train_losses , test_losses , accuracies , base_dir= "experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    os.makedirs(outdir, exist_ok=True)
    configfile = os.path.join(outdir, 'config.json')
    with open(configfile, 'w') as f:
        json.dump(config, f, sort_keys=True, indent=4)
        
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'w') as f:
        data = {
            'train_losses': train_losses,
            'test_losses': test_losses,
            'accuracies': accuracies,
        }
        json.dump(data, f, sort_keys=True, indent=4)

    save_checkpoint(experiment_name, model, "final", base_dir=base_dir)
    
def load_experiment(experiment_name, checkpoint_name="model_final.pt", base_dir="experiments"):
    outdir = os.path.join(base_dir, experiment_name)
    configfile = os.path.join(outdir, 'config.json')
    with open(configfile, 'r') as f:
        config = json.load(f)
    jsonfile = os.path.join(outdir, 'metrics.json')
    with open(jsonfile, 'r') as f:
        data = json.load(f)
    train_losses = data['train_losses']
    test_losses = data['test_losses']
    accuracies = data['accuracies']
    model = VisionTransformer(config)
    cpfile = os.path.join(outdir, checkpoint_name)
    model.load_state_dict(torch.load(cpfile))
    return config, model, train_losses, test_losses, accuracies

In [167]:
class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, value)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.__dict__})"
    
config_dict = {
    "patch_size": 4,  
    "hidden_size": 64,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 64,
    "dropout_rate": 0.1,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.04,
    "image_size": 96,
    "num_classes": 8,
    "num_channels": 3,
    "in_channels": 3,
    "qkv_bias": True,
}

config = Config(config_dict)

In [168]:
exp_name = "vision_transformer"
batch_size = 32
epochs = 10
lr = 0.001
save_model_every = 1

In [169]:
class Trainer:
    def __init__(self , model , optimizer , loss_fn , exp_name):
        self.model = model 
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        
    def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=2):
        train_losses, test_losses, accuracies = [], [], []
        for i in range(epochs):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0 and i+1 != epochs:
                print('\tSave checkpoint at epoch', i+1)
                save_checkpoint(self.exp_name, self.model, i+1)
                
        save_experiment(self.exp_name, config, self.model, train_losses, test_losses, accuracies)
        
    def train_epoch(self, trainloader):
        self.model.train()
        total_loss = 0
        for batch in tqdm(trainloader ,  total=len(train_loader)):
            images, labels = batch
            self.optimizer.zero_grad()
            loss = self.loss_fn(self.model(images)[0], labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)
    
    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                images, labels = batch
                logits, _ = self.model(images)
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss


In [170]:
def main():
    save_model_every_n_epochs = save_model_every
    trainloader, testloader, _ = prepareData()
    model = VisionTransformer(config)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    loss_fn = nn.CrossEntropyLoss()
    trainer = Trainer(model, optimizer, loss_fn, exp_name)
    trainer.train(trainloader, testloader, epochs, save_model_every_n_epochs=save_model_every_n_epochs)


if __name__ == '__main__':
    main()

100%|██████████| 705/705 [19:11<00:00,  1.63s/it]


Epoch: 1, Train loss: 2.0715, Test loss: 2.0653, Accuracy: 0.1561
	Save checkpoint at epoch 1


100%|██████████| 705/705 [19:06<00:00,  1.63s/it]


Epoch: 2, Train loss: 2.0587, Test loss: 2.0602, Accuracy: 0.1838
	Save checkpoint at epoch 2


100%|██████████| 705/705 [23:42<00:00,  2.02s/it]


Epoch: 3, Train loss: 2.0198, Test loss: 1.9469, Accuracy: 0.2562
	Save checkpoint at epoch 3


100%|██████████| 705/705 [22:46<00:00,  1.94s/it]


Epoch: 4, Train loss: 1.8844, Test loss: 1.8190, Accuracy: 0.3165
	Save checkpoint at epoch 4


100%|██████████| 705/705 [22:26<00:00,  1.91s/it]


Epoch: 5, Train loss: 1.8028, Test loss: 1.7643, Accuracy: 0.3201
	Save checkpoint at epoch 5


100%|██████████| 705/705 [22:29<00:00,  1.91s/it]


Epoch: 6, Train loss: 1.7513, Test loss: 1.7549, Accuracy: 0.3314
	Save checkpoint at epoch 6


100%|██████████| 705/705 [22:20<00:00,  1.90s/it]


Epoch: 7, Train loss: 1.7140, Test loss: 1.7136, Accuracy: 0.3463
	Save checkpoint at epoch 7


100%|██████████| 705/705 [26:09<00:00,  2.23s/it]


Epoch: 8, Train loss: 1.6835, Test loss: 1.6544, Accuracy: 0.3730
	Save checkpoint at epoch 8


100%|██████████| 705/705 [29:26<00:00,  2.51s/it]


Epoch: 9, Train loss: 1.6556, Test loss: 1.6691, Accuracy: 0.3659
	Save checkpoint at epoch 9


100%|██████████| 705/705 [29:15<00:00,  2.49s/it]


Epoch: 10, Train loss: 1.6291, Test loss: 1.6078, Accuracy: 0.3879


TypeError: Object of type Config is not JSON serializable