In [None]:
# Install dependencies (optional)

!pip -q install scikit-learn
!pip -q install transformers accelerate peft sacrebleu

import sklearn, transformers, torch
print("sklearn:", sklearn.__version__)
print("transformers:", transformers.__version__)
print("torch:", torch.__version__)


In [None]:
!pip install torch-geometric

In [None]:
import kagglehub, os, sys
kagglehub.login()

In [None]:
# Download competition and import baseline utilities


comp_path = kagglehub.competition_download("molecular-graph-captioning")
print("comp_path:", comp_path)

BASELINE_DIR = os.path.join(comp_path, "data_baseline")
DATA_DIR     = os.path.join(BASELINE_DIR, "data")
sys.path.append(BASELINE_DIR)

from data_utils import PreprocessedGraphDataset, collate_fn, load_id2emb, x_map, e_map

TRAIN_GRAPHS = os.path.join(DATA_DIR, "train_graphs.pkl")
VAL_GRAPHS   = os.path.join(DATA_DIR, "validation_graphs.pkl")
TEST_GRAPHS  = os.path.join(DATA_DIR, "test_graphs.pkl")

print("BASELINE_DIR:", BASELINE_DIR)
print("DATA_DIR    :", DATA_DIR)

# quick peek
train_plain = PreprocessedGraphDataset(TRAIN_GRAPHS)
g0 = train_plain[0]
print(g0)
print("x:", g0.x.shape, "edge_index:", g0.edge_index.shape, "edge_attr:", g0.edge_attr.shape)
print("id:", getattr(g0, "id", None))
print("desc:", str(getattr(g0, "description", ""))[:200])


In [None]:
# Generate RoBERTa teacher embeddings (CSV)
import pickle
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

MODEL_NAME = "roberta-base"
MAX_TOKEN_LENGTH = 128
BATCH_SIZE = 64
POOL = "mean"  # "mean" recommended; "cls" also possible

TRAIN_EMB_CSV = os.path.join(DATA_DIR, "train_embeddings_roberta.csv")
VAL_EMB_CSV   = os.path.join(DATA_DIR, "validation_embeddings_roberta.csv")

print("Will write:")
print("  TRAIN_EMB_CSV:", TRAIN_EMB_CSV)
print("  VAL_EMB_CSV  :", VAL_EMB_CSV)

# Skip regeneration if files exist (delete them if you want to force)
if os.path.exists(TRAIN_EMB_CSV) and os.path.exists(VAL_EMB_CSV):
    print("RoBERTa embedding CSVs already exist — skipping generation.")
else:
    print(f"Loading text model {MODEL_NAME} ...")
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
    enc = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()

    @torch.no_grad()
    def embed_text_batch(texts):
        batch = tok(
            texts,
            padding=True,
            truncation=True,
            max_length=MAX_TOKEN_LENGTH,
            return_tensors="pt",
        ).to(DEVICE)
        out = enc(**batch)  # last_hidden_state: [B,T,768]
        h = out.last_hidden_state

        if POOL == "cls":
            z = h[:, 0, :]  # RoBERTa first token (<s>)
        else:
            mask = batch["attention_mask"].unsqueeze(-1)  # [B,T,1]
            z = (h * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-12)

        z = F.normalize(z, dim=-1)
        return z.detach().cpu().numpy().astype(np.float32)

    def generate_split(pkl_path, out_csv_path):
        print(f"\nLoading graphs from: {pkl_path}")
        with open(pkl_path, "rb") as f:
            graphs = pickle.load(f)
        print("Loaded:", len(graphs))

        ids, embs = [], []

        texts = []
        ids_buf = []

        for g in tqdm(graphs, desc=f"Embedding {os.path.basename(pkl_path)}"):
            gid = getattr(g, "id", None)
            desc = str(getattr(g, "description", "") or "")
            ids_buf.append(gid)
            texts.append(desc)

            if len(texts) >= BATCH_SIZE:
                Z = embed_text_batch(texts)  # [B,768]
                ids.extend(ids_buf)
                embs.extend([z for z in Z])
                texts, ids_buf = [], []

        # flush remainder
        if len(texts) > 0:
            Z = embed_text_batch(texts)
            ids.extend(ids_buf)
            embs.extend([z for z in Z])

        df = pd.DataFrame({
            "ID": ids,
            "embedding": [",".join(map(str, e.tolist())) for e in embs],
        })
        df.to_csv(out_csv_path, index=False)
        print("Saved:", out_csv_path, "| rows:", len(df))

    generate_split(TRAIN_GRAPHS, TRAIN_EMB_CSV)
    generate_split(VAL_GRAPHS, VAL_EMB_CSV)

