In [None]:
import torch
from torch import nn
import numpy as np
import pandas as pd
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import random
import timeit
from tqdm import tqdm


In [None]:
## Hyper Parameters

RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 40


LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTVATION = 'gelu'
NUM_ENCODER = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
class PatchEmbeddings(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super(PatchEmbeddings, self).__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, 
                out_channels=embed_dim, 
                kernel_size=patch_size, 
                stride=patch_size
            ),
            nn.Flatten(2),
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.positional_embeddings = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = x[:, 0, :]
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.positional_embeddings
        x = self.dropout(x)  
        return x

In [None]:
model = PatchEmbeddings(embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, num_patches=NUM_PATCHES, dropout=DROPOUT, in_channels=IN_CHANNELS).to(device)

x = torch.randn(512, 1 , 28, 28)

print(model(x).shape) # torch.Size([512, 50, 16])

In [None]:
class ViT(nn.Module):
    def __init__(self, 
                 num_patches, 
                 img_size,
                 num_classes, 
                 patch_size, 
                 embed_dim, 
                 num_encoders, 
                 num_heads, 
                 hiden_dim, 
                 dropout,
                 activation, 
                 in_channels
            ):
        super(ViT, self).__init__()
        self.embeddings_block = PatchEmbeddings(embed_dim, patch_size, num_patches, dropout, in_channels)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dropout=dropout,
            activation=activation, 
            batch_first=True,
            norm_first=True
        )

        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])
        return x



In [None]:
model = ViT(NUM_PATCHES,IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODER, NUM_HEADS, HIDDEN_DIM, DROPOUT, ACTVATION, IN_CHANNELS).to(device)
x = torch.randn(512, 1 , 28, 28)
print(model(x).shape) # torch.Size([512, 10])

Dataset MNIST

In [None]:
train_df = pd.read_csv('data/digit_recognizer/train.csv')
test_df = pd.read_csv('data/digit_recognizer/test.csv')

In [None]:
train_df.head()

In [None]:
test_df.head()

In [None]:
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=RANDOM_SEED, shuffle=True)

In [None]:
class MNISTTrainDataset(Dataset):
    def __init__(self, images, labels, indicies, transform=None):
        self.images = images
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].reshape(28, 28).astype(np.uint8)
        image = self.transform(image)
        label = self.labels[idx]
        index = self.indicies[idx]

        sample = {
            'image': image,
            'label': label,
            'index': index
        }
        return sample
    
class MNISTValDataset(Dataset):
    def __init__(self, images, labels, indicies, transform=None):
        self.images = images
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            #transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].reshape(28, 28).astype(np.uint8)
        image = self.transform(image)
        label = self.labels[idx]
        index = self.indicies[idx]

        sample = {
            'image': image,
            'label': label,
            'index': index
        }
        return sample
    
class MNISTTestDataset(Dataset):
    def __init__(self, images, indicies, transform=None):
        self.images = images
        self.indicies = indicies
        self.transform = transforms.Compose([
            #transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].reshape(28, 28).astype(np.uint8)
        image = self.transform(image)
        index = self.indicies[idx]

        sample = {
            'image': image,
            'index': index
        }
        return sample

In [None]:
plt.figure()
f, axarr = plt.subplots(1, 3)

train_dataset = MNISTTrainDataset(train_df.iloc[:, 1:].values.astype(np.uint8), train_df.iloc[:, 0].valus, train_df.index.values)
print(len(train_dataset))
sample = train_dataset[0]
axarr[0].imshow(train_dataset[0]['image'].squeeze(), cmap='gray')
axarr[0].set_title("Train Image")
print("----------------")

val_dataset = MNISTValDataset(val_df.iloc[:, 1:].values.astype(np.uint8), val_df.iloc[:, 0].valus, val_df.index.values)
print(len(val_dataset))
sample = val_dataset[0]
axarr[0].imshow(val_dataset[0]['image'].squeeze(), cmap='gray')
axarr[0].set_title("Validation Image")
print("----------------")

test_dataset = MNISTTestDataset(test_df.iloc[:, :].values.astype(np.uint8), test_df.index.values)
print(len(test_dataset))
sample = test_dataset[0]
axarr[0].imshow(test_dataset[0]['image'].squeeze(), cmap='gray')
axarr[0].set_title("Test Image")
print("----------------")

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

    def load_checkpoint(self, model):
        '''Loads the saved model.'''
        model.load_state_dict(torch.load(self.path))


# # Initialize the early stopping object
# early_stopping = EarlyStopping(patience=5, verbose=True)

# for epoch in range(1, n_epochs + 1):
#     # Train your model
#     train(...)
    
#     # Validate your model
#     val_loss = validate(...)
    
#     # Check early stopping
#     early_stopping(val_loss, model)
    
#     if early_stopping.early_stop:
#         print("Early stopping")
#         break

# # Load the last checkpoint with the best model
# early_stopping.load_checkpoint(model)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY, betas=ADAM_BETAS)

start = timeit.default_timer()

for epoch in tqdm(range(EPOCHS), position=0, leave=True):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0.0

    for idx, data_instance in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = data_instance['image'].float().to(device)
        label = data_instance['label'].to(device)
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)

        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())

        loss = criterion(y_pred, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()

    train_loss = train_running_loss / (idx + 1)

    # Validation
    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0.0

    with torch.no_grad():
        for idx, data_instance in enumerate(tqdm(val_dataloader, position=0, leave=True)):
            img = data_instance['image'].float().to(device)
            label = data_instance['label'].to(device)
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)

            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())

            loss = criterion(y_pred, label)
            val_running_loss += loss.item()

    val_loss = val_running_loss / (idx + 1)

    #accuracy = (np.array(val_labels) == np.array(val_preds)).mean()
    #cm = confusion_matrix(val_labels, val_preds)    

    train_accuracy = sum(1 for x, y in zip(train_labels, train_preds) if x == y) / len(train_labels)
    val_accuracy = sum(1 for x, y in zip(val_labels, val_preds) if x == y) / len(val_labels)

    print(f"Epoch: {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Train Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}")

stop = timeit.default_timer()
print(f'Training_Time: {stop - start}s')
    

In [None]:
torch.cuda.empty_cache()



In [None]:
labels = []
ids = []
imgs = []

model.eval()

with torch.no_grad():
    for idx, data_instance in enumerate(tqdm(test_dataloader, position=0, leave=True)):
        img = data_instance['image'].float().to(device)
        ids.extend([int(x)+1 for x in data_instance['index']])

        y_pred = model(img) # output shape: (batch_size, num_classes)

        imgs.extend(img.detach().cpu())
        labels.extend([int(i)+1 for i in torch.argmax(y_pred, dim=1).detach().cpu()])


In [None]:
plt.figure()
f, axarr = plt.subplots(2, 3)

counter = 0
for i in range(2):
    for j in range(3):
        axarr[i, j].imshow(imgs[counter].squeeze(), cmap='gray')
        axarr[i, j].set_title(f"Predicted Label: {labels[counter]}")
        counter += 1

plt.show()
