# Train the embedding model

## Processing Setup

In [None]:
# Google colab
# import os
# from google.colab import drive
# drive.mount('/content/drive')
# project_home = '/content/drive/MyDrive/Projects/verge'
# os.chdir(project_home)

In [None]:
# Local processing setup
project_home = '..'

## Notebook Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from typing import List, Tuple, Optional

import pickle
import pandas as pd

from embedderv5 import *

## Parameters

In [None]:
# The name of the ROI to use.
roi_name = 'ne-laptop'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)

# The unique identifier of the model to be used.
run_id = '102'

# What type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)


## Load and organize data
We have two data sources that we need to associate with one another:
a set of initial embeddings ("initials") and a set of feature vectors
to be used for similarity assessments ("features").

In [None]:
# We will divide into training and validation sets based on AOI. 
# The splits have already been determined, before training the initial MGM.
# Here we look them up and re-organize things a bit.


In [None]:
# Get a list of tiles.
fname = '%s/tiles.csv' % roi_home
tile_info = pd.read_csv(fname)
print('%d tiles' % len(tile_info))
tile_info.head(3)

In [None]:
# Get the list of AOI tags.
aoi_tags = np.unique(tile_info['aoi_tag'])
print('%d unique AOIs' % len(aoi_tags))

In [None]:
# Load initial embeddings. Put them into a lookup table based on aoi/tile identifiers.
embeddings_lookup = {}

for aoi_tag in aoi_tags:

    fname = '%s/initials/%s.pkl' % (roi_home, aoi_tag)
    with open(fname, 'rb') as source:
        a = pickle.load(source)
    for b in a:
        key = '%s : %s' % (b['aoi_tag'], b['tile_tag'])
        e = b['embedding']
        embeddings_lookup[key] = e

print('%d total embeddings' % len(embeddings_lookup))


In [None]:
embedding_dim = e.shape[-1]
print('dimension of embeddings is %d' % embedding_dim)

In [None]:
# Load initial Features. Ditto.
features_lookup = {}

for aoi_tag in aoi_tags:

    fname = '%s/features/%s.pkl' % (roi_home, aoi_tag)
    with open(fname, 'rb') as source:
        a = pickle.load(source)
    for b in a:
        key = '%s : %s' % (b['aoi_tag'], b['tile_tag'])
        f = b['features']
        features_lookup[key] = f

print('%d total feature vectors' % len(features_lookup))


In [None]:
# Organize the data the way the model expects it.
sequences = []
similarity_features = []
for key in features_lookup.keys():
    f = features_lookup[key]
    if key in embeddings_lookup:
        e = embeddings_lookup[key].squeeze().detach().numpy()
        sequences.append(e)
        similarity_features.append(f)
    else:
        print('key mismatch: %s' % key)

print(type(sequences))
print(type(sequences[0]))
print(sequences[0].shape)

In [None]:
# # # This was the Claude-generated code to generate test data for the model. 
# sequences, similarity_features = generate_sample_data(
#     num_instances=1000, min_R=5, max_R=20, C=32, similarity_dim=16
# )
# print(type(sequences))
# print(type(sequences[0]))
# print(sequences[0].shape)

## Model
The model code, including data loaders, the model itself, loss function, and all that,
were generated by Claude via a lot of iterative prompting and debugging.


In [None]:
# Create dataset and dataloader with explicit triplet sampling
dataset = ContrastivePairDataset(
    sequences, 
    similarity_features, 
    similarity_threshold=0.5,  # Adjust based on your similarity features
    num_negatives=2  # Number of negatives per anchor
)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=triplet_collate_fn)


In [None]:
# Initialize model
model = PermutationInvariantModel(
    input_dim=embedding_dim,
    hidden_dim=128,
    embedding_dim=embedding_dim,
    num_attention_heads=4,  # Now using 4 attention heads
    num_linear_layers=3,
    dropout=0.1
)

print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")


In [None]:
# Train the model
train_model(model, train_loader, num_epochs=50, learning_rate=1e-3, device=device)


In [None]:
# # Example inference
# model.eval()
# with torch.no_grad():
#     sample_batch = next(iter(train_loader))
#     # Unpack the dictionary structure from triplet data loader
#     anchor_seqs, anchor_masks, anchor_sims = sample_batch['anchor']
#     anchor_seqs = anchor_seqs.to(device)
#     anchor_masks = anchor_masks.to(device)
    
#     # Generate embeddings for anchor samples
#     embeddings = model(anchor_seqs, anchor_masks)
#     print(f"Generated embeddings shape: {embeddings.shape}")
#     print(f"Sample embedding norm: {torch.norm(embeddings[0]).item():.4f}")


In [None]:
        
# Example inference
model.eval()
with torch.no_grad():
    sample_batch = next(iter(train_loader))
    # Unpack the dictionary structure from triplet data loader
    anchor_seqs, anchor_masks, anchor_sims = sample_batch['anchor']
    anchor_seqs = anchor_seqs.to(device)
    anchor_masks = anchor_masks.to(device)
    
    # Generate embeddings for anchor samples
    embeddings = model(anchor_seqs, anchor_masks)
    print(f"Generated embeddings shape: {embeddings.shape}")
    print(f"Sample embedding norm: {torch.norm(embeddings[0]).item():.4f}")
    
    # Get embeddings for positives
    pos_seqs, pos_masks, pos_sims = sample_batch['positive']
    pos_seqs, pos_masks = pos_seqs.to(device), pos_masks.to(device)
    pos_embeddings = model(pos_seqs, pos_masks)
    print(f"Positive embeddings shape: {pos_embeddings.shape}")
    
    # Get embeddings for negatives
    neg_seqs, neg_masks, neg_sims, neg_batch_indices = sample_batch['negatives']
    neg_seqs, neg_masks = neg_seqs.to(device), neg_masks.to(device)
    neg_batch_indices = neg_batch_indices.to(device)
    neg_embeddings = model(neg_seqs, neg_masks)
    print(f"Negative embeddings shape: {neg_embeddings.shape}")
    
    # Check similarity between anchors and positives
    pos_similarities = F.cosine_similarity(embeddings, pos_embeddings, dim=1)
    print(f"Anchor-Positive similarities: {pos_similarities.mean().item():.4f} ± {pos_similarities.std().item():.4f}")
    
    # Check similarity between anchors and negatives
    batch_size = embeddings.shape[0]
    neg_similarities_all = []
    
    for i in range(batch_size):
        # Get negatives for this anchor
        neg_mask = neg_batch_indices == i
        if neg_mask.sum() > 0:
            anchor_i = embeddings[i:i+1]  # (1, embedding_dim)
            negatives_i = neg_embeddings[neg_mask]  # (num_negs, embedding_dim)
            
            # Compute similarities between this anchor and its negatives
            neg_sims_i = F.cosine_similarity(
                anchor_i.expand_as(negatives_i), negatives_i, dim=1
            )
            neg_similarities_all.extend(neg_sims_i.cpu().tolist())
    
    if len(neg_similarities_all) > 0:
        neg_similarities = torch.tensor(neg_similarities_all)
        print(f"Anchor-Negative similarities: {neg_similarities.mean().item():.4f} ± {neg_similarities.std().item():.4f}")
        
        # Show the difference (should be positive if model is learning well)
        print(f"Positive vs Negative similarity difference: {pos_similarities.mean().item() - neg_similarities.mean().item():.4f}")
    else:
        print("No negative samples found in this batch")
