# Binary Model using Pretrained ViT

In [34]:
import os
import shutil
import evaluate
import numpy as np
import time
import copy

In [35]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [37]:
from transformers import ViTConfig, ViTModel, ViTImageProcessor, ViTForImageClassification
from transformers import AutoImageProcessor
from transformers import TrainingArguments, Trainer

In [38]:
from model_training_script import train_model

## Preparing Data Loaders

In [39]:
output_dir = '/Users/jasminecjwchen/Documents/GitHub/COMS-4995-ACV-Project/split_data'

In [40]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets

# after normalization, must convert to PIL for input into VIT
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ToPILImage()
])

train_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'val'), transform=transform)
test_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'test'), transform=transform)

In [41]:
def collate_fn(batch):
    # Filter failed images first
    batch = list(filter(lambda x: x is not None, batch))

    images = [sample[0] for sample in batch]
    labels = torch.LongTensor([sample[1] for sample in batch])
    
    return images, labels

In [42]:
phases = ["train", "val", "test"]

In [43]:
image_datasets = {x: datasets.ImageFolder(root = os.path.join(output_dir, x), transform = transform) for x in phases}
dataset_sizes = {x: len(image_datasets[x]) for x in phases}
dataloaders = {x: DataLoader(image_datasets[x], batch_size = 32, shuffle = (x == "train"), collate_fn=collate_fn) for x in phases}

In [44]:
print(dataset_sizes)

{'train': 27892, 'val': 5976, 'test': 5978}


## VIT

In [58]:
class VITModel(nn.Module):
    # vit config params are passed through kwargs
    def __init__(self, freeze_vit = True, **kwargs):
        super().__init__()
        
        self.tokenizer = ViTImageProcessor()
        self.use_pretrained = (len(kwargs) == 0)
        
        # if vit config are passed, then create model using configs
        if len(kwargs) != 0:
            vit_config = ViTConfig(**kwargs)
            self.model = ViTModel(vit_config)
            print("using custom model")
        else:
            # no config passed, then use pretrained with the option to freeze
            self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
            if freeze_vit: 
                for param in self.model.parameters():
                    param.requires_grad = False
            print("using pretrained model")
        self.linear = nn.Linear(151296, 2)
    
    def forward(self, image):
        encoded_image = self.tokenizer(image, return_tensors = "pt")
        model_output = self.model(**encoded_image)
        flattened_output = model_output.last_hidden_state.flatten(start_dim = 1)
        if not self.use_pretrained:
            self.linear = nn.Linear(flattened_output.shape[1], 2)
        return self.linear(flattened_output)

## Training

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
use_gpu = (device != "cpu")

Using cpu


## ViT Pretrained

In [47]:
pretrained_model = VITModel()
pretrained_optimizer = optim.Adam(pretrained_model.parameters(), lr = 0.001)

using pretrained model


In [48]:
train_model(pretrained_model, pretrained_optimizer, dataloaders, dataset_sizes, num_epochs = 10, patience = 2, output_filename = "vit_binary_pretrained.pth")

Epoch 0/9 train complete in 2635.3168 seconds. train loss: 0.2963 recall: 0.9406. Patience left: 2
Epoch 0/9 val complete in 555.4542 seconds. val loss: 0.2121 recall: 0.9617. Patience left: 2
Epoch 1/9 train complete in 2683.0672 seconds. train loss: 0.1666 recall: 0.9728. Patience left: 2
Epoch 1/9 val complete in 565.0933 seconds. val loss: 0.4871 recall: 0.9424. Patience left: 1
Epoch 2/9 train complete in 2788.3781 seconds. train loss: 0.0787 recall: 0.9865. Patience left: 1
Epoch 2/9 val complete in 542.9855 seconds. val loss: 0.2312 recall: 0.9781. Patience left: 2
Epoch 3/9 train complete in 2799.5167 seconds. train loss: 0.0735 recall: 0.9885. Patience left: 2
Epoch 3/9 val complete in 611.2034 seconds. val loss: 0.2252 recall: 0.9787. Patience left: 2
Epoch 4/9 train complete in 4755.8138 seconds. train loss: 0.0668 recall: 0.9907. Patience left: 2
Epoch 4/9 val complete in 616.8248 seconds. val loss: 0.3648 recall: 0.9742. Patience left: 1
Epoch 5/9 train complete in 2867.68

VITModel(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_feature

# ViT Custom Parameters

In [63]:
custom_model = VITModel(hidden_size = 96, num_hidden_layers = 6, num_attention_heads = 6, intermediate_size = 1000, num_classes = 2)
custom_optimizer = optim.Adam(custom_model.parameters(), lr = 0.001)

using custom model


In [64]:
train_model(custom_model, custom_optimizer, dataloaders, dataset_sizes, num_epochs = 10, patience = 2, output_filename = "vit_binary_custom.pth")

Epoch 0/9 train complete in 715.4610 seconds. train loss: 0.7615 recall: 0.5033. Patience left: 2
Epoch 0/9 val complete in 80.2966 seconds. val loss: 0.7450 recall: 0.4933. Patience left: 2
Epoch 1/9 train complete in 689.5075 seconds. train loss: 0.7276 recall: 0.5001. Patience left: 2
Epoch 1/9 val complete in 90.0884 seconds. val loss: 0.6929 recall: 0.5308. Patience left: 2
Epoch 2/9 train complete in 790.4634 seconds. train loss: 0.7019 recall: 0.4997. Patience left: 2
Epoch 2/9 val complete in 96.8933 seconds. val loss: 0.6978 recall: 0.4829. Patience left: 1
Epoch 3/9 train complete in 759.6567 seconds. train loss: 0.6964 recall: 0.5016. Patience left: 1
Epoch 3/9 val complete in 89.8401 seconds. val loss: 0.6963 recall: 0.4906. Patience left: 0
Ran out of patience. Stopping early
Training complete in 55m 12s
Best Metrics at Epoch 2:
Val_loss: 0.6929
Val_accuracy: 0.5308
Val_precision: 0.5308
Val_recall: 0.5308
Val_f1: 0.5306


VITModel(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-5): 6 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=96, out_features=96, bias=True)
              (key): Linear(in_features=96, out_features=96, bias=True)
              (value): Linear(in_features=96, out_features=96, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=96, out_features=96, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=96, out_features=1000, bias