print("DONE teacher CSVs.")


In [None]:
# Load teacher embeddings and build dataloaders
import random
import numpy as np
import torch
from torch.utils.data import DataLoader

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

train_emb = load_id2emb(TRAIN_EMB_CSV)
val_emb   = load_id2emb(VAL_EMB_CSV)
EMB_DIM   = 768

train_ds = PreprocessedGraphDataset(TRAIN_GRAPHS, train_emb)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)

print("Loaded embeddings:", len(train_emb), len(val_emb), "dim =", EMB_DIM)


In [None]:
# Define MolGNN, RoBERTa text encoder, and losses
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from torch_geometric.nn import GINEConv, global_mean_pool

# Feature embedders
class AtomEncoder(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(len(mapping), emb_dim) for _, mapping in x_map.items()
        ])
    def forward(self, x):
        out = 0
        for i, emb in enumerate(self.embeddings):
            out = out + emb(x[:, i])
        return out

class BondEncoder(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(len(mapping), emb_dim) for _, mapping in e_map.items()
        ])
    def forward(self, edge_attr):
        out = 0
        for i, emb in enumerate(self.embeddings):
            out = out + emb(edge_attr[:, i])
        return out

# Graph encoder 
class MolGNN(nn.Module):
    def __init__(self, hidden_dim=128, out_dim=768, num_layers=4, dropout=0.1):
        super().__init__()
        self.atom_encoder = AtomEncoder(hidden_dim)
        self.bond_encoder = BondEncoder(hidden_dim)

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for _ in range(num_layers):
            nn_fn = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(nn_fn))
            self.norms.append(nn.LayerNorm(hidden_dim))

        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, batch):
        x = self.atom_encoder(batch.x)
        edge_attr = self.bond_encoder(batch.edge_attr)

        for conv, norm in zip(self.convs, self.norms):
            h = conv(x, batch.edge_index, edge_attr)
            h = norm(h)
            h = F.relu(h)
            x = x + self.dropout(h)  # residual

        g = global_mean_pool(x, batch.batch)
        g = self.proj(g)
        return F.normalize(g, dim=-1)

# Text encoder (RoBERTa)
class RobertaTextEncoder(nn.Module):
    def __init__(self, name="roberta-base", out_dim=768, train_last_n_layers=2, pool="mean"):
        super().__init__()
        self.pool = pool
        self.tok = AutoTokenizer.from_pretrained(name)
        self.enc = AutoModel.from_pretrained(name)

        # freeze all
        for p in self.enc.parameters():
            p.requires_grad = False

        # unfreeze last N transformer blocks (if available)
        if hasattr(self.enc, "encoder") and hasattr(self.enc.encoder, "layer"):
            layers = self.enc.encoder.layer
            for block in layers[-train_last_n_layers:]:
                for p in block.parameters():
                    p.requires_grad = True

        self.proj = nn.Linear(self.enc.config.hidden_size, out_dim)

    def forward(self, texts, device, max_len=128):
        batch = self.tok(
            texts,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt",
        ).to(device)
        out = self.enc(**batch)
        h = out.last_hidden_state  # [B,T,H]

        if self.pool == "cls":
            pooled = h[:, 0, :]
        else:
            mask = batch["attention_mask"].unsqueeze(-1)
            pooled = (h * mask).sum(1) / (mask.sum(1) + 1e-12)

        z = self.proj(pooled)
        return F.normalize(z, dim=-1)

