# Import necessary libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import pandas as pd

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

# Load the dataset

In [None]:
folder_path = 'normal_train/'

data = []

# for f in os.listdir(folder_path):
#     if f.endswith('.png'):
#         img = plt.imread(os.path.join(folder_path, f))
#         img_class = f.split('.')[0].split('_')[0]
#         data.append({'image':img, 'class':img_class})
#         print(folder_path + f)

# df = pd.DataFrame(data)

# with open('data.pkl', 'wb') as f:
#     pickle.dump(df, f)

with open('data.pkl', 'rb') as f:
    df = pickle.load(f)

In [None]:
folder_path = 'normal_test/'

test_data = []

# for f in os.listdir(folder_path):
#     if f.endswith('.png'):
#         img = plt.imread(os.path.join(folder_path, f))
#         img_class = f.split('.')[0].split('_')[0]
#         test_data.append({'image':img, 'class':img_class})
#         print(folder_path + f)

# test_df = pd.DataFrame(test_data)

# with open('test_data.pkl', 'wb') as f:
#     pickle.dump(test_df, f)

with open('test_data.pkl', 'rb') as f:
    test_df = pickle.load(f)

# Display the first image

In [None]:
import matplotlib.pyplot as plt

first_image = df['image'].iloc[0]
first_class = df['class'].iloc[0]

plt.figure(figsize=(5, 5))
plt.imshow(first_image)
plt.title(f'Class: {first_class}')
plt.axis('off')
plt.show()

print(first_image.shape)

# CustomDataset

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image = self.df['image'].iloc[idx]
        label = self.df['class'].iloc[idx]

        image = torch.from_numpy(image).float()

        if self.transform:
            image = self.transform(image)

        return image, label

# Constants & Parameters

In [None]:
n_image = 9600
chw = (4, 256, 256)

n_patches=16 # Number of patches in each row and column
n_blocks=4
hidden_d=64 # Hidden dimension in patch embedding(token dimension)
n_heads=8 # Number of heads in multi-head self-attention, must be divisible by hidden_d
out_d=15 # Number of classes
mlp_ratio = 16
N_EPOCHS = 200
LR = 1e-3
BATCH_SIZE = 512

# Train-Validation Split

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
le.fit(df['class'])
with open('le.pkl', 'wb') as f:
    pickle.dump(le, f)
df['class'] = le.transform(df['class'])
test_df['class'] = le.transform(test_df['class'])

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

transform = ToTensor()
train_set = CustomDataset(train_df)
val_set = CustomDataset(val_df)
test_set = CustomDataset(test_df)

train_loader = DataLoader(train_set, shuffle=True, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_set, shuffle=True, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_set, shuffle=False, batch_size=BATCH_SIZE)


# Patchify

In [None]:
def patchify(images):
    n, h, w, c = images.shape

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2, device=images.device) # Nn p^2, HWC / p^2
    patch_size = h // n_patches

    for idx in range(n):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = images[idx, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size, :]
                patches[idx, i * n_patches + j] = patch.flatten()

    return patches

# Positional Embedding

In [None]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

# Multi-Head Self Attention

In [None]:
class MyMSA(nn.Module):
    def __init__(self):
        super(MyMSA, self).__init__()

        self.n_heads = n_heads
        d_head = int(hidden_d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

# ViT Block

In [None]:
class MyViTBlock(nn.Module):
    def __init__(self):
        super(MyViTBlock, self).__init__()

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA()
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

# Model


In [None]:
class MyViT(nn.Module):
  def __init__(self):
    # Super constructor
    super(MyViT, self).__init__()
    self.batch_size = BATCH_SIZE

    self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

    self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
    self.linear_mapper = nn.Linear(self.input_d, hidden_d)

    self.class_token = nn.Parameter(torch.rand(1, hidden_d))

    self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(n_patches ** 2 + 1, hidden_d)))
    self.pos_embed.requires_grad = False

    self.blocks = nn.ModuleList([MyViTBlock() for _ in range(n_blocks)])

    self.mlp = nn.Sequential(
            nn.Linear(hidden_d, out_d),
            nn.Softmax(dim=-1)
        )


  def forward(self, images):
    self.batch_size = images.shape[0]
    patches = patchify(images)
    tokens = self.linear_mapper(patches)
    tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    pos_embed = self.pos_embed.repeat(self.batch_size, 1, 1)
    out = tokens + pos_embed
    for block in self.blocks:
        out = block(out)

    out = out[:, 0]

    return self.mlp(out)

# Train

In [None]:
model = MyViT().to(device)
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()
max_acc = 0.0

writer = SummaryWriter()

for epoch in range(N_EPOCHS):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for batch in train_loader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)

        loss = criterion(y_hat, y)
        train_loss += loss.detach().cpu().item() / len(train_loader)

        _, predicted = torch.max(y_hat.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_acc = 100 * correct / total

    model.eval()
    correct = 0
    total = 0
    val_loss = 0.0
    for batch in val_loader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)

        loss = criterion(y_hat, y)
        val_loss += loss.detach().cpu().item() / len(val_loader)

        _, predicted = torch.max(y_hat.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

    val_acc = 100 * correct / total
    print(f"Epoch {epoch + 1}/{N_EPOCHS} val_accuracy: {val_acc:.2f}%")
    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/validation', val_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    writer.add_scalar('Accuracy/validation', val_acc, epoch)

    if max_acc < val_acc:
        max_acc = val_acc
        torch.save(model, 'ViT.pth')


# Evaluation

In [None]:
from PIL import Image
from torchvision import transforms
transform = transforms.ToPILImage()


model = torch.load('ViT.pth').to(device)
model.eval()
correct = 0
total = 0
data = []

os.makedirs('false', exist_ok=True)
count = 0;


for batch in test_loader:
    x, y = batch
    x, y = x.to(device), y.to(device)
    y_hat = model(x)

    _, predicted = torch.max(y_hat.data, 1)
    total += y.size(0)
    correct += (predicted == y).sum().item()

    data.extend([{'y': y_item.item(), 'y_hat': pred_item.item()} for y_item, pred_item in zip(y, predicted)])

    for i in range(y.size(0)):
        if y[i]!=predicted[i]:
            image = x[i].permute(2, 1, 0)
            image = transform(image)
            image_class = le.inverse_transform(np.array([predicted[i].cpu().item()]))
            image_path = os.path.join('false', f'false_{image_class}_{count}.png')
            image.save(image_path)
            count += 1

test_acc = 100 * correct / total
df = pd.DataFrame(data)
print(df)
df.to_csv('predictions.csv', index=False)
print(test_acc)


# Reference:

https://keras.io/examples/vision/image_classification_with_vision_transformer/  
https://medium.com/@brianpulfer/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c


