## Cell 1 — Environment & Data Setup

This cell prepares the workspace for the project:

* Creates a local data/ directory to store all dataset files.

* Downloads a dataset from Google Drive .

* Clones the required GitHub repository (challenge) which contains additional code and utilities used throughout this notebook (evaluation or model and dataset loading functions).

To load translator of the submission go to "Model, Training, and Dataset Parameters" section and set:
-TRANSLATOR_PATH = "/kaggle/input/translator903/other/default/1/translator.pth"

In [1]:
import kagglehub
print("Downloading inputs")
# Download latest version
path = kagglehub.dataset_download("ferruccioliu/openai-clip-vit-large-patch14")
path = kagglehub.dataset_download("niccolosici/newdataset40000-50000")
path = kagglehub.dataset_download("niccolosici/newdataset10000-20000")
path = kagglehub.dataset_download("niccolosici/aml-dataset")
path = kagglehub.dataset_download("niccolosici/newdataset")
path = kagglehub.model_download("niccolosici/translator903/other/default")
path = kagglehub.model_download("niccolosici/dec903/other/default")
path = kagglehub.model_download("niccolosici/enc903/other/default")


Downloading inputs


In [2]:
!mkdir data
!gdown 1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
!gdown 1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe
!unzip -o test.zip -d data
!unzip -o train.zip -d data
!git clone https://github.com/Mamiglia/challenge.git

Downloading...
From: https://drive.google.com/uc?id=1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
To: /kaggle/working/test.zip
100%|██████████████████████████████████████| 5.80M/5.80M [00:00<00:00, 18.8MB/s]
Failed to retrieve file url:

	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe

but Gdown can't. Please check connections and permissions.
Archive:  test.zip
   creating: data/test/
  inflating: data/test/captions.txt  
  inflating: data/test/test.clean.npz  
unzip:  cannot find or open train.zip, train.zip.zip or train.zip.ZIP.
Cloning into 'challenge'...
remote: Enumera

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
import random
from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval import evaluate_retrieval
import numpy as np
import math
import pickle

from transformers import CLIPModel,CLIPProcessor
import torch
import numpy as np
from tqdm import tqdm
import transformers.utils.hub as hub_utils 
from PIL import Image
import pandas as pd
from collections import defaultdict

2025-11-17 12:45:47.609492: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763383547.810463      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763383547.866797      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [4]:

def compute_mrr_at_k_batched(text_proj, image_emb, k=100, batch_size=128):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    
    N = text_proj.shape[0]
    reciprocal_ranks = []
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            batch_t = text_proj[start:end]
            
            sims = torch.matmul(batch_t, image_emb.T)
            top_k_values, top_k_indices = torch.topk(sims, k=min(k, N), dim=1)
            
            for i in range(end - start):
                true_idx = start + i
                top_k_for_query = top_k_indices[i].cpu().numpy()
                position = (top_k_for_query == true_idx).nonzero()
                
                if len(position[0]) > 0:
                    rank = position[0][0] + 1
                    reciprocal_ranks.append(1.0 / rank)
                else:
                    reciprocal_ranks.append(0.0)
    
    return sum(reciprocal_ranks) / len(reciprocal_ranks)


def compute_recall_at_k(text_proj, image_emb, k=1, batch_size=128):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    
    N = text_proj.shape[0]
    correct = 0
    
    with torch.no_grad():
        for start in range(0, N, batch_size):
            end = min(start + batch_size, N)
            batch_t = text_proj[start:end]
            
            sims = torch.matmul(batch_t, image_emb.T)
            top_k_indices = torch.topk(sims, k=min(k, N), dim=1)[1]
            
            for i in range(end - start):
                true_idx = start + i
                if true_idx in top_k_indices[i].cpu().numpy():
                    correct += 1
    
    return correct / N

## Main Dataset Construction Components

This cell defines the two core functions responsible for preparing the datasets used in the two-stage training pipeline:
- Stage 1: Building the Encoder Dataset
- Stage 2: Building the Decoder Dataset

### extract_dataset_CLIP_EM() — Building the Encoder Training Dataset

Prepares the dataset used to train the Stage 1 encoder MLP, whose purpose is to map external caption embeddings (e.g., SBERT, DINO, or other text encoders) into the CLIP text-embedding space:
- Loads existing caption embeddings from the original dataset.
- Loads several external datasets containing additional caption embeddings and caption texts.
- Merges all caption embeddings into one unified collection.
- Uses a local CLIP model to compute CLIP text embeddings for every caption (these become the targets).
- Optionally applies data augmentation at the embedding level (noise, dropout, mixup).
- Splits the dataset into training and validation sets.
- Builds PyTorch DataLoader objects for efficient batching.

### extract_dataset_from() — Building the Decoder Training Dataset

Prepares the dataset used to train the Stage 2 decoder MLP, whose role is to map CLIP-aligned caption embeddings into the image embedding space:
- Loads the original training dataset (with image embeddings and caption embeddings).
- Loads multiple external datasets containing additional image and caption embeddings.
- Concatenates all old and new caption embeddings into one dataset.
- Uses the previously trained encoder model (from Stage 1) to convert caption embeddings into the CLIP-aligned space.
- Matches each caption embedding to the correct image embedding (targets).
- Optionally applies augmentation on the caption embeddings.
- Splits the transformed dataset into training and validation sets.
- Wraps everything into PyTorch dataloaders.


In [5]:
def load_data_from_dataset_img_emb(new_dataset_file = "/kaggle/input/newdataset/coco_sbert_dino.pkl"):
    
    with open(new_dataset_file, "rb") as f:
        new_dataset = pickle.load(f)
    
    new_caption_emb = np.vstack([entry['caption_embedding'] for entry in new_dataset])
    new_image_emb_np = np.stack([entry["img_embedding"] for entry in new_dataset])  # shape: (N_new, 1536)
    return new_caption_emb, torch.from_numpy(new_image_emb_np).float()

