In [None]:

import kagglehub
kagglehub.login()


In [None]:


molecular_graph_captioning_path = kagglehub.competition_download('molecular-graph-captioning')

print('Data source import complete.')


In [None]:
RUN_NAME = "g2t_gineVNJK__flan-t5-base__lora"
GRAPH_ENCODER = "GINE + VirtualNode + JumpingKnowledge(cat)"
DECODER = "google/flan-t5-base"
LORA = "r=32, alpha=64, dropout=0.05"

print(RUN_NAME)
print(GRAPH_ENCODER)
print(DECODER)
print(LORA)


In [None]:
!pip -q install -U sacrebleu

import torch
TORCH = torch.__version__.split("+")[0]
CUDA = torch.version.cuda.replace(".", "")  # e.g. "121"
print("TORCH:", TORCH, "CUDA:", CUDA)

!pip -q install -U pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv \
  -f https://data.pyg.org/whl/torch-{TORCH}+cu{CUDA}.html

!pip -q install -U torch-geometric

In [None]:
import importlib

def check(pkg):
    try:
        m = importlib.import_module(pkg)
        print(f"[OK] {pkg:15s}  {getattr(m,'__version__','')}")
    except Exception as e:
        print(f"[FAIL] {pkg:15s}  {e}")

check("torch")
check("torch_geometric")
check("transformers")
check("peft")
check("sacrebleu")


In [None]:
import os
import gc
import math
import random
import pickle
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import AdamW

from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import GINEConv, global_add_pool

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_cosine_schedule_with_warmup,
)
from transformers.modeling_outputs import BaseModelOutput

from peft import LoraConfig, get_peft_model, TaskType

from sacrebleu import corpus_bleu

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
print("torch:", torch.__version__)


paths

In [None]:
DATA_DIR = "/kaggle/input/molecular-graph-captioning/data_baseline/data"

TRAIN_PKL = f"{DATA_DIR}/train_graphs.pkl"
VAL_PKL   = f"{DATA_DIR}/validation_graphs.pkl"
TEST_PKL  = f"{DATA_DIR}/test_graphs.pkl"


load graphs + feature maps

In [None]:
import sys, os, pickle
import kagglehub

# Download (or reuse cached) competition files in Colab
comp_path = kagglehub.competition_download("molecular-graph-captioning")
print("comp_path:", comp_path)

# Paths inside the downloaded package
BASELINE_DIR = os.path.join(comp_path, "data_baseline")
DATA_DIR     = os.path.join(BASELINE_DIR, "data")

# Make sure data_utils.py is importable
sys.path.append(BASELINE_DIR)
from data_utils import x_map, e_map

# PKL paths
TRAIN_PKL = os.path.join(DATA_DIR, "train_graphs.pkl")
VAL_PKL   = os.path.join(DATA_DIR, "validation_graphs.pkl")
TEST_PKL  = os.path.join(DATA_DIR, "test_graphs.pkl")

print("TRAIN exists:", os.path.exists(TRAIN_PKL))
print("VAL exists:", os.path.exists(VAL_PKL))
print("TEST exists:", os.path.exists(TEST_PKL))

def load_pkl(p):
    with open(p, "rb") as f:
        return pickle.load(f)

train_graphs = load_pkl(TRAIN_PKL)
val_graphs   = load_pkl(VAL_PKL)
test_graphs  = load_pkl(TEST_PKL)

print(len(train_graphs), len(val_graphs), len(test_graphs))
print("train has description:", hasattr(train_graphs[0], "description"))
print("test has description:", hasattr(test_graphs[0], "description"))
print("x:", train_graphs[0].x.shape, "edge_attr:", train_graphs[0].edge_attr.shape)


tokenizer + helper for labels
we'll use FLAN-T5 as a strong decoder and apply LoRA

In [None]:
MODEL_NAME = "google/flan-t5-base"   # strong + fits most GPUs
MAX_LEN = 128                        # caption max tokens

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_captions(captions: List[str], max_len: int):
    tok = tokenizer(
        captions,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )
    labels = tok["input_ids"].clone()
    labels[labels == tokenizer.pad_token_id] = -100
    return tok, labels


DataLoaders

In [None]:
BATCH_SIZE = 16  # start with 8 or 16 depending on GPU
NUM_WORKERS =2

