<a href="https://colab.research.google.com/github/matteomrz/20242R0136COSE47402/blob/main/final/final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install openai-clip
!pip install datasets
!pip install torch
!pip install tqdm

In [None]:
from datasets import load_dataset
from torch.utils.data import random_split

ds = load_dataset("bazyl/GTSRB")

train_full = ds['train']
test_full = ds['test']

# Map used Street Sign IDs to text descriptions
id_to_description = {
    18: "General caution",
    19: "Dangerous curve left",
    20: "Dangerous curve right",
    21: "Winding road",
    22: "Bumpy road",
    23: "Slippery road",
    24: "Road narrows on the right",
    25: "Road work",
    26: "Traffic lights",
    27: "Pedestrians",
    28: "Children crossing",
    29: "Bike crossing",
    30: "Beware of ice/snow",
    31: "Wild animals crossing",
}

# Filter for warning signs
train_full = [example for example in train_full if example['ClassId'] in id_to_description]
test_full = [example for example in test_full if example['ClassId'] in id_to_description]

# Add Text Description
for instance in train_full:
    instance['Description'] = id_to_description[instance['ClassId']]

len_train = int(0.8 * len(train_full))
train, val = random_split(train_full, [len_train, len(train_full) - len_train])

In [None]:
import clip
import torch

model, preprocess = clip.load("ViT-B/32", jit=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset

class WarningSignDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(BytesIO(item['Path']['bytes']))
        return self.transform(image), item['ClassId'] - 18

        # returns the number of the fucking description

In [None]:
from torch.utils.data import DataLoader

# Create DataLoader for training and validation sets
train_loader = DataLoader(WarningSignDataset(train), batch_size=32, shuffle=True)
val_loader = DataLoader(WarningSignDataset(val), batch_size=32, shuffle=False)
test_loader = DataLoader(WarningSignDataset(test_full), batch_size=32, shuffle=False)

In [None]:
import torch.nn as nn

# Modify the model to include a classifier for subcategories
class CLIPFineTuner(nn.Module):
    def __init__(self, model, num_classes):
        super(CLIPFineTuner, self).__init__()
        self.model = model
        self.classifier = nn.Linear(model.visual.output_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.model.encode_image(x).float()  # Convert to float32
        return self.classifier(features)

In [None]:
num_classes = len(id_to_description)
model_ft = CLIPFineTuner(model, num_classes).to(device)

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.classifier.parameters(), lr=5e-4)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [12]:
from tqdm import tqdm

# Number of epochs for training
num_epochs = 20

# Training loop
for epoch in range(num_epochs):
    model_ft.train()  # Set the model to training mode
    running_loss = 0.0  # Initialize running loss for the current epoch
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}, Loss: 0.0000")  # Initialize progress bar

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)  # Move images and labels to the device (GPU or CPU)
        optimizer.zero_grad()  # Clear the gradients of all optimized variables
        outputs = model_ft(images)  # Forward pass: compute predicted outputs by passing inputs to the model
        loss = criterion(outputs, labels)  # Calculate the loss
        loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
        optimizer.step()  # Perform a single optimization step (parameter update)

        running_loss += loss.item()  # Update running loss
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")  # Update progress bar with current loss

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')  # Print average loss for the epoch

    scheduler.step()  # Update learning rate scheduler

    # Validation
    model_ft.eval()  # Set the model to evaluation mode
    correct = 0  # Initialize correct predictions counter
    total = 0  # Initialize total samples counter

    with torch.no_grad():  # Disable gradient calculation for validation
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)  # Move images and labels to the device
            outputs = model_ft(images)  # Forward pass: compute predicted outputs by passing inputs to the model
            _, predicted = torch.max(outputs.data, 1)  # Get the class label with the highest probability
            total += labels.size(0)  # Update total samples
            correct += (predicted == labels).sum().item()  # Update correct predictions

    print(f'Validation Accuracy: {100 * correct / total}%')  # Print validation accuracy for the epoch

# Save the fine-tuned model
torch.save(model_ft.state_dict(), 'clip_finetuned.pth')  # Save the model's state dictionary

Epoch 1/20, Loss: 2.1224: 100%|██████████| 192/192 [00:19<00:00,  9.63it/s]


Epoch [1/20], Loss: 2.1224
Validation Accuracy: 52.35294117647059%


Epoch 2/20, Loss: 1.5708: 100%|██████████| 192/192 [00:17<00:00, 11.10it/s]


Epoch [2/20], Loss: 1.5708
Validation Accuracy: 67.25490196078431%