def extract_dataset_from(dataset_path,dataset_par, augmentation_par,encoder_model_par, encoder_path):

    train_data = load_data(dataset_path)
    
    TRAIN_SIZE = dataset_par["TRAIN_SIZE"]
    BATCH_SIZE = dataset_par["BATCH_SIZE"]
    USE_AUGMENTATION =  augmentation_par["USE_AUGMENTATION"]
    NUM_AUGMENTED_COPIES = augmentation_par["NUM_AUGMENTED_COPIES"]
    NOISE_STD = augmentation_par["NOISE_STD"]
    DROPOUT_AUG_PROB = augmentation_par["DROPOUT_AUG_PROB"]
    MIXUP_ALPHA = augmentation_par["MIXUP_ALPHA"]
    MIXUP_PROB_BATCH = augmentation_par["MIXUP_PROB_BATCH"]
    
    print("\nLoading data...")
    
    new_caption_emb0,new_image_emb0 = load_data_from_dataset_img_emb("/kaggle/input/newdataset/coco_sbert_dino.pkl")
    new_caption_emb1,new_image_emb1 = load_data_from_dataset_img_emb("/kaggle/input/newdataset10000-20000/coco_sbert_dino_10000_to_20000.pkl")
    new_caption_emb2,new_image_emb2 = load_data_from_dataset_img_emb("/kaggle/input/newdataset40000-50000/coco_sbert_dino_40000_to_50000.pkl")
    
    
    new_caption_emb = [new_caption_emb0,new_caption_emb1,new_caption_emb2]
    new_image_emb = [new_image_emb0, new_image_emb1,new_image_emb2]

    # Convert to torch tensor
    
    old_embeddings = train_data['captions/embeddings']
    image_embeddings = torch.from_numpy(train_data['images/embeddings']).float()

    old_embeddings = np.vstack([old_embeddings]+new_caption_emb)
    
    old_embeddings = torch.from_numpy(old_embeddings).float()
    caption_label = train_data['captions/label']
    
    caption_to_image_idx = np.argmax(caption_label, axis=1)
    target_image_embeddings = image_embeddings[caption_to_image_idx]

    target_image_embeddings = torch.cat([target_image_embeddings]+ new_image_emb, dim=0)
    encoder_model = encoder_model_par["CREATE_MODEL"](encoder_model_par)
    encoder_model.load_state_dict(torch.load(encoder_path))
    encoder_model.eval()

    
    print("Transforming embeddings through Stage 1...")
    with torch.no_grad():
        batch_size_transform = 10000
        clip_embeddings_list = []
    
        for i in tqdm(range(0, len(old_embeddings), batch_size_transform)):
            batch = old_embeddings[i:i+batch_size_transform].to(DEVICE)
            clip_batch = encoder_model(batch)
            clip_embeddings_list.append(clip_batch.cpu())
    
        clip_embeddings = torch.cat(clip_embeddings_list, dim=0)
    
    
    n_train = len(clip_embeddings)-int((len(clip_embeddings)*0.1 // 100) * 100)
    indices = torch.randperm(len(clip_embeddings))
   
    X_train_orig, X_val = clip_embeddings[indices[:n_train]], clip_embeddings[indices[n_train:]]
    y_train_orig, y_val = target_image_embeddings[indices[:n_train]], target_image_embeddings[indices[n_train:]]
    
    print(f"Original Train: {len(X_train_orig)}, Val: {len(X_val)}")
    
    # Initialize augmenter
    X_train, y_train = X_train_orig, y_train_orig
    
    if USE_AUGMENTATION and NUM_AUGMENTED_COPIES > 0:
        augmenter = EmbeddingAugmenter(
            noise_std=NOISE_STD,
            dropout_prob=DROPOUT_AUG_PROB,
            mixup_alpha=MIXUP_ALPHA
        )
        print(f"\n Augmentation Configuration:")
        print(f"   - Augmented copies per sample: {NUM_AUGMENTED_COPIES}")
        print(f"   - Gaussian noise: std={NOISE_STD}")
        print(f"   - Embedding dropout: prob={DROPOUT_AUG_PROB}")
        print(f"   - Mixup (batch-level): alpha={MIXUP_ALPHA}, prob={MIXUP_PROB_BATCH}")
        
        # Create augmented training set
        X_train, y_train = create_augmented_dataset(
            X_train_orig, 
            y_train_orig, 
            augmenter, 
            NUM_AUGMENTED_COPIES
        )
    else:
        print("\n Augmentation disabled")
    
    print(f"\n Final Training Size: {len(X_train)}")
    print(f" Validation Size: {len(X_val)} (no augmentation)")
    
    # Clean up CLIP model to free memory
    # The original code had clip_tokenizer here, but it's not defined. Removed it.
    torch.cuda.empty_cache()
    
    # Create datasets
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                             num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, 
                           num_workers=0, pin_memory=True)
    
    return train_loader, val_loader




def image_caption_lists():
    captions_file = "/kaggle/input/aml-competition/train/train/captions.txt"
    
    # Read the file
    df = pd.read_csv(captions_file, sep=',', header=None, names=['image', 'caption'])
    
    # Get lists directly from dataframe (preserves duplicates and order)
    image_names = df['image'].tolist()
    text_captions = df['caption'].tolist()
    
    return image_names[1:], text_captions[1:]
def load_img_clip_dict():
    file_pt= "/kaggle/input/img-to-clip/y_clip(1).npz"
    
    # Load the saved .npz file
    data = np.load(file_pt, allow_pickle=True)
    
    # Extract arrays
    embeddings = data['embeddings']       # shape: (N, D)
    image_names = data['image_names']     # shape: (N,)
    
    # Convert image names to plain Python strings (if they’re bytes)
    image_names = [str(name) for name in image_names]
    
    # Create dictionary: {image_name: emb_vec}
    image_to_emb = {
        img_name: torch.from_numpy(emb_vec).float() 
        for img_name, emb_vec in zip(image_names, embeddings)
    }
    return image_to_emb
def load_data_from_dataset(new_dataset_file = "/kaggle/input/newdataset/coco_sbert_dino.pkl"):
    
    with open(new_dataset_file, "rb") as f:
        new_dataset = pickle.load(f)
    
    new_caption_emb = np.vstack([entry['caption_embedding'] for entry in new_dataset])
    new_caption_texts = [entry['caption'] for entry in new_dataset]
    return new_caption_emb, new_caption_texts
    
def extract_dataset_CLIP_EM(dataset_path, dataset_par, augmentation_par, clip_model_name="openai/clip-vit-large-patch14-336"):
    print("NEW vertotot")
    """
    Extract dataset with pre-existing caption embeddings as input (X) 
    and CLIP-generated embeddings as targets (y)
    
    Args:
        dataset_path: Path to dataset
        dataset_par: Dataset parameters
        augmentation_par: Augmentation parameters
        clip_model_name: CLIP model to use
    
    Returns:
        train_loader, val_loader with (caption_emb, clip_text_emb) pairs
    """
    
    TRAIN_SIZE = dataset_par["TRAIN_SIZE"]
    BATCH_SIZE = dataset_par["BATCH_SIZE"]
    USE_AUGMENTATION = augmentation_par["USE_AUGMENTATION"]
    NUM_AUGMENTED_COPIES = augmentation_par["NUM_AUGMENTED_COPIES"]
    NOISE_STD = augmentation_par["NOISE_STD"]
    DROPOUT_AUG_PROB = augmentation_par["DROPOUT_AUG_PROB"]
    MIXUP_ALPHA = augmentation_par["MIXUP_ALPHA"]
    MIXUP_PROB_BATCH = augmentation_par["MIXUP_PROB_BATCH"]
    
    
    new_caption_emb0,new_caption_texts0 = load_data_from_dataset("/kaggle/input/newdataset/coco_sbert_dino.pkl")
    new_caption_emb1,new_caption_texts1 = load_data_from_dataset("/kaggle/input/newdataset10000-20000/coco_sbert_dino_10000_to_20000.pkl")
    new_caption_emb2,new_caption_texts2 = load_data_from_dataset("/kaggle/input/newdataset40000-50000/coco_sbert_dino_40000_to_50000.pkl")
    
    new_caption_emb = [new_caption_emb0,new_caption_emb1, new_caption_emb2]
    new_caption_texts = new_caption_texts0+ new_caption_texts1+new_caption_texts2
    
    train_data = np.load(dataset_path)
    
    old_embeddings = train_data['captions/embeddings']
    caption_text = train_data['captions/text']
    
    print(f"Old embeddings: {old_embeddings.shape}")
    print(f"Captions: {len(caption_text)}")

    old_embeddings = np.vstack([old_embeddings]+ new_caption_emb)
    caption_text  =list(caption_text) + new_caption_texts
    
    print("Concatenated shapes:", len(old_embeddings), len(caption_text))

    print("\nGenerating CLIP embeddings...")
    model_clip = CLIPModel.from_pretrained(clip_model_name,local_files_only=True)
    processor = CLIPProcessor.from_pretrained(clip_model_name,local_files_only=True)
    model_clip = model_clip.to(DEVICE)
    model_clip.eval()
    
    all_clip_embeddings = []
    caption_list = [str(cap) for cap in caption_text]
    
    with torch.no_grad():
        for i in tqdm(range(0, len(caption_list), 256), desc="CLIP embeddings"):
            batch_captions = caption_list[i:i+256]
            inputs = processor(text=batch_captions, return_tensors="pt", padding=True,
                              truncation=True, max_length=77).to(DEVICE)
            text_embeddings = model_clip.get_text_features(**inputs)
            text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
            all_clip_embeddings.append(text_embeddings.cpu())
    
            if (i // BATCH_SIZE) % 10 == 0:
                torch.cuda.empty_cache()
    
    clip_embeddings = torch.cat(all_clip_embeddings, dim=0)
    print(f"Generated CLIP embeddings: {clip_embeddings.shape}")
    
    np.savez("clip_txt_emb.npz", clip_embeddings=clip_embeddings.numpy(),
             caption_ids=train_data['captions/ids'] if 'captions/ids' in train_data else np.arange(len(clip_embeddings)))
    
    del model_clip, processor
    torch.cuda.empty_cache()
    
    print("\nTraining Stage 1 MLP...")
    X = torch.from_numpy(old_embeddings).float()
    y = clip_embeddings.float()
    
    n_train = len(X)-int((len(X)*0.1 // 100) * 100)
    indices = torch.randperm(len(X))
    X_train, X_val = X[indices[:n_train]], X[indices[n_train:]]
    y_train, y_val = y[indices[:n_train]], y[indices[n_train:]]
    
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)
    return train_loader, val_loader



##  Loss Functions for Text-Image Embeddings

This cell defines loss functions for embeddings alignedment:

- `symmetric_contrastive_loss`: computes contrastive loss between text and image embeddings in both directions.  
- `cosine_loss`: computes the mean cosine similarity loss between predicted and target embeddings.  
- `triplet_loss`: computes triplet loss with hardest negative mining.  
- `combined_loss`: combines contrastive and triplet losses with a weighted sum using specified hyperparameters.

The cosine loss is used by the encoder to map the CLIP text encoder embedding space. While the combined loss is used by the decoder to map CLIP embedding into 1536 dimesnional image embedding space


In [6]:


def symmetric_contrastive_loss(text_proj, image_emb, temperature=0.05):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    logits = torch.matmul(text_proj, image_emb.T) / temperature
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)
    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)
    return (loss_t2i + loss_i2t) / 2

def cosine_loss(pred, target, loss_arg=None):
    pred_norm = F.normalize(pred, dim=-1)
    target_norm = F.normalize(target, dim=-1)
    cos_sim = (pred_norm * target_norm).sum(dim=-1)
    return (1 - cos_sim).mean()


    
def triplet_loss(text_proj, image_emb, margin=0.4):
    text_proj = F.normalize(text_proj, dim=-1)
    image_emb = F.normalize(image_emb, dim=-1)
    batch_size = text_proj.shape[0]
    sims = torch.matmul(text_proj, image_emb.T)
    pos_sims = sims.diagonal()
    mask = 1.0 - torch.eye(batch_size, device=sims.device)
    neg_sims = sims * mask + torch.eye(batch_size, device=sims.device) * -1e9
    hard_neg_sims, _ = neg_sims.max(dim=1)
    loss = F.relu(margin - pos_sims + hard_neg_sims).mean()
    return loss


def combined_loss(text_proj, image_emb, loss_arg):
    temperature=loss_arg["TEMPERATURE"]
    alpha=loss_arg["ALPHA"]
    margin=loss_arg["MARGIN"]
    contrastive =  symmetric_contrastive_loss(text_proj, image_emb, temperature)
    triplet = triplet_loss(text_proj, image_emb, margin=margin)
    
    return contrastive*alpha + (1-alpha)*triplet
    #return (text_proj - image_emb).norm(dim=1)


# MLP and Transformer Model Constructors

This cell defines functions and classes to create various neural network models used in the pipeline:

-  `create_MLP_contrastive`,`create_Dec`: helper functions to instantiate models with parameters from `model_par` and move them to the device.
- `ProjectionMLP`: Encoder MLP for projecting embeddings into a shared space, with skip connections and BatchNorm.
- `DecoderMLP`: decoder MLP mapping latent embeddings to image embeddings, with skip connections.
- (Commented-out code): alternative residual-block-based `DecoderMLP` architecture.


In [7]:


def create_MLP_contrastive(model_par):
    return ProjectionMLP(model_par["MLP_T_EM"]["INPUT_DIM"], 
                          768, 
                          model_par["MLP_T_EM"]["HIDDEN_DIMS"],
                          model_par["MLP_T_EM"]["DROPOUT"]
    ).to(DEVICE)

def create_Dec(model_par):
    return DecoderMLP(
        input_dim=model_par["INPUT_DIM"],
        output_dim=model_par["OUTPUT_DIM"],
        hidden_dims = model_par["HIDDEN_DIMS"],
        dropout = model_par["DROPOUT"],
    ).to(DEVICE)
    
def create_AE_Translator(encoder, decoder):
    return Translator(encoder, decoder).to(DEVICE)
    
class Translator(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
    
        self.encoder = encoder
        self.decoder = decoder

        
    def forward(self, x):
        x = self.encoder(x)
        out = self.decoder(x)
        return out


class ProjectionMLP(nn.Module):
    """MLP for projecting embeddings to shared space"""
    
    def __init__(self, input_dim, output_dim, hidden_dims, dropout):
        super().__init__()
        print("Creating ProjectionMLP")
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout if i < len(hidden_dims) - 1 else dropout * 0.5)
            ])
            prev_dim = hidden_dim
        
        layers.extend([
            nn.Linear(prev_dim, output_dim),
            nn.BatchNorm1d(output_dim)
        ])
        
        self.network = nn.Sequential(*layers)
        self.skip = nn.Linear(input_dim, output_dim)
        self.skip_weight = nn.Parameter(torch.tensor(0.1))
        
    def forward(self, x):
        out = self.network(x)
        out = F.normalize(out, dim=-1)
        skip = self.skip(x)
        return out + self.skip_weight * skip





