In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score,f1_score

from tqdm import tqdm
from transformers import CLIPTokenizerFast

## **[MetaFormer](https://arxiv.org/abs/2210.13452)** for Image Classification

This notebook will use the best variance of **Metaformer**, **CAFormer** for image classification

Original paper: **[MetaFormer Baselines for Vision](https://arxiv.org/abs/2210.13452)**

Github: **[MetaFormer](https://github.com/sail-sg/metaformer)**

In [2]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

In [4]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['newid'].unique())

### Preparing Dataset

In [5]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(), 
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

In [6]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask
    

## Load CAFormer
Model parameters: 
- Small: 39M
- Big: 99M

In [7]:
from metaformer import caformer_s36_in21ft1k, caformer_m36_in21ft1k, caformer_b36_in21ft1k

model = caformer_m36_in21ft1k(pretrained=True)
model.to(device)

MetaFormer(
  (downsample_layers): ModuleList(
    (0): Downsampling(
      (pre_norm): Identity()
      (conv): Conv2d(3, 96, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
      (post_norm): LayerNormGeneral()
    )
    (1): Downsampling(
      (pre_norm): LayerNormGeneral()
      (conv): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (post_norm): Identity()
    )
    (2): Downsampling(
      (pre_norm): LayerNormGeneral()
      (conv): Conv2d(192, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (post_norm): Identity()
    )
    (3): Downsampling(
      (pre_norm): LayerNormGeneral()
      (conv): Conv2d(384, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (post_norm): Identity()
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): MetaFormerBlock(
        (norm1): LayerNormWithoutBias()
        (token_mixer): SepConv(
          (pwconv1): Linear(in_features=96, out_features=192, bias=False)
          (act1): Star

### Changing output layers

In [8]:
model.head.fc2 = nn.Linear(model.head.fc2.in_features, num_classes)

### Training parameters

In [9]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

### Loading dataset

In [10]:
train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_all.pth')
train_dataloader = DataLoader(train_dataset, batch_size=int(BATCH_SIZE/1), shuffle=True)

test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
test_dataloader = DataLoader(test_dataset, batch_size=int(BATCH_SIZE/1), shuffle=False)

In [11]:
device

device(type='cuda')

### Training

In [None]:
max_accuracy = 0.0

for epoch in range(EPOCHS):
    model.to(device)
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
 
    

    
    for inputs, labels, input_ids, attention_mask in tqdm(train_dataloader):
        optimizer.zero_grad()
        inputs = inputs.to(device).type(torch.cuda.FloatTensor)
        labels = labels.to(device)

        outputs = model(inputs)
        

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    len_train += len(train_dataset)
   
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    print('Evaluating epoch:', epoch+1)
    with torch.no_grad():
        len_test = 0
        
        
        for inputs, labels, input_ids, attention_mask in tqdm(test_dataloader):
            inputs = inputs.to(device).type(torch.cuda.FloatTensor)
            labels = labels.to(device)
            
            outputs = model(inputs)
            
            loss = criterion(outputs, labels).to(device)
            eval_loss += loss.item()
            
            outputs = torch.argmax(outputs, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        
        len_test += len(test_dataset)
        
        
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
    print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
    print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
    print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
    if accuracy_score(true_labels, pred_labels) > max_accuracy:
        max_accuracy = accuracy_score(true_labels, pred_labels)
        torch.save(model.state_dict(), f'{MODEL_PATH}/CAFormer_reduced_model_{epoch+1}.pth')
        torch.save(optimizer.state_dict(), f'{MODEL_PATH}/optimizer/CAFormer_vit_reduced_optimizer_{epoch+1}.pth')
            

In [21]:
model.parameters()

<generator object Module.parameters at 0x0000015BEC91F3E0>

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

97674991