# V6 Phase 3: Contrastive Fine-tuning

## Goal
Fine-tune **HieroBERT** to align with **English BERT** using **Contrastive Learning** (InfoNCE).

## Strategy
1. **Encoders**:
    - **Source**: HieroBERT (Trainable). Fused with Visual Embeddings.
    - **Target**: English BERT (Frozen). Provides the "ground truth" semantic space.
2. **Objective**: Minimize the distance between correct translation pairs $(h_i, e_i)$ and maximize it for negatives.
3. **Data**: 8,541 Anchor Pairs.
4. **Visual Fusion**: `Hiero_Emb = BERT(text) + MLP(Visual_Mean)`

## Inputs
- `models/hierobert_small`: Pre-trained HieroBERT.
- `data/processed/visual_embeddings_768d.pkl`: Visual features.
- `data/processed/anchors.json`: Anchor pairs.

In [1]:
!pip install transformers torch scikit-learn numpy pandas tqdm


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import json
import numpy as np
from pathlib import Path
from transformers import BertModel, BertTokenizerFast, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Paths
MODEL_PATH = Path("../models/hierobert_small")
VISUAL_PATH = Path("../data/processed/visual_embeddings_768d.pkl")
ANCHORS_PATH = Path("../data/processed/anchors.json")

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")



Using device: mps


## 1. Data Loading & Dataset Class

In [3]:
# Load Resources
with open(VISUAL_PATH, 'rb') as f:
    visual_embeddings = pickle.load(f)

with open(ANCHORS_PATH, 'r') as f:
    anchors = json.load(f)