class DecoderMLP(nn.Module):
    """MLP decoder from latent space to image embedding space"""
    
    def __init__(self, input_dim=768, output_dim=1536, hidden_dims=[1024, 1280, 1408], dropout=0.2):
        super().__init__()
        # Load encoder (NO FREEZING)
        print("LayerNorm", hidden_dims)
        # That's it. Encoder is trainable.
        
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout if i < len(hidden_dims) - 1 else dropout * 0.5)
            ])
            prev_dim = hidden_dim
        
        layers.extend([nn.Linear(prev_dim, output_dim)])
        
        self.network = nn.Sequential(*layers)
        self.skip = nn.Linear(input_dim, output_dim)
        self.skip_weight = nn.Parameter(torch.tensor(0.1))
        
    def forward(self, x):

        out = self.network(x)
        out = F.normalize(out, dim=-1)
        skip = self.skip(x)
        return out + self.skip_weight * skip

def create_model(model_par):
    print("\n2. Building model...")

    model = model_par["CREATE_MODEL"](model_par)
    
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    return model

In [8]:
def train_model(augmentation_par,train_par, loss_par,model, train_loader, val_loader, device, model_path,augmenter=None):
    LR =train_par["LR"]
    WEIGHT_DEC = train_par["WEIGHT_DEC"]
    WARMUP = train_par["WARMUP"]
    USE_AUGMENTATION = augmentation_par["USE_AUGMENTATION"]
    MIXUP_PROB_BATCH = augmentation_par["MIXUP_PROB_BATCH"]
    FLOW = train_par["FLOW"]
    epochs = train_par["EPOCHS"]
    
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DEC)
    warmup_epochs = WARMUP
    scheduler_warmup = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs)
    scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs-warmup_epochs, eta_min=1e-7)
    
    best_mrr = 0.0
    patience_counter = 0
    patience = 5
    loss_fn = loss_par["FUNC"]
    for epoch in range(epochs):
        model.train()
        train_loss = 0
    
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = loss_fn(outputs, y_batch,loss_par["ARG"])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
    
        train_loss /= len(train_loader)
    
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                outputs = model(X_batch)
                loss =  loss_fn(outputs, y_batch,loss_par["ARG"])
                val_loss += loss.item()
    
        val_loss /= len(val_loader)
    
        if (epoch + 1) % 2 == 0 or epoch == epochs - 1:
            all_preds = []
            all_targets = []
            model.eval()
            with torch.no_grad():
                for X_batch, y_batch in val_loader:
                    X_batch = X_batch.to(DEVICE)
                    y_batch = y_batch.to(DEVICE)
                    pred_batch = model(X_batch)
                    all_preds.append(pred_batch.cpu())
                    all_targets.append(y_batch.cpu())
    
            all_preds = torch.cat(all_preds, dim=0)
            all_targets = torch.cat(all_targets, dim=0)
    
            mrr_100 = compute_mrr_at_k_batched(all_preds.to(DEVICE), all_targets.to(DEVICE), k=100, batch_size=128)
            model.eval()
            print(all_preds.shape,all_targets.shape)
            print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}, MRR@100={mrr_100:.4f}", evaluate_retrieval(all_preds.cpu(), all_targets.cpu(), np.arange(len(all_preds))))
            
            if train_loss > val_loss:
                best_mrr = mrr_100
                patience_counter = 0
                Path(model_path).parent.mkdir(parents=True, exist_ok=True)
                torch.save(model.state_dict(), model_path)
                print(f"  New best: {mrr_100:.4f}")
            else:
                patience_counter += 1
    
            del all_preds, all_targets
            torch.cuda.empty_cache()
        else:
            print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")
    
        if epoch < warmup_epochs:
            scheduler_warmup.step()
        else:
            scheduler_cosine.step()
    
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    