Epoch 3/20, Loss: 1.2535: 100%|██████████| 192/192 [00:17<00:00, 10.74it/s]


Epoch [3/20], Loss: 1.2535
Validation Accuracy: 74.83660130718954%


Epoch 4/20, Loss: 1.0587: 100%|██████████| 192/192 [00:18<00:00, 10.43it/s]


Epoch [4/20], Loss: 1.0587
Validation Accuracy: 76.73202614379085%


Epoch 5/20, Loss: 0.9165: 100%|██████████| 192/192 [00:18<00:00, 10.59it/s]


Epoch [5/20], Loss: 0.9165
Validation Accuracy: 81.63398692810458%


Epoch 6/20, Loss: 0.8172: 100%|██████████| 192/192 [00:17<00:00, 10.77it/s]


Epoch [6/20], Loss: 0.8172
Validation Accuracy: 82.94117647058823%


Epoch 7/20, Loss: 0.7378: 100%|██████████| 192/192 [00:17<00:00, 11.00it/s]


Epoch [7/20], Loss: 0.7378
Validation Accuracy: 85.62091503267973%


Epoch 8/20, Loss: 0.6733: 100%|██████████| 192/192 [00:17<00:00, 11.00it/s]


Epoch [8/20], Loss: 0.6733
Validation Accuracy: 85.94771241830065%


Epoch 9/20, Loss: 0.6206: 100%|██████████| 192/192 [00:16<00:00, 11.42it/s]


Epoch [9/20], Loss: 0.6206
Validation Accuracy: 87.45098039215686%


Epoch 10/20, Loss: 0.5769: 100%|██████████| 192/192 [00:16<00:00, 11.35it/s]


Epoch [10/20], Loss: 0.5769
Validation Accuracy: 88.03921568627452%


Epoch 11/20, Loss: 0.5477: 100%|██████████| 192/192 [00:17<00:00, 10.75it/s]


Epoch [11/20], Loss: 0.5477
Validation Accuracy: 88.16993464052288%


Epoch 12/20, Loss: 0.5421: 100%|██████████| 192/192 [00:17<00:00, 11.07it/s]


Epoch [12/20], Loss: 0.5421
Validation Accuracy: 88.30065359477125%


Epoch 13/20, Loss: 0.5390: 100%|██████████| 192/192 [00:17<00:00, 11.26it/s]


Epoch [13/20], Loss: 0.5390
Validation Accuracy: 88.56209150326798%


Epoch 14/20, Loss: 0.5352: 100%|██████████| 192/192 [00:17<00:00, 10.90it/s]


Epoch [14/20], Loss: 0.5352
Validation Accuracy: 88.75816993464052%


Epoch 15/20, Loss: 0.5319: 100%|██████████| 192/192 [00:17<00:00, 10.84it/s]


Epoch [15/20], Loss: 0.5319
Validation Accuracy: 88.62745098039215%


Epoch 16/20, Loss: 0.5281: 100%|██████████| 192/192 [00:17<00:00, 11.24it/s]


Epoch [16/20], Loss: 0.5281
Validation Accuracy: 88.88888888888889%


Epoch 17/20, Loss: 0.5242: 100%|██████████| 192/192 [00:17<00:00, 11.09it/s]


Epoch [17/20], Loss: 0.5242
Validation Accuracy: 88.75816993464052%


Epoch 18/20, Loss: 0.5197: 100%|██████████| 192/192 [00:17<00:00, 10.77it/s]


Epoch [18/20], Loss: 0.5197
Validation Accuracy: 88.82352941176471%


Epoch 19/20, Loss: 0.5151: 100%|██████████| 192/192 [00:17<00:00, 10.82it/s]


Epoch [19/20], Loss: 0.5151
Validation Accuracy: 89.15032679738562%


Epoch 20/20, Loss: 0.5111: 100%|██████████| 192/192 [00:17<00:00, 11.15it/s]


Epoch [20/20], Loss: 0.5111
Validation Accuracy: 89.01960784313725%


In [13]:
# Testing
with torch.no_grad():  # Disable gradient calculation for validation
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move images and labels to the device
            outputs = model_ft(images)  # Forward pass: compute predicted outputs by passing inputs to the model
            _, predicted = torch.max(outputs.data, 1)  # Get the class label with the highest probability
            total += labels.size(0)  # Update total samples
            correct += (predicted == labels).sum().item()  # Update correct predictions

print(f'Test Accuracy: {100 * correct / total}%')  # Print validation accuracy for the epoch

Test Accuracy: 80.64102564102564%
