<a href="https://colab.research.google.com/github/matteomrz/20242R0136COSE47402/blob/main/final/final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
!pip install openai-clip
!pip install datasets
!pip install torch
!pip install tqdm



In [25]:
from datasets import load_dataset
from torch.utils.data import random_split

ds = load_dataset("bazyl/GTSRB")

train_full = ds['train']
test_full = ds['test']

# Base label for text description
base_label = 'a picture of a street sign warning about '

# Map used Street Sign IDs to text descriptions
id_to_description = {
    18: "General caution",
    19: "Dangerous curve left",
    20: "Dangerous curve right",
    21: "Winding road",
    22: "Bumpy road",
    23: "Slippery road",
    24: "Road narrows on the right",
    25: "Road work",
    26: "Traffic lights",
    27: "Pedestrians",
    28: "Children crossing",
    29: "Bike crossing",
    30: "Beware of ice/snow",
    31: "Wild animals crossing",
}

# Filter for warning signs
train_full = [example for example in train_full if example['ClassId'] in id_to_description]
test_full = [example for example in test_full if example['ClassId'] in id_to_description]

# Add Text Description
for instance in train_full:
    instance['Description'] = base_label + id_to_description[instance['ClassId']]

len_train = int(0.8 * len(train_full))
train, val = random_split(train_full, [len_train, len(train_full) - len_train])

In [26]:
import clip
import torch
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
import torch.nn as nn

In [27]:
class WarningSignDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(BytesIO(item['Path']['bytes']))
        return preprocess(image), item['Description'], item['ClassId'] - 18

In [28]:
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32",device=device,jit=False)

train_dataloader = DataLoader(WarningSignDataset(train),batch_size = 32, shuffle=True) #Define your own dataloader
val_dataloader = DataLoader(WarningSignDataset(val),batch_size = 32, shuffle=False) #Define your own dataloader

# without this the loss is NaN
#https://github.com/openai/CLIP/issues/57
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [29]:
from tqdm import tqdm

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

params = model.parameters()
optimizer = optim.Adam(params, lr=1e-3) #Params used from paper, the lr is smaller, more safe for fine tuning to new dataset

# add your own code to track the training progress.
for epoch in range(8):

  # TRAINING
  model.train()
  running_loss = 0.0
  pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{8}, Loss: 0.0000")
  for images, texts, _ in pbar :
      optimizer.zero_grad()

      images= images.to(device)
      texts = clip.tokenize(texts).to(device)

      logits_per_image, logits_per_text = model(images, texts)

      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)

      total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
      total_loss.backward()
      running_loss += total_loss.item()

      # Fixes NaN loss
      if device == "cpu":
         optimizer.step()
      else :
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)

      pbar.set_description(f"Epoch {epoch+1}/{8}, Loss: {running_loss/len(train_dataloader):.4f}")
  print(f"Epoch {epoch+1}/{8}, Loss: {running_loss/len(train_dataloader):.4f}")

  # VALIDATION
  model.eval()
  total = 0
  correct = 0
  with torch.no_grad():
    for images, texts, classId in val_dataloader:
      images= images.to(device)
      texts = clip.tokenize(texts).to(device)
      logits_per_image, logits_per_text = model(images, texts)
      pred = torch.argmax(logits_per_image, dim=-1).cpu().numpy()
      total += len(images)
      correct += (pred[0] == classId).sum().item()
  print(f"Validation Accuracy: {100*correct/total}%")

torch.save(model.state_dict(), 'clip_finetuned.pth')

Epoch 1/8, Loss: 3.5870: 100%|██████████| 192/192 [00:50<00:00,  3.81it/s]


Epoch 1/8, Loss: 3.5870
Validation Accuracy: 4.901960784313726%


Epoch 2/8, Loss: 3.4594: 100%|██████████| 192/192 [00:43<00:00,  4.41it/s]


Epoch 2/8, Loss: 3.4594
Validation Accuracy: 5.751633986928105%


Epoch 3/8, Loss: 3.4580: 100%|██████████| 192/192 [00:43<00:00,  4.38it/s]


Epoch 3/8, Loss: 3.4580
Validation Accuracy: 6.7973856209150325%


Epoch 4/8, Loss: 0.8114:  23%|██▎       | 45/192 [00:10<00:33,  4.37it/s]


KeyboardInterrupt: 

In [None]:
# TESTING
model.eval()
total = 0
correct = 0
with torch.no_grad():
  for images, texts, classId in val_dataloader:
    images= images.to(device)
    texts = clip.tokenize(texts).to(device)
    logits_per_image, logits_per_text = model(images, texts)
    pred = torch.argmax(logits_per_image, dim=-1).cpu().numpy()
    total += len(images)
    correct += (pred[0] == classId).sum().item()
print(f"Test Accuracy: {100*correct/total}%")