# Model, Training, and Dataset Parameters

This cell defines all the **hyperparameters and configuration dictionaries** used in the pipeline:

- **Model parameters:** `MLP_T_EM`,  `MLP_ENC_PAR`, `MLP_DEC_PAR` define encoder and decoder MLP architectures and how they are created.  
- **Loss parameters:** `ENC_LOSS_PAR` and `DEC_LOSS_PAR` define the loss functions and their arguments for encoder and decoder training.  
- **Training parameters:** `ENC_TRAINING_PAR` and `DEC_TRAINING_PAR` define learning rates, warmup, epochs, weight decay, and whether to use flow.  
- **Dataset parameters:** `ENC_DATASET_PAR` and `DEC_DATASET_PAR` define batch sizes and training/validation split ratios.  
- **Augmentation parameters:** `AUGMENTATION_PAR` controls optional embedding-level augmentations like noise, dropout, and mixup.  
- **File and device settings:** `DATASET_PATH`, `DEVICE`, and `GRAD_CLIP`.


In [9]:
#DEFINE PARAMETERS
#MODEL
#Contrastive into 780 space
MLP_T_EM = {"NAME":"ModelTextToEMb", "HIDDEN_DIMS":[1024, 1024, 896],"INPUT_DIM":1024, "OUTPUT_DIM":768, "DROPOUT": 0.2}
MLP_ENC_PAR = {"NAME":"ModelContrastive", "MLP_T_EM": MLP_T_EM, "CREATE_MODEL": create_MLP_contrastive}

