### Imports

In [ ]:
import json
import numpy as np
import os

from collections import defaultdict
from PIL import Image
from sklearn.linear_model import Ridge
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision.transforms import v2
from torchinfo import summary

### Device

In [ ]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

### Constants

In [ ]:
BATCH_SIZE = 32
NUM_EPOCHS = 100
IMAGE_DIR = '../data/train/'
TEST_IMAGE_DIR = '../data/test_set_final_image_set/'

### Load and split data

### Custom dataset

In [ ]:
class ImageQuery(Dataset):
    def __init__(self, dataset, img_dir, augmentations=None, preprocessing=None):
        self.image_paths = [img_dir + img for img in dataset.keys()]
        # Tokenize text using CLIP's tokenizer
        self.queries  = clip.tokenize(list(dataset.values()))
        self.augmentations = augmentations
        self.preprocessing = preprocessing

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        if self.augmentations:
            # Preprocess image using augmentation transforms
            image = self.augmentations(Image.open(self.image_paths[idx]))
        else:
            image = Image.open(self.image_paths[idx])
        # Preprocess image using CLIP's preprocessing function
        image = self.preprocessing(image)
        query = self.queries[idx]
        return image, query

train_dataloader = DataLoader(ImageQuery(train_dataset, IMAGE_DIR, augmentations=augment, preprocessing=clip_preprocess), batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(ImageQuery(val_dataset, IMAGE_DIR, preprocessing=clip_preprocess), batch_size=BATCH_SIZE)

### Model

In [ ]:
class Embedding(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.resnet = models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V2')
        self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
        
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 768),
        )

    def forward(self, input):
        x = self.resnet(input)
        output = self.model(x)
        norm_layer = F.normalize(output)
        return norm_layer
    
embedding = Embedding().to(device)

### Optimizer

In [ ]:
optimizer = torch.optim.Adam(
    siamese.parameters(),
    lr=1e-3,
    eps=1e-6,
    weight_decay=1e-3,
    )
scheduler = ExponentialLR(optimizer, gamma=0.9)

### Early stopping

In [ ]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
early_stopper = EarlyStopper(patience=3, min_delta=1e-3)

### Training loop

In [ ]:
train_batches = len(train_dataloader)
val_batches = len(val_dataloader)

for epoch in range(NUM_EPOCHS):
    running_train_loss = 0.0
    running_val_loss = 0.0
    
    pbar = tqdm(train_dataloader, total=train_batches)
    model.train()
    for idx, batch in enumerate(pbar):
        optimizer.zero_grad()

        train_images, train_texts = batch

        train_images = train_images.to(device)
        train_texts = train_texts.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(train_images, train_texts)

        # Compute loss
        ground_truth = torch.arange(len(train_images), dtype=torch.long, device=device)
        train_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        running_train_loss += train_loss

        # Backward pass
        train_loss.backward()
        optimizer.step()
        clip.model.convert_weights(model)

        pbar.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Training loss: {running_train_loss / (idx + 1):.2E}")
    scheduler.step()

    pbar_val = tqdm(val_dataloader, total=val_batches)
    model.eval()
    with torch.no_grad():
        for idx, batch_val in enumerate(pbar_val):
            val_images, val_texts = batch_val
    
            val_images = val_images.to(device)
            val_texts = val_texts.to(device)
            
            # Forward pass
            val_logits_per_image, val_logits_per_text = model(val_images, val_texts)

            # Compute loss
            val_ground_truth = torch.arange(len(val_images), dtype=torch.long, device=device)
            val_loss = (loss_img(val_logits_per_image, val_ground_truth) + loss_txt(val_logits_per_text, val_ground_truth)) / 2
            running_val_loss += val_loss
            
            pbar_val.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Validation loss: {running_val_loss / (idx + 1):.2E}")
            
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': running_train_loss,
        'val_loss': running_val_loss,
    }, f"../data/clip_ft_{epoch + 1}.pt")