train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_graphs,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader  = DataLoader(test_graphs,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


**Model definition (robust graoh encoder + LoRA T5**
categorical node/edge embedders
we embed each categorical feature separately

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINEConv, global_add_pool

class CategoricalFeatureEmbedder(nn.Module):
    def __init__(self, feature_map, emb_dim: int):
        super().__init__()
        self.keys = list(feature_map.keys())
        self.embs = nn.ModuleDict({k: nn.Embedding(len(feature_map[k]), emb_dim) for k in self.keys})

    def forward(self, x_cat: torch.Tensor) -> torch.Tensor:
        # x_cat: [num_nodes_or_edges, num_features] integer IDs
        out = 0
        for i, k in enumerate(self.keys):
            out = out + self.embs[k](x_cat[:, i])
        return out  # [N, emb_dim]


class GINEVirtualNodeJK(nn.Module):
    def __init__(
        self,
        x_map,
        e_map,
        hidden_dim: int = 256,
        num_layers: int = 5,
        dropout: float = 0.1,
        jk: str = "cat",           # "cat" or "last"
        use_virtual_node: bool = True,
    ):
        super().__init__()
        assert jk in ["cat", "last"]
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.jk = jk
        self.use_virtual_node = use_virtual_node

        self.node_emb = CategoricalFeatureEmbedder(x_map, hidden_dim)
        self.edge_emb = CategoricalFeatureEmbedder(e_map, hidden_dim)

        if use_virtual_node:
            self.vn = nn.Embedding(1, hidden_dim)
            self.vn_mlp = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_dim, hidden_dim),
                )
                for _ in range(num_layers)
            ])

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(nn=mlp))
            self.norms.append(nn.LayerNorm(hidden_dim))

    def forward(self, data):
        x = self.node_emb(data.x)           # [total_nodes, hidden]
        e = self.edge_emb(data.edge_attr)   # [total_edges, hidden]
        h_list = []

        if self.use_virtual_node:
            vn = self.vn.weight[0].unsqueeze(0).repeat(data.num_graphs, 1)  # [B, hidden]

        for layer in range(self.num_layers):
            if self.use_virtual_node:
                x = x + vn[data.batch]

            x = self.convs[layer](x, data.edge_index, e)
            x = self.norms[layer](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            h_list.append(x)

            if self.use_virtual_node:
                pooled = global_add_pool(x, data.batch)  # [B, hidden]
                vn = vn + self.vn_mlp[layer](pooled)

        if self.jk == "cat":
            return torch.cat(h_list, dim=-1)  # [total_nodes, hidden*num_layers]
        else:
            return h_list[-1]                 # [total_nodes, hidden]


Robust GNN: GINE + Virtual Node + JK
This outputs node embeddings (not pooled) so T5 gets many "enocder tokens"

In [None]:
class GINEVirtualNodeJK(nn.Module):
    def __init__(
        self,
        x_map: Dict[str, List[Any]],
        e_map: Dict[str, List[Any]],
        hidden_dim: int = 256,
        num_layers: int = 5,
        dropout: float = 0.1,
        jk: str = "cat",   # "cat" or "last"
        use_virtual_node: bool = True,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.jk = jk
        self.use_virtual_node = use_virtual_node

        # Node/edge categorical embeddings
        self.node_emb = CategoricalFeatureEmbedder(x_map, hidden_dim)
        self.edge_emb = CategoricalFeatureEmbedder(e_map, hidden_dim)

        # Virtual node embedding (one per graph, broadcasted to nodes)
        if use_virtual_node:
            self.vn = nn.Embedding(1, hidden_dim)
            self.vn_mlp = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_dim, hidden_dim),
                )
                for _ in range(num_layers)
            ])

        # GINE layers
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        for _ in range(num_layers):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINEConv(nn=mlp))
            self.norms.append(nn.LayerNorm(hidden_dim))

    def forward(self, data):
        # data.x: [total_nodes, 9] categorical IDs
        # data.edge_attr: [total_edges, 3] categorical IDs
        x = self.node_emb(data.x)
        e = self.edge_emb(data.edge_attr)

        h_list = []

        if self.use_virtual_node:
            vn = self.vn.weight[0].unsqueeze(0)  # [1, hidden]
            vn = vn.repeat(data.num_graphs, 1)   # [B, hidden]

        for layer in range(self.num_layers):
            if self.use_virtual_node:
                # add virtual node representation to each node
                x = x + vn[data.batch]

            x = self.convs[layer](x, data.edge_index, e)
            x = self.norms[layer](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            h_list.append(x)

            if self.use_virtual_node:
                # update virtual node from pooled node states
                pooled = global_add_pool(x, data.batch)  # [B, hidden]
                vn = vn + self.vn_mlp[layer](pooled)

        if self.jk == "cat":
            x_out = torch.cat(h_list, dim=-1)  # [total_nodes, hidden*num_layers]
        else:
            x_out = h_list[-1]

        return x_out


Full model : graph -> encoder states + LoRA-T5 decoder
key point: we convert nodes into padded sequences with to_dense_batch

In [None]:
class Graph2TextT5(nn.Module):
    def __init__(
        self,
        x_map,
        e_map,
        t5_name: str,
        gnn_hidden: int = 256,
        gnn_layers: int = 5,
        dropout: float = 0.1,
        jk: str = "cat",
        use_virtual_node: bool = True,
        lora_r: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.05,
    ):
        super().__init__()

        self.t5 = AutoModelForSeq2SeqLM.from_pretrained(t5_name)
        d_model = self.t5.config.d_model

        self.gnn = GINEVirtualNodeJK(
            x_map=x_map,
            e_map=e_map,
            hidden_dim=gnn_hidden,
            num_layers=gnn_layers,
            dropout=dropout,
            jk=jk,
            use_virtual_node=use_virtual_node,
        )

        gnn_out_dim = (gnn_hidden * gnn_layers) if jk == "cat" else gnn_hidden

        self.bridge = nn.Sequential(
            nn.Linear(gnn_out_dim, d_model),
            nn.LayerNorm(d_model),
        )

        # LoRA on T5
        lora_cfg = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            bias="none",
            target_modules=["q", "k", "v", "o"],  # T5Attention projections
        )
        self.t5 = get_peft_model(self.t5, lora_cfg)

    def forward(self, batch, labels=None, decoder_input_ids=None):
        # 1) GNN node embeddings (ragged)
        node_h = self.gnn(batch)  # [total_nodes, gnn_out_dim]

        # 2) Make dense encoder tokens: [B, Nmax, gnn_out_dim] + mask [B, Nmax]
        dense_h, mask = to_dense_batch(node_h, batch.batch)  # mask: bool

        # 3) Project to T5 dimension
        enc_hidden = self.bridge(dense_h)  # [B, Nmax, d_model]
        attn_mask = mask.long()            # [B, Nmax]

        enc_out = BaseModelOutput(last_hidden_state=enc_hidden)

        # 4) T5 decode
        out = self.t5(
            encoder_outputs=enc_out,
            attention_mask=attn_mask,
            labels=labels,
            decoder_input_ids=decoder_input_ids,
        )
        return out

    @torch.no_grad()
    def generate(self, batch, **gen_kwargs):
        node_h = self.gnn(batch)
        dense_h, mask = to_dense_batch(node_h, batch.batch)
        enc_hidden = self.bridge(dense_h)
        attn_mask = mask.long()
        enc_out = BaseModelOutput(last_hidden_state=enc_hidden)

        return self.t5.generate(
            encoder_outputs=enc_out,
            attention_mask=attn_mask,
            **gen_kwargs
        )


