# Binary Model using Pretrained ViT

In [1]:
import os
import shutil
import evaluate
import numpy as np
import time
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [4]:
from transformers import ViTConfig, ViTModel, ViTImageProcessor, ViTForImageClassification
from transformers import AutoImageProcessor
from transformers import TrainingArguments, Trainer

In [5]:
from model_training_script import train_model

## Preparing Data Loaders

In [6]:
output_dir = '/Users/jasminecjwchen/Documents/GitHub/COMS-4995-ACV-Project/split_data'

In [7]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets

# after normalization, must convert to PIL for input into VIT
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ToPILImage()
])

train_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'val'), transform=transform)
test_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'test'), transform=transform)

In [8]:
def collate_fn(batch):
    # Filter failed images first
    batch = list(filter(lambda x: x is not None, batch))

    images = [sample[0] for sample in batch]
    labels = torch.LongTensor([sample[1] for sample in batch])
    
    return images, labels

In [9]:
phases = ["train", "val", "test"]

In [10]:
image_datasets = {x: datasets.ImageFolder(root = os.path.join(output_dir, x), transform = transform) for x in phases}
dataset_sizes = {x: len(image_datasets[x]) for x in phases}
dataloaders = {x: DataLoader(image_datasets[x], batch_size = 32, shuffle = (x == "train"), collate_fn=collate_fn) for x in phases}

In [11]:
print(dataset_sizes)

{'train': 27892, 'val': 5976, 'test': 5978}


## VIT

In [12]:
class VITModel(nn.Module):
    # vit config params are passed through kwargs
    def __init__(self, freeze_vit = True, **kwargs):
        super().__init__()
        
        self.tokenizer = ViTImageProcessor()
        
        # if vit config are passed, then create model using configs
        if len(kwargs) == 0:
            vit_config = ViTConfig(**kwargs)
            self.model = ViTModel(vit_config)
        else:
            # no config passed, then use pretrained with the option to freeze
            self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
            if freeze_vit: 
                for param in self.model.parameters():
                    param.requires_grad = False
        self.logistic = nn.Linear(151296, 2)
    
    def forward(self, image):
        encoded_image = self.tokenizer(image, return_tensors = "pt")
        model_output = self.model(**encoded_image)
        flattened_output = model_output.last_hidden_state.flatten(start_dim = 1)
        return self.logistic(flattened_output)

## Training

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
use_gpu = (device != "cpu")

Using cpu


## ViT Pretrained

In [14]:
pretrained_model = VITModel()
pretrained_optimizer = optim.Adam(pretrained_model.parameters(), lr = 0.001)

In [15]:
train_model(pretrained_model, pretrained_optimizer, dataloaders, dataset_sizes, num_epochs = 10, patience = 2, output_filename = "vit_binary_pretrained.pth")

# ViT Custom Parameters

In [None]:
custom_model = VITModel(hidden_size = 10, num_hidden_layers = 2, num_attention_heads = 2, intermediate_size = 100, num_classes = 2)
custom_optimizer = optim.Adam(custom_model.parameters(), lr = 0.001)

In [None]:
train_model(custom_model, custom_optimizer, dataloaders, num_epochs = 10, patience = 2, output_filename = "vit_binary_custom.pth")

Epoch 0/0 train complete in 2820.9151 seconds. train loss: 0.3212 acc: 0.9392. Patience left: 10
Epoch 0/0 val complete in 3502.6977 seconds. val loss: 0.2149 acc: 0.9635. Patience left: 10
Training complete in 58m 24s
Best val accuracy: 0.963521


VITModel(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_feature