# Contrastive loss (InfoNCE symmetric)
def contrastive_loss(a, b, temperature=0.07):
    logits = (a @ b.T) / temperature
    labels = torch.arange(len(a), device=a.device)
    return (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

def distill_cosine(a, b):
    # 1 - cosine similarity (assuming normalized => dot is cosine)
    return 1.0 - (a * b).sum(dim=-1).mean()


In [None]:
# Train hybrid model (graph + text)
from torch.utils.data import DataLoader

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

gnn = MolGNN(hidden_dim=128, out_dim=768, num_layers=4, dropout=0.1).to(DEVICE)
txt = RobertaTextEncoder("roberta-base", out_dim=768, train_last_n_layers=2, pool="mean").to(DEVICE)

# Optimizer with different LRs (graph higher, text lower)
txt_params = [p for p in txt.parameters() if p.requires_grad]
opt = torch.optim.AdamW(
    [
        {"params": gnn.parameters(), "lr": 1e-3},
        {"params": txt_params, "lr": 2e-4},
    ],
    weight_decay=1e-2
)

# weights
W_ONLINE  = 0.5
W_DISTILL = 0.1
WARMUP_EPOCHS = 3 

EPOCHS = 20  

def temp_schedule(epoch):
    # keep Model B schedule
    return max(0.02, 0.07 * (0.95 ** epoch))

def train_epoch(epoch):
    gnn.train()
    txt.train()

    ramp = 1.0 if WARMUP_EPOCHS <= 0 else min(1.0, epoch / float(WARMUP_EPOCHS))
    w_online  = W_ONLINE  * ramp
    w_distill = W_DISTILL * ramp

    total = 0.0
    t = temp_schedule(epoch)

    for graphs, teacher_emb in tqdm(train_dl, desc=f"train ep {epoch+1}", leave=False):
        graphs = graphs.to(DEVICE)
        teacher_emb = F.normalize(teacher_emb.to(DEVICE), dim=-1)

        # Pull texts from the batch (stored in graph objects)
        texts = [str(g.description) for g in graphs.to_data_list()]

        graph_emb = gnn(graphs)
        online_emb = txt(texts, DEVICE, max_len=128)

        loss_teacher = contrastive_loss(graph_emb, teacher_emb, temperature=t)
        loss_online  = contrastive_loss(graph_emb, online_emb, temperature=t)
        loss_distill = distill_cosine(online_emb, teacher_emb)

        loss = loss_teacher + w_online * loss_online + w_distill * loss_distill

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(gnn.parameters()) + txt_params, 1.0)
        opt.step()

        total += loss.item() * graphs.num_graphs

    return total / len(train_dl.dataset)

@torch.no_grad()
def eval_retrieval_metrics(model, graph_path, emb_dict):
    model.eval()
    ds = PreprocessedGraphDataset(graph_path, emb_dict)
    dl = DataLoader(ds, batch_size=64, shuffle=False, collate_fn=collate_fn)

    mol_vecs, txt_vecs = [], []
    for graphs, text_emb in dl:
        graphs = graphs.to(DEVICE)
        mol_vecs.append(model(graphs))
        txt_vecs.append(F.normalize(text_emb.to(DEVICE), dim=-1))

    mol_vecs = torch.cat(mol_vecs, dim=0)
    txt_vecs = torch.cat(txt_vecs, dim=0)

    sims = txt_vecs @ mol_vecs.T
    ranks = sims.argsort(dim=-1, descending=True)

    correct = torch.arange(len(txt_vecs), device=DEVICE)
    pos = (ranks == correct.unsqueeze(1)).nonzero()[:, 1] + 1

    return {
        "MRR": (1 / pos.float()).mean().item(),
        "R@1": (pos <= 1).float().mean().item(),
        "R@5": (pos <= 5).float().mean().item(),
        "R@10": (pos <= 10).float().mean().item(),
    }