# Tokenizers
hiero_tokenizer = BertTokenizerFast.from_pretrained(str(MODEL_PATH))
en_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class MultimodalDataset(Dataset):
    def __init__(self, anchors, visual_emb, h_tokenizer, e_tokenizer, max_len=32):
        self.anchors = anchors
        self.visual_emb = visual_emb
        self.h_tokenizer = h_tokenizer
        self.e_tokenizer = e_tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.anchors)

    def __getitem__(self, idx):
        item = self.anchors[idx]
        h_text = item['hieroglyphic']
        e_text = item['english']

        # Hiero Tokenization
        h_enc = self.h_tokenizer(h_text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
        
        # English Tokenization
        e_enc = self.e_tokenizer(e_text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')

        # Visual Features (Mean of glyphs)
        visual_vecs = []
        for char in h_text:
            if char in self.visual_emb:
                visual_vecs.append(self.visual_emb[char])
        
        if visual_vecs:
            v_vec = np.mean(visual_vecs, axis=0)
        else:
            v_vec = np.zeros(768)

        return {
            'h_input_ids': h_enc['input_ids'].squeeze(0),
            'h_attention_mask': h_enc['attention_mask'].squeeze(0),
            'e_input_ids': e_enc['input_ids'].squeeze(0),
            'e_attention_mask': e_enc['attention_mask'].squeeze(0),
            'visual_vec': torch.tensor(v_vec, dtype=torch.float32)
        }

# Split & Loader
train_anchors, test_anchors = train_test_split(anchors, test_size=0.1, random_state=42)

train_dataset = MultimodalDataset(train_anchors, visual_embeddings, hiero_tokenizer, en_tokenizer)
test_dataset = MultimodalDataset(test_anchors, visual_embeddings, hiero_tokenizer, en_tokenizer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## 2. Model Architecture
We define `MultimodalHieroBERT` which takes text and visual inputs and outputs a single embedding.

In [4]:
class MultimodalHieroBERT(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.bert = BertModel.from_pretrained(str(model_path))
        # Adapter to fuse visual features
        self.visual_adapter = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)
        )
        # LayerNorm for stability
        self.ln = nn.LayerNorm(768)

    def forward(self, input_ids, attention_mask, visual_vec):
        # Get BERT output
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Mean pooling (excluding CLS/SEP)
        # Mask out padding tokens for mean calculation
        # Simple approach: use pooler_output (CLS) or mean of last hidden state
        # Let's use CLS for now as it's standard for sentence embeddings in BERT
        text_emb = outputs.pooler_output 
        
        # Process visual vec
        vis_emb = self.visual_adapter(visual_vec)
        
        # Fusion: Additive
        fused_emb = self.ln(text_emb + vis_emb)
        return fused_emb

## 3. Training Setup

In [5]:
# Initialize Models
hiero_model = MultimodalHieroBERT(MODEL_PATH).to(device)
en_model = BertModel.from_pretrained('bert-base-uncased').to(device)

# Freeze English BERT (Target)
for param in en_model.parameters():
    param.requires_grad = False
en_model.eval()

# Optimizer
optimizer = optim.AdamW(hiero_model.parameters(), lr=2e-5)

# Contrastive Loss (InfoNCE)
def contrastive_loss(h_emb, e_emb, temperature=0.1):
    # Normalize
    h_emb = torch.nn.functional.normalize(h_emb, dim=1)
    e_emb = torch.nn.functional.normalize(e_emb, dim=1)
    
    # Cosine similarity matrix: [batch, batch]
    logits = torch.matmul(h_emb, e_emb.T) / temperature
    
    # Labels: diagonal is the positive pair (0,0), (1,1), etc.
    labels = torch.arange(logits.size(0)).to(device)
    
    loss = nn.CrossEntropyLoss()(logits, labels)
    return loss

Some weights of BertModel were not initialized from the model checkpoint at ../models/hierobert_small and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 4. Training Loop

In [6]:
EPOCHS = 5

print("Starting Contrastive Fine-tuning...")
for epoch in range(EPOCHS):
    hiero_model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        # Move to device
        h_ids = batch['h_input_ids'].to(device)
        h_mask = batch['h_attention_mask'].to(device)
        e_ids = batch['e_input_ids'].to(device)
        e_mask = batch['e_attention_mask'].to(device)
        v_vec = batch['visual_vec'].to(device)
        
        optimizer.zero_grad()
        
        # Forward Hiero
        h_emb = hiero_model(h_ids, h_mask, v_vec)
        
        # Forward English (Target)
        with torch.no_grad():
            e_out = en_model(input_ids=e_ids, attention_mask=e_mask)
            e_emb = e_out.pooler_output
            
        # Loss
        loss = contrastive_loss(h_emb, e_emb)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")

# Save Fine-tuned Model
torch.save(hiero_model.state_dict(), "../models/hierobert_contrastive.pth")
print("Model saved.")

Starting Contrastive Fine-tuning...


Epoch 1: 100%|██████████| 241/241 [01:06<00:00,  3.63it/s]


Epoch 1 Loss: 3.3932


Epoch 2: 100%|██████████| 241/241 [00:58<00:00,  4.09it/s]


Epoch 2 Loss: 3.3165


Epoch 3: 100%|██████████| 241/241 [00:57<00:00,  4.16it/s]


Epoch 3 Loss: 3.2534


Epoch 4: 100%|██████████| 241/241 [00:58<00:00,  4.13it/s]


Epoch 4 Loss: 3.1867


Epoch 5: 100%|██████████| 241/241 [00:58<00:00,  4.13it/s]


Epoch 5 Loss: 3.1196
Model saved.


## 5. Evaluation

In [7]:
def evaluate(loader, k_values=[1, 5, 10]):
    hiero_model.eval()
    all_h_embs = []
    all_e_embs = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            h_ids = batch['h_input_ids'].to(device)
            h_mask = batch['h_attention_mask'].to(device)
            e_ids = batch['e_input_ids'].to(device)
            e_mask = batch['e_attention_mask'].to(device)
            v_vec = batch['visual_vec'].to(device)
            
            h_emb = hiero_model(h_ids, h_mask, v_vec)
            e_out = en_model(input_ids=e_ids, attention_mask=e_mask)
            e_emb = e_out.pooler_output
            
            all_h_embs.append(h_emb.cpu())
            all_e_embs.append(e_emb.cpu())
            
    H = torch.cat(all_h_embs)
    E = torch.cat(all_e_embs)
    
    # Normalize
    H = torch.nn.functional.normalize(H, dim=1)
    E = torch.nn.functional.normalize(E, dim=1)
    
    # Similarity Matrix
    sim_matrix = torch.matmul(H, E.T).numpy()
    
    top_k_hits = {k: 0 for k in k_values}
    n_test = len(H)
    
    for i in range(n_test):
        sorted_indices = np.argsort(-sim_matrix[i])
        for k in k_values:
            if i in sorted_indices[:k]:
                top_k_hits[k] += 1
                
    results = {f"Top-{k}": hits/n_test for k, hits in top_k_hits.items()}
    return results

print("Evaluating on Test Set...")
scores = evaluate(test_loader)
print("Contrastive Alignment Results:")
print(json.dumps(scores, indent=2))

Evaluating on Test Set...


Evaluating: 100%|██████████| 27/27 [00:04<00:00,  6.54it/s]

Contrastive Alignment Results:
{
  "Top-1": 0.004678362573099415,
  "Top-5": 0.023391812865497075,
  "Top-10": 0.047953216374269005
}