MLP_DEC_PAR = {"NAME":"ModelDecoder", "HIDDEN_DIMS":[896, 1000, 1200, 1400],"INPUT_DIM":768, "OUTPUT_DIM":1536, "DROPOUT": 0.2, "ENC_MODEL_PATH":"models/mlp_enc_txt_cos_loss.pth","DEC_MODEL_PATH":"models/mlp_dec_txt_cos_loss.pth", "ENCODER_PAR":MLP_ENC_PAR, "CREATE_MODEL":create_Dec}
#Decoder fron space 780
ENC_LOSS_PAR  = {"NAME":"COSINE_LOSS", "FUNC": cosine_loss, "ARG":None}
DEC_LOSS_PAR = {"NAME":"COMB_TRIPLET_AND_CONTR", "FUNC": combined_loss, "ARG":{"TEMPERATURE":0.007, "ALPHA":0.1, "MARGIN":0.4}}

#TRAINING
ENC_TRAINING_PAR = {"LR":0.002, "WARMUP":3, "EPOCHS": 30, "WEIGHT_DEC":0.01, "FLOW":False}
DEC_TRAINING_PAR = {"LR":0.006, "WARMUP":3, "EPOCHS": 30, "WEIGHT_DEC":0.01, "FLOW":False}


#DATASET
ENC_DATASET_PAR = {"BATCH_SIZE":4096*2, "TRAIN_SIZE":0.9}
DEC_DATASET_PAR = {"BATCH_SIZE":4096*5, "TRAIN_SIZE":0.9}

AUGMENTATION_PAR = {"USE_AUGMENTATION":False, "NUM_AUGMENTED_COPIES":2, "NOISE_STD":0.02, "DROPOUT_AUG_PROB":0.1,"MIXUP_ALPHA":0.2, "MIXUP_PROB_BATCH": 0.3}



#FILES
DATASET_PATH = "/kaggle/input/aml-dataset/train/train/train.npz"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRANSLATOR_PATH ="models/translator.pth"

GRAD_CLIP = None



## Create Encoder Dataset

This cell creates the **training and validation dataloaders** for the Stage 1 encoder by calling `extract_dataset_CLIP_EM` with:

- `DATASET_PATH`: path to the raw dataset  
- `ENC_DATASET_PAR`: dataset parameters (batch size, train/val split)  
- `AUGMENTATION_PAR`: optional embedding-level augmentations  
- CLIP model path: to compute CLIP text embeddings  


In [10]:
# CREATE DATASET
train_loader_clip, val_loader_clip = extract_dataset_CLIP_EM(DATASET_PATH, ENC_DATASET_PAR, AUGMENTATION_PAR,"/kaggle/input/openai-clip-vit-large-patch14/clip-vit-large-patch14")


NEW vertotot
Old embeddings: (125000, 1024)
Captions: 125000
Concatenated shapes: 275078 275078

Generating CLIP embeddings...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
CLIP embeddings: 100%|██████████| 1075/1075 [05:07<00:00,  3.50it/s]


Generated CLIP embeddings: torch.Size([275078, 768])

Training Stage 1 MLP...


## Train Encoder Model

- Instantiates the encoder model using `create_model` with `MLP_ENC_PAR`.  
- Trains the encoder using `train_model` with specified augmentation, training, and loss parameters, along with the encoder dataloaders.  
- Saves the trained encoder weights to `MLP_DEC_PAR["ENC_MODEL_PATH"]`.  
- Deletes the model from memory to free GPU resources.


In [11]:
encoder_model = create_model(MLP_ENC_PAR)

encoder_model = train_model(AUGMENTATION_PAR,ENC_TRAINING_PAR, ENC_LOSS_PAR, encoder_model,train_loader_clip, val_loader_clip, DEVICE, MLP_DEC_PAR["ENC_MODEL_PATH"])
del encoder_model


2. Building model...
Creating ProjectionMLP
Parameters: 4,501,121


Epoch 1/30: 100%|██████████| 31/31 [00:05<00:00,  5.70it/s]


Epoch 1: Train=0.6193, Val=0.4307