In [None]:
for ep in range(EPOCHS):
    loss = train_epoch(ep)
    metrics = eval_retrieval_metrics(gnn, VAL_GRAPHS, val_emb)
    print(f"Epoch {ep+1:02d}/{EPOCHS} | temp={temp_schedule(ep):.4f} | Loss={loss:.4f} | {metrics}")

torch.save(gnn.state_dict(), "mol_gnn_hybrid.pt")
torch.save(txt.state_dict(), "roberta_text_hybrid.pt")
print("Saved mol_gnn_hybrid.pt and roberta_text_hybrid.pt")


In [None]:
# Evaluate BLEU on validation set (top-1 copy)
import sacrebleu
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Reload (optional)
gnn = MolGNN(out_dim=768).to(DEVICE)
gnn.load_state_dict(torch.load("mol_gnn_hybrid.pt", map_location=DEVICE))
gnn.eval()

train_plain = PreprocessedGraphDataset(TRAIN_GRAPHS)
val_plain   = PreprocessedGraphDataset(VAL_GRAPHS)

train_captions = [str(g.description) for g in train_plain]

train_dl_plain = DataLoader(train_plain, batch_size=64, shuffle=False, collate_fn=collate_fn)
train_vecs = []

with torch.no_grad():
    for graphs in tqdm(train_dl_plain, desc="Encoding TRAIN graphs"):
        graphs = graphs.to(DEVICE)
        train_vecs.append(gnn(graphs))

train_vecs = torch.cat(train_vecs, dim=0)  # [N_train, 768]

val_dl_plain = DataLoader(val_plain, batch_size=64, shuffle=False, collate_fn=collate_fn)

pred_texts = []
ref_texts  = []

with torch.no_grad():
    for graphs in tqdm(val_dl_plain, desc="Predicting VAL captions"):
        graphs = graphs.to(DEVICE)
        q = gnn(graphs)                 # [B,768]
        sims = q @ train_vecs.T         # cosine since normalized
        best = sims.argmax(dim=1).cpu().numpy()

        for idx in best:
            pred_texts.append(train_captions[int(idx)])

        ref_texts.extend([str(g.description) for g in graphs.to_data_list()])

bleu = sacrebleu.corpus_bleu(pred_texts, [ref_texts], smooth_method="exp").score
print(f"✅ VAL BLEU: {bleu:.4f}")

for i in range(3):
    print("\n---", i, "---")
    print("PRED:", pred_texts[i][:200])
    print("REF :", ref_texts[i][:200])


In [None]:
# Predict on test set and write submission.csv
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

test_plain = PreprocessedGraphDataset(TEST_GRAPHS)
test_dl = DataLoader(test_plain, batch_size=64, shuffle=False, collate_fn=collate_fn)

pred_ids = []
pred_texts = []

gnn.eval()
with torch.no_grad():
    for graphs in tqdm(test_dl, desc="Predicting TEST captions"):
        graphs = graphs.to(DEVICE)
        q = gnn(graphs)
        sims = q @ train_vecs.T
        best = sims.argmax(dim=1).cpu().numpy()

        batch_ids = graphs.id
        if torch.is_tensor(batch_ids):
            batch_ids = batch_ids.cpu().numpy().tolist()

        for i, idx in enumerate(best):
            pred_ids.append(batch_ids[i])
            pred_texts.append(train_captions[int(idx)])

submission = pd.DataFrame({"ID": pred_ids, "description": pred_texts})
submission.to_csv("submission.csv", index=False)
print("✅ wrote submission.csv")
submission.head()
