# Image classification 

For classification purpose, the output layer needs to be a FC layer with its output the class number. 

ViT needs to append a FC layer (head) because its strucute was originally designed for NLP; for other CNN based model, can modify the last layer in-place.

Fine-tuning can:
1. Update the parameters of the whole model
2. Only update the last layer (or certain layers). Can make these parts require_grad=False.

Notes:
1. By default, pretrained PyTorch models (huggingface) are built considering input data (N,C,H,W) as the first layer or embedding layer. Model itself is built with weights only taking a single data (C,H, W) because no need to duplicated the parameters. For example, the ViT model used in this notebook handles batches in the model embedding layer.
2. PyTorch dataloader will prepare data in correct batch (N,C,H,W). Then you can write outputs=model(batch_data). It will feed each data into the "actual" model, calculate the score per data,and ouput the results in batch.
3. Then you can use criterion to calculate a single loss score per minibatch, and update the weights per minibatch (instead of the whole dataset).
4. PyTorch-Lightning handles batch inside the trainer() class.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

# Training
## Parameters

In [None]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = 'vit_p16' #"vit_p32" #"resnet" #"squeezenet"

# Number of classes in the dataset
num_classes = 2

# Batch size during training
batch_size = 16

# Number of epochs to train for 
num_epochs = 50

# Flag for feature extracting. When False, we finetune the whole model, when True we only update the reshaped layer params
feature_extract = True

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)

# output_dir
out_dir = './discriminator_data/checkpoints/'

## Dataset

In [None]:
from model_utils import myDataloader

#dataroot = "./discriminator_data/data"
dataroot = "/lovelace/zhuowen/diffusers/als/40k_generated/als_2400_labeled"
input_size = 224
transform=transforms.Compose([transforms.Resize(input_size),
                              transforms.CenterCrop(input_size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                             ])
dataloaders =  myDataloader(root=dataroot, 
                            batch_size=batch_size, 
                            transform=transform,
                            num_workers = 2,
                            split=[0.7, 0.3],
                            from_subfolders=True)

train_dataloader, val_dataloader = dataloaders.dataloader()
print(len(train_dataloader), len(val_dataloader))
    
dataloaders.plot_dataloader(train_dataloader, device)

# Create training and validation dataloaders
dataloaders_dict = {'train': train_dataloader, 'validation':val_dataloader}


## Training example

In [None]:
from model_utils import myModels

# Initialize the model for this run
model_ft, input_size = myModels.initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# Print the model we just instantiated
#print(model_ft)

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are 
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad is True.
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

In [None]:
from pipeline_utils import Trainer

# define optimizer and loss
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Train and evaluate
trainer = Trainer(model_ft, 
                  dataloaders_dict, 
                  criterion, 
                  optimizer_ft,
                  device=device, 
                  num_epochs=num_epochs, 
                  is_inception=(model_name=="inception"), 
                  verbose=True
                  )
model_ft, hist = trainer.train()

In [None]:
trainer.plot_history(10, save_fig=True, out_name=out_dir+'history_'+model_name+'_'+str(num_epochs)+'epochs_classifier.png')
trainer.save_results(out_dir, model_name=model_name+'_'+str(num_epochs)+'_epochs_classifier')

# Evaluation example (vit_16x16)

In [None]:
import torchvision
from pipeline_utils import Evaluation

dataset = torchvision.datasets.ImageFolder(root="/lovelace/zhuowen/diffusers/als/40k_generated/als_2400_labeled", 
                                                   transform=transform)
eval_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
vit_p16 = torch.load(out_dir+'vit_p16_50epochs.pth')
vit_p16_eval = Evaluation(vit_p16, eval_dataloader, threshold=0.56, device=device)
vit_p16_eval.evaluate()

## Inferencing example (vit_16x16)

In [None]:
from model_utils import myDataset
from pipeline_utils import Inference
import glob

test_data_path = '/lovelace/zhuowen/diffusers/als/40k_generated/prompt1/'
filepaths = glob.glob(test_data_path+'/*')

test_dataset = myDataset(root=test_data_path, 
                         transform=transform,
                         return_filepath=True)


test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              drop_last=False,
                                              prefetch_factor=None)

In [None]:
inference_data_path = '/lovelace/zhuowen/diffusers/als/40k_generated/prompt1/'

# load trained model 
vit_p16 = torch.load(out_dir+'vit_p16_50epochs.pth')
vit_p16_infer = Inference(vit_p16, test_dataloader, threshold=0.56, device=device, filepaths=filepaths)
vit_p16_infer.next_batch()

## Ensemble classification and evaluation 