Epoch 2/30: 100%|██████████| 31/31 [00:05<00:00,  6.05it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 2: Train=0.3096, Val=0.2369, MRR@100=0.5344 {'mrr': 0.9594460624281227, 'ndcg': 0.9695585807895192, 'recall_at_1': 0.9309818181818181, 'recall_at_3': 0.9875636363636363, 'recall_at_5': 0.9937818181818182, 'recall_at_10': 0.9975636363636363, 'recall_at_50': 0.9997090909090909, 'l2_dist': 1.4779030084609985}
  New best: 0.5344


Epoch 3/30: 100%|██████████| 31/31 [00:05<00:00,  5.95it/s]


Epoch 3: Train=0.2099, Val=0.1834


Epoch 4/30: 100%|██████████| 31/31 [00:04<00:00,  6.62it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 4: Train=0.1793, Val=0.1666, MRR@100=0.6550 {'mrr': 0.973597339426007, 'ndcg': 0.9800774088632642, 'recall_at_1': 0.9561090909090909, 'recall_at_3': 0.9901090909090909, 'recall_at_5': 0.9942909090909091, 'recall_at_10': 0.9973818181818181, 'recall_at_50': 0.9996727272727273, 'l2_dist': 1.4246119260787964}
  New best: 0.6550


Epoch 5/30: 100%|██████████| 31/31 [00:04<00:00,  6.46it/s]


Epoch 5: Train=0.1666, Val=0.1563


Epoch 6/30: 100%|██████████| 31/31 [00:04<00:00,  6.62it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 6: Train=0.1581, Val=0.1500, MRR@100=0.7250 {'mrr': 0.9836648191857408, 'ndcg': 0.9877039895691545, 'recall_at_1': 0.9725090909090909, 'recall_at_3': 0.9942181818181818, 'recall_at_5': 0.9970909090909091, 'recall_at_10': 0.9987272727272727, 'recall_at_50': 0.9998181818181818, 'l2_dist': 1.4177603721618652}
  New best: 0.7250


Epoch 7/30: 100%|██████████| 31/31 [00:04<00:00,  6.48it/s]


Epoch 7: Train=0.1527, Val=0.1454


Epoch 8/30: 100%|██████████| 31/31 [00:04<00:00,  6.60it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 8: Train=0.1486, Val=0.1418, MRR@100=0.7511 {'mrr': 0.9866502150329218, 'ndcg': 0.9899549880791348, 'recall_at_1': 0.9773454545454545, 'recall_at_3': 0.9956, 'recall_at_5': 0.9977454545454545, 'recall_at_10': 0.9987636363636364, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.4085776805877686}
  New best: 0.7511


Epoch 9/30: 100%|██████████| 31/31 [00:04<00:00,  6.51it/s]


Epoch 9: Train=0.1447, Val=0.1385


Epoch 10/30: 100%|██████████| 31/31 [00:05<00:00,  6.14it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 10: Train=0.1420, Val=0.1359, MRR@100=0.7727 {'mrr': 0.9883567936598342, 'ndcg': 0.9912438904471653, 'recall_at_1': 0.9801090909090909, 'recall_at_3': 0.9962181818181818, 'recall_at_5': 0.9980363636363636, 'recall_at_10': 0.9989818181818182, 'recall_at_50': 0.9998181818181818, 'l2_dist': 1.378578782081604}
  New best: 0.7727


Epoch 11/30: 100%|██████████| 31/31 [00:04<00:00,  6.56it/s]


Epoch 11: Train=0.1393, Val=0.1338


Epoch 12/30: 100%|██████████| 31/31 [00:04<00:00,  6.49it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 12: Train=0.1372, Val=0.1317, MRR@100=0.7897 {'mrr': 0.9899128055110019, 'ndcg': 0.9924102002233594, 'recall_at_1': 0.9827272727272728, 'recall_at_3': 0.9967636363636364, 'recall_at_5': 0.9982545454545455, 'recall_at_10': 0.9988727272727272, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.3487788438796997}
  New best: 0.7897


Epoch 13/30: 100%|██████████| 31/31 [00:04<00:00,  6.55it/s]


Epoch 13: Train=0.1350, Val=0.1298


Epoch 14/30: 100%|██████████| 31/31 [00:04<00:00,  6.55it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 14: Train=0.1334, Val=0.1282, MRR@100=0.7965 {'mrr': 0.9897978690495881, 'ndcg': 0.992329038556231, 'recall_at_1': 0.9824727272727273, 'recall_at_3': 0.9968, 'recall_at_5': 0.9982909090909091, 'recall_at_10': 0.9991636363636364, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.323569655418396}
  New best: 0.7965


Epoch 15/30: 100%|██████████| 31/31 [00:04<00:00,  6.55it/s]


Epoch 15: Train=0.1315, Val=0.1268


Epoch 16/30: 100%|██████████| 31/31 [00:04<00:00,  6.59it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 16: Train=0.1301, Val=0.1255, MRR@100=0.8098 {'mrr': 0.9913270416280935, 'ndcg': 0.993471662878524, 'recall_at_1': 0.9851636363636364, 'recall_at_3': 0.9971272727272728, 'recall_at_5': 0.9984, 'recall_at_10': 0.9992, 'recall_at_50': 0.999890909090909, 'l2_dist': 1.2947150468826294}
  New best: 0.8098


Epoch 17/30: 100%|██████████| 31/31 [00:04<00:00,  6.59it/s]


Epoch 17: Train=0.1285, Val=0.1241


Epoch 18/30: 100%|██████████| 31/31 [00:04<00:00,  6.54it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 18: Train=0.1273, Val=0.1233, MRR@100=0.8139 {'mrr': 0.991759578990813, 'ndcg': 0.9938004527428591, 'recall_at_1': 0.9858545454545454, 'recall_at_3': 0.9975272727272727, 'recall_at_5': 0.9985818181818182, 'recall_at_10': 0.9992363636363636, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.2777440547943115}
  New best: 0.8139


Epoch 19/30: 100%|██████████| 31/31 [00:04<00:00,  6.59it/s]


Epoch 19: Train=0.1263, Val=0.1225


Epoch 20/30: 100%|██████████| 31/31 [00:04<00:00,  6.53it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 20: Train=0.1252, Val=0.1217, MRR@100=0.8203 {'mrr': 0.9921074257901399, 'ndcg': 0.9940662092606237, 'recall_at_1': 0.9864363636363637, 'recall_at_3': 0.9977454545454545, 'recall_at_5': 0.9986545454545455, 'recall_at_10': 0.9993090909090909, 'recall_at_50': 0.999890909090909, 'l2_dist': 1.2656447887420654}
  New best: 0.8203


Epoch 21/30: 100%|██████████| 31/31 [00:04<00:00,  6.63it/s]


Epoch 21: Train=0.1246, Val=0.1210


Epoch 22/30: 100%|██████████| 31/31 [00:05<00:00,  6.07it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 22: Train=0.1235, Val=0.1203, MRR@100=0.8240 {'mrr': 0.992204534128123, 'ndcg': 0.9941358100264702, 'recall_at_1': 0.9865818181818182, 'recall_at_3': 0.9977090909090909, 'recall_at_5': 0.9986545454545455, 'recall_at_10': 0.9993090909090909, 'recall_at_50': 0.999890909090909, 'l2_dist': 1.2479870319366455}
  New best: 0.8240


Epoch 23/30: 100%|██████████| 31/31 [00:05<00:00,  6.19it/s]


Epoch 23: Train=0.1228, Val=0.1198


Epoch 24/30: 100%|██████████| 31/31 [00:05<00:00,  6.02it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 24: Train=0.1222, Val=0.1194, MRR@100=0.8277 {'mrr': 0.9920826748762205, 'ndcg': 0.9940434520086442, 'recall_at_1': 0.9864363636363637, 'recall_at_3': 0.9974545454545455, 'recall_at_5': 0.9985818181818182, 'recall_at_10': 0.9993090909090909, 'recall_at_50': 0.999890909090909, 'l2_dist': 1.238354206085205}
  New best: 0.8277


Epoch 25/30: 100%|██████████| 31/31 [00:04<00:00,  6.24it/s]


Epoch 25: Train=0.1216, Val=0.1190


Epoch 26/30: 100%|██████████| 31/31 [00:04<00:00,  6.44it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 26: Train=0.1212, Val=0.1188, MRR@100=0.8290 {'mrr': 0.9922682333523836, 'ndcg': 0.9941832706074266, 'recall_at_1': 0.9867272727272727, 'recall_at_3': 0.9976727272727273, 'recall_at_5': 0.9986181818181818, 'recall_at_10': 0.9993090909090909, 'recall_at_50': 0.999890909090909, 'l2_dist': 1.2339528799057007}
  New best: 0.8290


Epoch 27/30: 100%|██████████| 31/31 [00:04<00:00,  6.60it/s]


Epoch 27: Train=0.1208, Val=0.1186


Epoch 28/30: 100%|██████████| 31/31 [00:04<00:00,  6.48it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 28: Train=0.1206, Val=0.1185, MRR@100=0.8310 {'mrr': 0.9924028156160418, 'ndcg': 0.9942853365210274, 'recall_at_1': 0.9869454545454546, 'recall_at_3': 0.9977818181818182, 'recall_at_5': 0.9986181818181818, 'recall_at_10': 0.9993090909090909, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.2293031215667725}
  New best: 0.8310


Epoch 29/30: 100%|██████████| 31/31 [00:04<00:00,  6.68it/s]


Epoch 29: Train=0.1205, Val=0.1184


Epoch 30/30: 100%|██████████| 31/31 [00:04<00:00,  6.47it/s]


torch.Size([27500, 768]) torch.Size([27500, 768])
Epoch 30: Train=0.1203, Val=0.1184, MRR@100=0.8315 {'mrr': 0.9923810542736158, 'ndcg': 0.9942691468197736, 'recall_at_1': 0.986909090909091, 'recall_at_3': 0.9977454545454545, 'recall_at_5': 0.9986181818181818, 'recall_at_10': 0.9992727272727273, 'recall_at_50': 0.9998545454545454, 'l2_dist': 1.2275773286819458}
  New best: 0.8315


## Create Decoder Dataset

This cell creates the **training and validation dataloaders** for the Stage 2 decoder by calling `extract_dataset_from` with:

- `DATASET_PATH`: path to the raw dataset  
- `DEC_DATASET_PAR`: dataset parameters (batch size, train/val split)  
- `AUGMENTATION_PAR`: optional embedding-level augmentations  
- `MLP_ENC_PAR` and `MLP_DEC_PAR["ENC_MODEL_PATH"]`: use the trained encoder to convert captions into CLIP-aligned embeddings for the decoder.  


In [12]:
train_loader_dec, val_loader_dec = extract_dataset_from(DATASET_PATH, DEC_DATASET_PAR, AUGMENTATION_PAR,MLP_ENC_PAR, MLP_DEC_PAR["ENC_MODEL_PATH"] )



Loading data...
Creating ProjectionMLP
Transforming embeddings through Stage 1...


100%|██████████| 28/28 [00:01<00:00, 22.81it/s]


Original Train: 247578, Val: 27500

 Augmentation disabled

 Final Training Size: 247578
 Validation Size: 27500 (no augmentation)


## Train Decoder Model

- Instantiates the decoder model using `create_model` with `MLP_DEC_PAR`.  
- Trains the decoder using `train_model` with specified augmentation, training, and loss parameters, along with the decoder dataloaders.  
- Saves the trained decoder weights to `MLP_DEC_PAR["DEC_MODEL_PATH"]`.  
- Deletes the model from memory to free GPU resources.

After training the encoder is loaded and decoder and encoder are saved into one model called translator

In [13]:
decoder_model = create_model(MLP_DEC_PAR)

decoder_model = train_model(AUGMENTATION_PAR,DEC_TRAINING_PAR, DEC_LOSS_PAR, decoder_model,train_loader_dec, val_loader_dec, DEVICE, MLP_DEC_PAR["DEC_MODEL_PATH"])



#SAVING FINAL TRANSLATOR
encoder_model = MLP_ENC_PAR["CREATE_MODEL"](MLP_ENC_PAR)
encoder_model.load_state_dict(torch.load(MLP_DEC_PAR["ENC_MODEL_PATH"]))

decoder_model = MLP_DEC_PAR["CREATE_MODEL"](MLP_DEC_PAR)
decoder_model.load_state_dict(torch.load(MLP_DEC_PAR["DEC_MODEL_PATH"]))



translator = create_AE_Translator(encoder_model, decoder_model)
Path(TRANSLATOR_PATH).parent.mkdir(parents=True, exist_ok=True)

torch.save(translator.state_dict(), TRANSLATOR_PATH)

del decoder_model, encoder_model, translator



2. Building model...
LayerNorm [896, 1000, 1200, 1400]
Parameters: 7,810,737


Epoch 1/30: 100%|██████████| 13/13 [00:15<00:00,  1.20s/it]


Epoch 1: Train=1.3298, Val=0.9870


Epoch 2/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 2: Train=0.9286, Val=0.8156, MRR@100=0.1399 {'mrr': 0.8408355700510836, 'ndcg': 0.880123241239267, 'recall_at_1': 0.7389818181818182, 'recall_at_3': 0.9378181818181818, 'recall_at_5': 0.9701454545454545, 'recall_at_10': 0.9895636363636363, 'recall_at_50': 0.9989818181818182, 'l2_dist': 26.111242294311523}
  New best: 0.1399


Epoch 3/30: 100%|██████████| 13/13 [00:15<00:00,  1.20s/it]


Epoch 3: Train=0.8192, Val=0.7531


Epoch 4/30: 100%|██████████| 13/13 [00:15<00:00,  1.20s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 4: Train=0.7684, Val=0.7202, MRR@100=0.2551 {'mrr': 0.898170450478819, 'ndcg': 0.9235439175924549, 'recall_at_1': 0.8282545454545455, 'recall_at_3': 0.9662545454545455, 'recall_at_5': 0.9865090909090909, 'recall_at_10': 0.9946545454545455, 'recall_at_50': 0.9991272727272728, 'l2_dist': 26.001779556274414}
  New best: 0.2551


Epoch 5/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


Epoch 5: Train=0.7389, Val=0.7016


Epoch 6/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 6: Train=0.7193, Val=0.6903, MRR@100=0.2991 {'mrr': 0.915180987751522, 'ndcg': 0.9363660907900478, 'recall_at_1': 0.8549818181818182, 'recall_at_3': 0.9750181818181818, 'recall_at_5': 0.9888727272727272, 'recall_at_10': 0.9951636363636364, 'recall_at_50': 0.9993090909090909, 'l2_dist': 26.01835060119629}
  New best: 0.2991


Epoch 7/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


Epoch 7: Train=0.7091, Val=0.6830


Epoch 8/30: 100%|██████████| 13/13 [00:15<00:00,  1.20s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 8: Train=0.6953, Val=0.6747, MRR@100=0.3199 {'mrr': 0.9268492956631826, 'ndcg': 0.9451406324602721, 'recall_at_1': 0.8738181818181818, 'recall_at_3': 0.9796363636363636, 'recall_at_5': 0.9903272727272727, 'recall_at_10': 0.9960363636363636, 'recall_at_50': 0.9992727272727273, 'l2_dist': 26.05483627319336}
  New best: 0.3199


Epoch 9/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


Epoch 9: Train=0.6845, Val=0.6689


Epoch 10/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 10: Train=0.6795, Val=0.6691, MRR@100=0.3356 {'mrr': 0.9330453550379051, 'ndcg': 0.9497733969833199, 'recall_at_1': 0.8849090909090909, 'recall_at_3': 0.9802545454545455, 'recall_at_5': 0.990909090909091, 'recall_at_10': 0.996, 'recall_at_50': 0.9994181818181819, 'l2_dist': 26.088123321533203}
  New best: 0.3356


Epoch 11/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


Epoch 11: Train=0.6726, Val=0.6628


Epoch 12/30: 100%|██████████| 13/13 [00:15<00:00,  1.20s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 12: Train=0.6649, Val=0.6593, MRR@100=0.3460 {'mrr': 0.9352789850075122, 'ndcg': 0.95148012284062, 'recall_at_1': 0.888109090909091, 'recall_at_3': 0.981890909090909, 'recall_at_5': 0.9920363636363636, 'recall_at_10': 0.9963272727272727, 'recall_at_50': 0.9994545454545455, 'l2_dist': 26.09223175048828}
  New best: 0.3460


Epoch 13/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


Epoch 13: Train=0.6581, Val=0.6547


Epoch 14/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 14: Train=0.6515, Val=0.6531, MRR@100=0.3553 {'mrr': 0.9393077548053709, 'ndcg': 0.9544826456575993, 'recall_at_1': 0.8951636363636364, 'recall_at_3': 0.9826181818181818, 'recall_at_5': 0.9919636363636364, 'recall_at_10': 0.9965454545454545, 'recall_at_50': 0.9994909090909091, 'l2_dist': 26.099769592285156}


Epoch 15/30: 100%|██████████| 13/13 [00:15<00:00,  1.19s/it]


Epoch 15: Train=0.6483, Val=0.6517


Epoch 16/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 16: Train=0.6440, Val=0.6493, MRR@100=0.3599 {'mrr': 0.9403679479564251, 'ndcg': 0.9552811249932324, 'recall_at_1': 0.8969454545454545, 'recall_at_3': 0.9836727272727273, 'recall_at_5': 0.9921818181818182, 'recall_at_10': 0.9962545454545455, 'recall_at_50': 0.9994181818181819, 'l2_dist': 26.1003360748291}


Epoch 17/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


Epoch 17: Train=0.6387, Val=0.6480


Epoch 18/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 18: Train=0.6349, Val=0.6448, MRR@100=0.3703 {'mrr': 0.9429032805668198, 'ndcg': 0.9571669553383735, 'recall_at_1': 0.9013818181818182, 'recall_at_3': 0.9837454545454546, 'recall_at_5': 0.992, 'recall_at_10': 0.9965818181818182, 'recall_at_50': 0.9995272727272727, 'l2_dist': 26.103736877441406}


Epoch 19/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


Epoch 19: Train=0.6301, Val=0.6437


Epoch 20/30: 100%|██████████| 13/13 [00:15<00:00,  1.21s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 20: Train=0.6260, Val=0.6416, MRR@100=0.3765 {'mrr': 0.9439414502222384, 'ndcg': 0.9579480087010414, 'recall_at_1': 0.9032363636363636, 'recall_at_3': 0.9844727272727273, 'recall_at_5': 0.9922909090909091, 'recall_at_10': 0.9965090909090909, 'recall_at_50': 0.9993818181818181, 'l2_dist': 26.110321044921875}


Epoch 21/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


Epoch 21: Train=0.6222, Val=0.6401


Epoch 22/30: 100%|██████████| 13/13 [00:15<00:00,  1.18s/it]


torch.Size([27500, 1536]) torch.Size([27500, 1536])
Epoch 22: Train=0.6193, Val=0.6389, MRR@100=0.3807 {'mrr': 0.9437667990907244, 'ndcg': 0.9578132980797136, 'recall_at_1': 0.9028727272727273, 'recall_at_3': 0.9842545454545455, 'recall_at_5': 0.9924727272727273, 'recall_at_10': 0.9961454545454546, 'recall_at_50': 0.9993818181818181, 'l2_dist': 26.11033821105957}
Early stopping at epoch 22
Creating ProjectionMLP
LayerNorm [896, 1000, 1200, 1400]


In [14]:


print("\nGenerating test predictions...")
test_data = load_data("/kaggle/input/aml-dataset/test/test/test.clean.npz")
test_embds = torch.from_numpy(test_data['captions/embeddings']).float()

translator  = create_AE_Translator(create_model(MLP_ENC_PAR), create_model(MLP_DEC_PAR))
translator.load_state_dict(torch.load(TRANSLATOR_PATH))

translator.eval()
with torch.no_grad():
    test_clip_list = []
    pred_embds = translator(test_embds.to(DEVICE)).cpu()
    
submission = generate_submission(test_data['captions/ids'], pred_embds, 'submission.csv')
print(f"\nSubmission saved to: submission.csv")


Generating test predictions...

2. Building model...
Creating ProjectionMLP
Parameters: 4,501,121

2. Building model...
LayerNorm [896, 1000, 1200, 1400]
Parameters: 7,810,737
Generating submission file...
✓ Saved submission to submission.csv

Submission saved to: submission.csv