**Training + evaluation + submission**
init model

In [None]:
model = Graph2TextT5(
    x_map=x_map,
    e_map=e_map,
    t5_name=MODEL_NAME,
    gnn_hidden=256,
    gnn_layers=5,
    dropout=0.1,
    jk="cat",
    use_virtual_node=True,
    lora_r=32,
    lora_alpha=64,
    lora_dropout=0.05,
).to(device)

model.t5.print_trainable_parameters()


optimizer + scheduler + AMP

In [None]:
# ===== Training setup (Base++): param groups + scheduler + early stopping =====
EPOCHS = 30
WARMUP_RATIO = 0.05


LR_GNN = 3e-4
LR_LORA = 1e-4
WEIGHT_DECAY_GNN = 0.01
WEIGHT_DECAY_LORA = 0.0

PATIENCE = 4       
MIN_DELTA = 0.05    
GRAD_CLIP = 1.0

# ---- build optimizer param groups ----
gnn_params, lora_params = [], []
for n, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if "lora" in n.lower():
        lora_params.append(p)
    else:
        gnn_params.append(p)

optimizer = AdamW(
    [
        {"params": gnn_params, "lr": LR_GNN, "weight_decay": WEIGHT_DECAY_GNN},
        {"params": lora_params, "lr": LR_LORA, "weight_decay": WEIGHT_DECAY_LORA},
    ]
)

num_training_steps = EPOCHS * len(train_loader)
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


In [None]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0.0

    for batch in loader:
        batch = batch.to(device)

        captions = [g.description for g in batch.to_data_list()]
        tok, labels = tokenize_captions(captions, MAX_LEN)
        labels = labels.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            out = model(batch, labels=labels)
            loss = out.loss

        # Backprop (AMP)
        scaler.scale(loss).backward()

        # IMPORTANT: unscale before clipping so clipping is meaningful
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()

    return total_loss / max(1, len(loader))


@torch.no_grad()
def evaluate_bleu(model, loader, num_batches=None):
    """
    num_batches=None => evaluate full validation (recommended for early stopping).
    set num_batches=50 for faster but noisier estimates.
    """
    model.eval()
    hyps, refs = [], []

    for i, batch in enumerate(loader):
        if num_batches is not None and i >= num_batches:
            break

        batch = batch.to(device)
        ref_caps = [g.description for g in batch.to_data_list()]

        gen_ids = model.generate(
            batch,
            max_new_tokens=MAX_LEN,     # you can later try 160
            num_beams=4,                # you can later try 6
            no_repeat_ngram_size=3,     # you can later try 2
            early_stopping=True,
        )
        gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

        hyps.extend(gen_text)
        refs.extend(ref_caps)

    bleu = corpus_bleu(hyps, [refs]).score
    return bleu


