# Binary Model using Pretrained ViT

In [62]:
import os
import shutil
import evaluate
import numpy as np
import time
import copy

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score

In [64]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

In [5]:
from transformers import ViTConfig, ViTModel, ViTImageProcessor, ViTForImageClassification
from transformers import AutoImageProcessor
from transformers import TrainingArguments, Trainer

## Preparing Data Loaders

In [None]:
output_dir = '/Users/jasminecjwchen/Documents/GitHub/COMS-4995-ACV-Project/split_data'

In [71]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets

# after normalization, must convert to PIL for input into VIT
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ToPILImage()
])

train_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'val'), transform=transform)
test_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'test'), transform=transform)

In [72]:
def collate_fn(batch):
    # Filter failed images first
    batch = list(filter(lambda x: x is not None, batch))

    images = [sample[0] for sample in batch]
    labels = torch.LongTensor([sample[1] for sample in batch])
    
    return images, labels

In [74]:
phases = ["train", "val", "test"]

In [189]:
image_datasets = {x: datasets.ImageFolder(root = os.path.join(output_dir, x), transform = transform) for x in phases}
dataset_sizes = {x: len(image_datasets[x]) for x in phases}
dataloaders = {x: DataLoader(image_datasets[x], batch_size = 32, shuffle = (x == "train"), collate_fn=collate_fn) for x in phases}

In [190]:
print(dataset_sizes)

{'train': 27892, 'val': 5976, 'test': 5978}


## VIT

In [191]:
class VITModel(nn.Module):
    # vit config params are passed through kwargs
    def __init__(self, freeze_vit = True, **kwargs):
        super().__init__()
        
        self.tokenizer = ViTImageProcessor()
        
        # if vit config are passed, then create model using configs
        if len(kwargs) == 0:
            vit_config = ViTConfig(**kwargs)
            self.model = ViTModel(vit_config)
        else:
            # no config passed, then use pretrained with the option to freeze
            self.model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
            if freeze_vit: 
                for param in self.model.parameters():
                    param.requires_grad = False
        self.logistic = nn.Linear(151296, 2)
    
    def forward(self, image):
        encoded_image = self.tokenizer(image, return_tensors = "pt")
        model_output = self.model(**encoded_image)
        flattened_output = model_output.last_hidden_state.flatten(start_dim = 1)
        return self.logistic(flattened_output)

## Training

In [192]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
use_gpu = (device != "cpu")

Using cpu


In [193]:
def train_model(model, optimizer, scheduler = None, dataloaders = dataloaders, num_epochs = 1, patience = 10, output_filename = "best_model.pth"):    
    since = time.time()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # use patience for early stopping when validation isnt getting better
    patience_left = patience

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for image, labels in dataloaders[phase]:
                if use_gpu:
                    labels = labels.to(device)
                
                optimizer.zero_grad()
                
                # forward
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(image)
                    preds = torch.argmax(outputs, 1)
                    loss = F.cross_entropy(outputs, labels)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                
                running_loss += loss.item() * len(image)
                running_corrects += torch.sum(preds == labels.data)
                #print(running_loss, running_corrects)
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            epoch_time = time.time() - epoch_start_time

            # deep copy the model if it's best so far
            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), output_filename)
                patience_left = patience
            elif phase == "val":
                patience_left -= 1
            
            print(f'Epoch {epoch}/{num_epochs - 1} {phase} complete in {epoch_time:.4f} seconds. {phase} loss: {epoch_loss:.4f} acc: {epoch_acc:.4f}. Patience left: {patience_left}')
            
        if patience_left <= 0:
            print("Ran out of patience. Stopping early")
            break
        
        if scheduler:
            scheduler.step()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val accuracy: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [194]:
model = VITModel(hidden_size = 10, num_hidden_layers = 2, num_attention_heads = 2, intermediate_size = 100, num_classes = 2)
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [195]:
train_model(model, optimizer)

Epoch 0/0 train complete in 2820.9151 seconds. train loss: 0.3212 acc: 0.9392. Patience left: 10
Epoch 0/0 val complete in 3502.6977 seconds. val loss: 0.2149 acc: 0.9635. Patience left: 10
Training complete in 58m 24s
Best val accuracy: 0.963521


VITModel(
  (model): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_feature