In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset
import pandas as pd

from sklearn.metrics import accuracy_score,f1_score

from tqdm import tqdm
from transformers import ViTForImageClassification, TrainingArguments, Trainer

In [14]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
df_train = pd.read_csv(f'{DIRECTROY}/train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/test_kaggletest.csv') 
num_classes = len(df_train['class'].unique())

In [4]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [5]:
len(df_test_public), len(df_test_private)

(18957, 36419)

# Divide the train into multiple chunks

Due to lack of ram, i will have to divide the dataloader into multiple dataloader

In [6]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['class'].values).long()
        self.imgs = torch.cat([ self.transforms(Image.open(f'{DIRECTROY}/{self.directory}/{x}').convert('RGB')).reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        return img, label

In [8]:
import math

In [9]:
for i in range(math.ceil(len(df_train)/8096)):
    train_dataset = CustomDataset(df_train[i*8096:(i+1)*8096], image_transforms, 'train')
    torch.save(train_dataset, f'{DIRECTROY}/train_dataset/train_dataset_{i}.pth')
    del train_dataset

100%|██████████| 8096/8096 [01:12<00:00, 112.31it/s]
100%|██████████| 8096/8096 [01:25<00:00, 94.83it/s] 
100%|██████████| 8096/8096 [01:25<00:00, 94.24it/s] 
100%|██████████| 8096/8096 [01:34<00:00, 86.03it/s] 
100%|██████████| 8096/8096 [01:48<00:00, 74.91it/s] 
100%|██████████| 8096/8096 [01:36<00:00, 84.26it/s] 
100%|██████████| 8096/8096 [01:41<00:00, 80.07it/s] 
100%|██████████| 8096/8096 [01:40<00:00, 80.79it/s] 
100%|██████████| 8096/8096 [01:33<00:00, 86.24it/s] 
100%|██████████| 8096/8096 [01:39<00:00, 81.54it/s] 
100%|██████████| 8096/8096 [01:38<00:00, 82.36it/s] 
100%|██████████| 8096/8096 [01:36<00:00, 83.99it/s] 
100%|██████████| 8096/8096 [01:33<00:00, 86.18it/s] 
100%|██████████| 8096/8096 [01:33<00:00, 86.25it/s] 
100%|██████████| 8096/8096 [01:34<00:00, 85.57it/s] 
100%|██████████| 8096/8096 [01:29<00:00, 90.48it/s] 
100%|██████████| 8096/8096 [01:32<00:00, 87.55it/s] 
100%|██████████| 4299/4299 [00:50<00:00, 85.60it/s] 


In [11]:
for i in range(math.ceil(len(df_test_public)/8096)):
    test_dataset = CustomDataset(df_test_public[i*8096:(i+1)*8096], image_transforms, 'test')
    torch.save(test_dataset, f'{DIRECTROY}/test_public_dataset/test_public_dataset_{i}.pth')
    del test_dataset

# for i in range(4):
#     test_dataset = CustomDataset(df_test_private[i*10000:(i+1)*10000], image_transforms, 'test')
#     torch.save(test_dataset, f'{DIRECTROY}/test_private_dataset_{i}.pth')
#     del test_dataset


100%|██████████| 8096/8096 [01:05<00:00, 123.96it/s]
100%|██████████| 8096/8096 [01:19<00:00, 101.22it/s]
100%|██████████| 2765/2765 [00:28<00:00, 98.73it/s]


In [6]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model.parameters

<bound method Module.parameters of ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
      

In [8]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [16]:
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    max_accuracy = 0.0
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    for i in range(18):
        train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_{i}.pth')
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        for inputs, labels in tqdm(train_dataloader):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        len_train += len(train_dataset)
        del train_dataset
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    print('Evaluating epoch:', epoch+1)
    with torch.no_grad():
        len_test = 0
        for i in range(3):
            test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_dataset_{i}.pth')
            test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

            for inputs, labels in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs.logits, labels)
                eval_loss += loss.item()
                
                outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
                labels = labels.flatten().cpu().numpy()
                
                true_labels.extend(labels)
                pred_labels.extend(outputs)
            
            len_test += len(test_dataset)
            del test_dataset
        
        print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
        if accuracy_score(true_labels, pred_labels) > max_accuracy:
            max_accuracy = accuracy_score(true_labels, pred_labels)
            torch.save(model.state_dict(), f'{MODEL_PATH}/vit_model_{epoch+1}.pth')
            

Training epoch: 1


100%|██████████| 253/253 [01:00<00:00,  4.17it/s]
100%|██████████| 253/253 [01:01<00:00,  4.12it/s]
  0%|          | 0/253 [00:00<?, ?it/s]

In [None]:
train_dataset = torch.load(f'{DIRECTROY}/train_dataset_0.pth')
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
print(len(train_dataset))

10000