def fit_with_early_stopping(model, train_loader, val_loader, work_dir="./", eval_num_batches=None):
    import os, gc, torch

    best_bleu = -1.0
    best_path = os.path.join(work_dir, "best_model.pt")
    bad_epochs = 0

    for epoch in range(1, EPOCHS + 1):
        tr_loss = train_one_epoch(model, train_loader)

        val_bleu = evaluate_bleu(model, val_loader, num_batches=eval_num_batches)
        print(f"Epoch {epoch}/{EPOCHS} | train_loss={tr_loss:.4f} | val_BLEU={val_bleu:.2f}")

        if val_bleu > best_bleu + MIN_DELTA:
            best_bleu = val_bleu
            bad_epochs = 0
            torch.save(model.state_dict(), best_path)
            print(f"  ✓ saved best checkpoint: {best_path}")
        else:
            bad_epochs += 1
            print(f"  no improvement ({bad_epochs}/{PATIENCE})")
            if bad_epochs >= PATIENCE:
                print(f"Early stopping at epoch {epoch}. Best val_BLEU={best_bleu:.2f}")
                break

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    model.load_state_dict(torch.load(best_path, map_location=device))
    print(f"Loaded best checkpoint: {best_path} | best val_BLEU={best_bleu:.2f}")
    return model, best_bleu

In [None]:
import os, gc, time
import torch
from google.colab import drive

# ===== Save checkpoints to Google Drive =====
drive.mount("/content/drive")

PROJECT_NAME = "my_project"  # change if you want
RUN_NAME = time.strftime("run_%Y%m%d_%H%M%S")    # unique folder per run

WORK_DIR = f"/content/drive/MyDrive/{PROJECT_NAME}/{RUN_NAME}"
os.makedirs(WORK_DIR, exist_ok=True)
print("WORK_DIR:", WORK_DIR)

best_bleu = -1.0
best_path = os.path.join(WORK_DIR, "best_model.pt")
last_path = os.path.join(WORK_DIR, "last_model.pt")
log_path  = os.path.join(WORK_DIR, "train_log.txt")
bad_epochs = 0

EVAL_NUM_BATCHES = None   # None = full validation (recommended). set 50 if too slow.

# optional: write logs to a file in Drive
def log(msg: str):
    print(msg)
    with open(log_path, "a") as f:
        f.write(msg + "\n")

log(f"Saving to: {WORK_DIR}")
log(f"best_path: {best_path}")

for epoch in range(1, EPOCHS + 1):
    tr_loss = train_one_epoch(model, train_loader)
    val_bleu = evaluate_bleu(model, val_loader, num_batches=EVAL_NUM_BATCHES)

    log(f"Epoch {epoch}/{EPOCHS} | train_loss={tr_loss:.4f} | val_BLEU={val_bleu:.2f}")

    # always save "last" checkpoint (useful if training interrupts)
    torch.save(model.state_dict(), last_path)

    if val_bleu > best_bleu + MIN_DELTA:
        best_bleu = val_bleu
        bad_epochs = 0
        torch.save(model.state_dict(), best_path)
        log(f"  ✓ saved best: {best_path}")
    else:
        bad_epochs += 1
        log(f"  no improvement ({bad_epochs}/{PATIENCE})")

        if bad_epochs >= PATIENCE:
            log(f"Early stopping at epoch {epoch}. Best val_BLEU={best_bleu:.2f}")
            break

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Load best checkpoint at the end (fallback to last if best doesn't exist)
load_path = best_path if os.path.exists(best_path) else last_path
model.load_state_dict(torch.load(load_path, map_location=device))
log(f"Loaded model: {load_path} | best val_BLEU: {best_bleu:.2f}")


load best + generate predictions for test

In [None]:
model.load_state_dict(torch.load(best_path, map_location=device))
model.eval()

all_ids = []
all_preds = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        batch_list = batch.to_data_list()

        gen_ids = model.generate(
            batch,
            max_new_tokens=MAX_LEN,     # later try 160 if truncation suspected
            num_beams=4,                # later try 6
            no_repeat_ngram_size=3,     # later try 2
            early_stopping=True,
        )
        gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

        for g, txt in zip(batch_list, gen_text):
            all_ids.append(g.id)
            all_preds.append(txt)

print("generated:", len(all_preds))
print("example:", all_preds[0])


In [None]:
sub = pd.DataFrame({"id": all_ids, "description": all_preds})
sub_path = os.path.join(WORK_DIR, "submission.csv")
sub.to_csv(sub_path, index=False)
print("saved:", sub_path)
sub.head()
