In [None]:
!pip install -q torch-geometric torch-scatter
!pip install -q transformers accelerate bitsandbytes peft
!pip install -q pandas networkx scipy tqdm

import torch
import os

print(f"CUDA available : {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU : {device_name}")
    print(f"VRAM : {vram:.2f} GB")
else:
    raise RuntimeError("‚ùå ERREUR : Activez le GPU dans Ex√©cution > Modifier le type d'ex√©cution !")

[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m63.7/63.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m108.0/108.0 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import pickle
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import Batch
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import tqdm

# Configuration
CONFIG = {
    'lap_dim': 8,
    'rw_dim': 16,
    'wl_vocab': 50000,
    'hidden_dim': 256,
    'output_dim': 768,
    'batch_size': 512,
    'lr': 4e-4,
    'epochs': 100,
    'temperature': 0.07,
    'device': 'cuda'
}

In [None]:
# --- 1. Encoder ---

class AtomFeatureEncoder(nn.Module):
    """Embeddings for categorical data"""
    def __init__(self, emb_dim):
        super().__init__()
        # Number of classes for each categorical label
        dims = [119, 9, 11, 12, 9, 5, 8, 2, 2]
        self.embeddings = nn.ModuleList([nn.Embedding(d, emb_dim) for d in dims])

    def forward(self, x):
        out = 0
        for i, layer in enumerate(self.embeddings):
            out += layer(x[:, i].long())
        return out

class LighterInputLayer(nn.Module):
    """Gathering all the graph features into one vector"""
    def __init__(self, hidden_dim, num_bond_types=22, max_k=11):
        super().__init__()

        # Embeddings of same dimensions
        self.edge_embedding = nn.Embedding(num_bond_types, hidden_dim, padding_idx=0)
        self.rank_embedding = nn.Embedding(max_k, hidden_dim)

        # Projection of geometric features to be able to sum
        self.geo_proj = nn.Sequential(
            nn.Linear(5, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, h_neighbors, neighbors_edge_idx, geo_features, mask):
        
        device = h_neighbors.device
        N, K, _ = h_neighbors.shape

        # 1. Features
        # Edges
        safe_edges = neighbors_edge_idx.clone()
        safe_edges[safe_edges == -1] = 0
        h_edges = self.edge_embedding(safe_edges)

        ranks = torch.arange(K, device=device).unsqueeze(0).expand(N, K)
        h_ranks = self.rank_embedding(ranks)

        # Geometry
        h_geo = self.geo_proj(geo_features)

        # 2. Sum
        out = h_neighbors + h_edges + h_ranks + h_geo

        return self.norm(out) * mask.unsqueeze(-1)

class PreNormAttention(nn.Module):
    """
    Architecture Pre-Norm to prevent oversmoothing.
    x = x + Attention(Norm(x)) instead of x = Norm(x + Attention(x))
    """
    def __init__(self, input_dim, num_heads=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(input_dim)
        self.mha = nn.MultiheadAttention(input_dim, num_heads, batch_first=True)

        self.norm2 = nn.LayerNorm(input_dim)
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, 4*input_dim),
            nn.ReLU(),
            nn.Linear(4*input_dim, input_dim)
        )

    def forward(self, x, neighbor_indices):
        key_mask = (neighbor_indices == -1)

        # Bloc 1 : Attention (Pre-Norm)
        x_norm = self.norm1(x)
        attn, _ = self.mha(x_norm, x_norm, x_norm, key_padding_mask=key_mask, need_weights=False)
        x = x + attn

        # Bloc 2 : FFN (Pre-Norm)
        x_norm = self.norm2(x)
        x = x + self.ffn(x_norm)

        return x * (~key_mask).unsqueeze(-1)

class FullEncoderModel(nn.Module):
    def __init__(self, lap_dim, rw_dim, wl_vocab, hidden_dim, output_dim, max_k=11):
        super().__init__()

        # 1. Atom
        self.atom_enc = AtomFeatureEncoder(hidden_dim)

        # 2. Global Features (Weighted sum)
        self.global_proj = nn.Sequential(nn.Linear(lap_dim + rw_dim, hidden_dim), nn.SiLU())
        self.wl_emb = nn.Embedding(wl_vocab, hidden_dim)

        # 3. Input Layer
        self.input_layer = LighterInputLayer(hidden_dim, max_k=max_k)

        # 4. Attention (Pre-Norm)
        self.layers = nn.ModuleList([PreNormAttention(hidden_dim) for _ in range(4)])

        # 5. Agregation
        self.agg_scorer = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
        self.norm_final = nn.LayerNorm(hidden_dim)

        # 6. Output
        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )
        self.final_norm = nn.LayerNorm(hidden_dim)

    def forward(self, data, sorted_neighbors, sorted_edges, geo_features):
        geo_features = geo_features.contiguous()

        # --- Step 1 : Central Node ---
        h = self.atom_enc(data.x)
        h_global = self.global_proj(torch.cat([data.pe_lap, data.pe_rw], dim=-1))
        h_wl = self.wl_emb(data.pe_wl).sum(dim=1)

        # Sum (memory efficient)
        h = h + h_global + h_wl

        # --- Step 2 : Neighbors ---
        mask = (sorted_neighbors != -1)
        safe_indices = sorted_neighbors.clone()
        safe_indices[~mask] = 0
        h_neighbors_raw = h[safe_indices]

        h_neighbors = self.input_layer(h_neighbors_raw, sorted_edges, geo_features, mask)

        # --- Step 3 : Attention ---
        for layer in self.layers:
            h_neighbors = layer(h_neighbors, sorted_neighbors)

        # --- Step 4 : Agregation ---
        scores = self.agg_scorer(h_neighbors)
        scores = scores.masked_fill(~mask.unsqueeze(-1), 0)
        h_context = (h_neighbors * scores).sum(dim=1)

        h_final = self.norm_final(h + h_context)

        # --- Step 5 : Pooling  ---
        z_mean = global_mean_pool(h_final, data.batch) # For global information
        z_max  = global_max_pool(h_final, data.batch) # For local information

        # Sum
        z_graph = z_mean + z_max
        z_graph = self.final_norm(z_graph)

        return self.final_mlp(z_graph), data.batch

In [None]:
class PreprocessedGraphDataset(Dataset):
    def __init__(self, graph_path, emb_dict=None):
        print(f"Loading {graph_path}...")
        with open(graph_path, 'rb') as f:
            self.graphs = pickle.load(f)
        self.emb_dict = emb_dict

    def __len__(self): return len(self.graphs)

    def __getitem__(self, idx):
        g = self.graphs[idx]
        edges = getattr(g, 'sorted_edges', torch.zeros_like(g.sorted_neighbors))

        if self.emb_dict:
            return g, g.sorted_neighbors, edges, g.geo_features, self.emb_dict[str(g.id)]
        return g, g.sorted_neighbors, edges, g.geo_features

def collate_fn(batch):
    has_text = len(batch[0]) == 5

    if has_text:
        graphs, neighbors, edges, geos, texts = zip(*batch)
        texts = torch.stack(texts)
    else:
        graphs, neighbors, edges, geos = zip(*batch)
        texts = None

    batch_graph = Batch.from_data_list(list(graphs))

    adjusted_neighbors = []
    offset = 0
    for i, n in enumerate(neighbors):
        mask = (n != -1)
        gn = n.clone()
        gn[mask] += offset
        adjusted_neighbors.append(gn)
        offset += graphs[i].num_nodes
    total_neighbors = torch.cat(adjusted_neighbors, dim=0)

    total_edges = torch.cat(edges, dim=0)

    total_geos = torch.cat(geos, dim=0).contiguous()

    if has_text:
        return batch_graph, total_neighbors, total_edges, total_geos, texts
    return batch_graph, total_neighbors, total_edges, total_geos

In [None]:
# --- 2. Encoder Loss and Training---

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import pandas as pd
import os
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temp = temperature
        self.ce = nn.CrossEntropyLoss()

    def forward(self, g_emb, t_emb):
        # Normalization on the unit sphere
        g_emb = nn.functional.normalize(g_emb, dim=-1)
        t_emb = nn.functional.normalize(t_emb, dim=-1)

        # Cosine similarity
        logits = (g_emb @ t_emb.T) / self.temp

        # Labels : diagonal is the target (graph i corresponds to text i)
        labels = torch.arange(logits.size(0), device=logits.device)

        # Mean of two losses Graph->Text and Text->Graph
        return (self.ce(logits, labels) + self.ce(logits.T, labels)) / 2


def train():
    GRAPH_PATH = 'train_graphs_preprocessed.pkl'
    EMB_PATH = 'train_embeddings.csv'

    if not os.path.exists(GRAPH_PATH):
        print("Missing files !")
        return

    print("Loading Dataset...")
    df = pd.read_csv(EMB_PATH)
    emb_dict = {}
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Parsing CSV"):
        try:
            vals = [float(x) for x in str(row['embedding']).replace('[','').replace(']','').split(',')]
            emb_dict[str(row['ID'])] = torch.tensor(vals, dtype=torch.float32)
        except: continue

    dataset = PreprocessedGraphDataset(GRAPH_PATH, emb_dict)
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate_fn)

    model = FullEncoderModel(
        CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'],
        CONFIG['hidden_dim'], CONFIG['output_dim'],
        max_k=11
    ).to(CONFIG['device'])

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=1e-4)
    criterion = ContrastiveLoss(CONFIG['temperature'])
    scaler = GradScaler()

    print("Training...")
    model.train()

    best_loss = float('inf')

    for epoch in range(CONFIG['epochs']):
        total_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")

        for batch in pbar:
            g_data, neighbors, edges, geos, texts = batch

            g_data = g_data.to(CONFIG['device'])
            neighbors = neighbors.to(CONFIG['device'])
            edges = edges.to(CONFIG['device'])
            geos = geos.to(CONFIG['device'])
            texts = texts.to(CONFIG['device'])

            optimizer.zero_grad()

            with autocast():
                z_graph, _ = model(g_data, neighbors, edges, geos)
                loss = criterion(z_graph, texts)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        avg_loss = total_loss / len(loader)

        print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), "model_best.pt")
            print("Best model saved.")

        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")

In [None]:
 train()

In [None]:
# --- 3. Evaluation of Encoder with R@K ---

import torch
import torch.nn.functional as F
import pandas as pd
import os
import numpy as np
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

EVAL_CONFIG = {
    'model_path': 'model_best.pt',

    'val_graph': 'validation_graphs_preprocessed.pkl',
    'val_csv':   'validation_embeddings.csv',

    'lap_dim': 8,
    'rw_dim': 16,
    'wl_vocab': 50000,
    'hidden_dim': 256,
    'output_dim': 768,
    'max_k': 11,

    'batch_size': 128,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

def run_evaluation():
    print(f"Evaluation on {EVAL_CONFIG['device']}...")

    print(f"Reading embeddings from {EVAL_CONFIG['val_csv']}...")
    if not os.path.exists(EVAL_CONFIG['val_csv']):
        print("Missing csv file !")
        return

    try:
        df = pd.read_csv(EVAL_CONFIG['val_csv'], engine='python', on_bad_lines='skip')
        emb_dict = {}
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Parsing CSV"):
            try:
                val_str = str(row['embedding']).replace('[','').replace(']','').replace('\n','')
                vals = [float(x) for x in val_str.split(',') if x.strip()]
                emb_dict[str(row['ID'])] = torch.tensor(vals, dtype=torch.float32)
            except:
                continue
        print(f"{len(emb_dict)} embeddings textes charg√©s.")
    except Exception as e:
        print(f"Error : {e}")
        return

    if not os.path.exists(EVAL_CONFIG['val_graph']):
        print(f"Error : {EVAL_CONFIG['val_graph']} not found.")
        return

    dataset = PreprocessedGraphDataset(EVAL_CONFIG['val_graph'], emb_dict)

    loader = DataLoader(
        dataset,
        batch_size=EVAL_CONFIG['batch_size'],
        shuffle=False,
        collate_fn=collate_fn
    )

    print(f"Loading weights from {EVAL_CONFIG['model_path']}...")

    model = FullEncoderModel(
        lap_dim=EVAL_CONFIG['lap_dim'],
        rw_dim=EVAL_CONFIG['rw_dim'],
        wl_vocab=EVAL_CONFIG['wl_vocab'],
        hidden_dim=EVAL_CONFIG['hidden_dim'],
        output_dim=EVAL_CONFIG['output_dim'],
        max_k=EVAL_CONFIG['max_k']
    ).to(EVAL_CONFIG['device'])

    if os.path.exists(EVAL_CONFIG['model_path']):
        state_dict = torch.load(EVAL_CONFIG['model_path'], map_location=EVAL_CONFIG['device'])
        try:
            model.load_state_dict(state_dict, strict=True)
            print("Weights loaded.")
        except RuntimeError as e:
            model.load_state_dict(state_dict, strict=False)
    else:
        print("Error : File .pt does not exist.")
        return

    model.eval()

    all_graph_embs = []
    all_text_embs = []

    print("Evaluation...")
    with torch.no_grad():
        for batch in tqdm(loader, desc="Inf√©rence"):
            g_data, neighbors, edges, geos, texts = batch

            g_data = g_data.to(EVAL_CONFIG['device'])
            neighbors = neighbors.to(EVAL_CONFIG['device'])
            edges = edges.to(EVAL_CONFIG['device'])
            geos = geos.to(EVAL_CONFIG['device'])
            texts = texts.to(EVAL_CONFIG['device'])

            with torch.amp.autocast('cuda', enabled=(EVAL_CONFIG['device']=='cuda')):
                z_graph, _ = model(g_data, neighbors, edges, geos)

            all_graph_embs.append(z_graph.cpu())
            all_text_embs.append(texts.cpu())

    # --- Metrics (R@K) ---
    # We try to see if the good vector is predicted in the top K scores
    print("Computing similarities...")

    graph_matrix = torch.cat(all_graph_embs, dim=0).float()
    text_matrix = torch.cat(all_text_embs, dim=0).float()

    graph_matrix = F.normalize(graph_matrix, p=2, dim=-1)
    text_matrix = F.normalize(text_matrix, p=2, dim=-1)

    similarity_matrix = torch.matmul(graph_matrix, text_matrix.T)

    n_samples = similarity_matrix.size(0)
    metrics = {"R@1": 0, "R@5": 0, "R@10": 0}

    print("Looking for Top-K...")
    _, topk_indices = torch.topk(similarity_matrix, k=10, dim=1)
    topk_indices = topk_indices.numpy()

    for i in range(n_samples):
        target_idx = i # good answer
        preds = topk_indices[i]

        if target_idx == preds[0]:
            metrics["R@1"] += 1
        if target_idx in preds[:5]:
            metrics["R@5"] += 1
        if target_idx in preds[:10]:
            metrics["R@10"] += 1

    print("\n" + "="*40)
    print(f"Results ({n_samples} Molecules)")
    print("="*40)
    print(f"R@1  (Exact)  : {metrics['R@1'] / n_samples * 100:.2f}%")
    print(f"R@5  (Top 5)  : {metrics['R@5'] / n_samples * 100:.2f}%")
    print(f"R@10 (Top 10) : {metrics['R@10'] / n_samples * 100:.2f}%")
    print("="*40)

if __name__ == "__main__":
    run_evaluation()

In [None]:
# --- 1. Tokenizer and LLM ---

print(f"Loading Tokenizer {CONFIG['llm_name']}...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['llm_name'], trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# We add a special token to imitate the view of the graph
num_added_tokens = tokenizer.add_tokens(["<graph>"], special_tokens=True)
GRAPH_TOKEN_ID = tokenizer.convert_tokens_to_ids("<graph>")

print(f"Token <graph> added ! (ID: {GRAPH_TOKEN_ID})")
if num_added_tokens == 0:
    print("Token already exists.")

class MoleculeGenDataset(Dataset):
    def __init__(self, pkl_path, tokenizer, max_len=128):
        print(f"Loading graphs from {pkl_path}...")
        with open(pkl_path, 'rb') as f:
            self.graphs = pickle.load(f)
        self.tokenizer = tokenizer
        self.max_len = max_len

        # We define the fixed prompt to get its length
        self.prompt_prefix = (
            "<|im_start|>system\nYou are an expert chemist.<|im_end|>\n"
            "<|im_start|>user\n<graph>\nDescribe this molecule.<|im_end|>\n"
            "<|im_start|>assistant\n"
        )

        # To know where we cut
        self.prefix_ids = self.tokenizer(self.prompt_prefix, add_special_tokens=False).input_ids
        self.prefix_len = len(self.prefix_ids)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        data = self.graphs[idx]

        description = getattr(data, 'description', "Description unavailable.")

        # Total text
        full_text = self.prompt_prefix + description + "<|im_end|>"

        # Tokenization
        tokens = self.tokenizer(
            full_text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)
        labels = input_ids.clone()

        # We ignore padding
        labels[attention_mask == 0] = -100

        # We ignore the fix prompt
        safe_len = min(self.prefix_len, len(labels))
        labels[:safe_len] = -100

        return data, input_ids, attention_mask, labels

def multimodal_collate(batch):
    graphs = [item[0] for item in batch]
    input_ids = torch.stack([item[1] for item in batch])
    attention_mask = torch.stack([item[2] for item in batch])

    labels = torch.stack([item[3] for item in batch])

    batched_graph = Batch.from_data_list(graphs)

    return batched_graph, input_ids, attention_mask, labels

print("Dataset and Tokenizer ready.")

In [None]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

class MolQwen(nn.Module):
    def __init__(self, encoder, llm_model_name):
        super().__init__()
        self.encoder = encoder
        self.encoder.eval()

        # Configuration LLM with 4-bit (Memory)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )

        # LoRA
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
        )
        self.llm = get_peft_model(self.llm, peft_config)

        llm_dim = self.llm.config.hidden_size

        # To stabilise computations (issues with differents bits encoding for weights encoder/decoder)
        self.ln_graph = nn.LayerNorm(256)
        self.projector = nn.Sequential(
            nn.Linear(256, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim)
        )

    def forward(self, g_data, input_ids, attention_mask, labels):
        with torch.no_grad():
            _, h_atoms, batch_idx = self.encoder(
                g_data, g_data.sorted_neighbors, g_data.sorted_edges, g_data.geo_features
            )

        # Projection float 32
        h_atoms_32 = self.ln_graph(h_atoms.float())
        h_projected_32 = self.projector(h_atoms_32)

        # To float 16
        visual_embeds, visual_mask = to_dense_batch(h_projected_32.to(torch.float16), batch_idx)

        # Fusion
        text_embeds = self.llm.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([visual_embeds, text_embeds], dim=1)
        combined_mask = torch.cat([visual_mask, attention_mask], dim=1)

        # Masks
        ignore_image = torch.full(visual_embeds.shape[:2], -100, dtype=torch.long, device=visual_embeds.device)
        final_labels = torch.cat([ignore_image, labels], dim=1)

        return self.llm(inputs_embeds=inputs_embeds, attention_mask=combined_mask, labels=final_labels).loss

    @torch.no_grad()
    def generate_caption(self, g_data, tokenizer, max_new_tokens=100):
        if not hasattr(g_data, 'batch') or g_data.batch is None:
             g_data.batch = torch.zeros(g_data.x.size(0), dtype=torch.long, device=g_data.x.device)

        _, h_atoms, batch_idx = self.encoder(
            g_data, g_data.sorted_neighbors, g_data.sorted_edges, g_data.geo_features
        )
        h_atoms_32 = self.ln_graph(h_atoms.float())
        h_projected_32 = self.projector(h_atoms_32)

        # Float16
        visual_embeds, visual_mask = to_dense_batch(h_projected_32.to(torch.float16), batch_idx)

        # Same prompt
        prompt = "<|im_start|>system\nYou are an expert chemist.<|im_end|>\n<|im_start|>user\n<graph>\nDescribe this molecule.<|im_end|>\n<|im_start|>assistant\n"
        text_inputs = tokenizer(prompt, return_tensors="pt").to(self.llm.device)
        text_embeds = self.llm.get_input_embeddings()(text_inputs.input_ids)

        # Fusion
        inputs_embeds = torch.cat([visual_embeds, text_embeds], dim=1)
        combined_mask = torch.cat([visual_mask, text_inputs.attention_mask], dim=1)

        # Generation 
        output_ids = self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=combined_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False, # deterministic here
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
        )
        return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
def train_llm():
    LLM_CONFIG = {
        'batch_size': 2,
        'grad_accum': 32,
        'lr': 5e-5,
        'epochs': 2,
        'device': CONFIG['device']
    }

    # Datasets
    train_dataset = MoleculeGenDataset('train_graphs_preprocessed.pkl', tokenizer)
    val_dataset = MoleculeGenDataset('validation_graphs_preprocessed.pkl', tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=LLM_CONFIG['batch_size'], shuffle=True, collate_fn=multimodal_collate)
    val_loader = DataLoader(val_dataset, batch_size=LLM_CONFIG['batch_size'], shuffle=False, collate_fn=multimodal_collate)

    # Model
    encoder = FullEncoderModel(CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'], CONFIG['hidden_dim'], CONFIG['output_dim'], max_k=11).to(LLM_CONFIG['device'])
    if os.path.exists("model_best.pt"):
        encoder.load_state_dict(torch.load("model_best.pt", map_location=LLM_CONFIG['device']), strict=False)

    model = MolQwen(encoder, CONFIG['llm_name']).to(LLM_CONFIG['device'])

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LLM_CONFIG['lr'], weight_decay=0.01)

    # Training loop
    for epoch in range(LLM_CONFIG['epochs']):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]")

        for step, batch in enumerate(pbar):
            g, t, m, l = batch
            try:
                g = g.to(LLM_CONFIG['device'])
                t, m, l = t.to(LLM_CONFIG['device']), m.to(LLM_CONFIG['device']), l.to(LLM_CONFIG['device'])

                loss = model(g, t, m, l)

                (loss / LLM_CONFIG['grad_accum']).backward() # accumulated gradients for memory

                if (step + 1) % LLM_CONFIG['grad_accum'] == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # To avoid NaN
                    optimizer.step()
                    optimizer.zero_grad()

                train_loss += loss.item()
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})
            except Exception as e:
                print(f"Error train step {step}: {e}")
                continue

        model.eval()
        torch.save(model.state_dict(), f"molqwen_val_epoch{epoch+1}.pt")
        val_loss = 0
        print(f"Validation Epoch {epoch+1}...")

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [VAL]"):
                g, t, m, l = batch
                g = g.to(LLM_CONFIG['device'])
                t, m, l = t.to(LLM_CONFIG['device']), m.to(LLM_CONFIG['device']), l.to(LLM_CONFIG['device'])

                loss = model(g, t, m, l)
                val_loss += loss.item()

        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} done | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

In [None]:
train_llm()

In [None]:
import csv
import os
import pickle
import torch
import numpy as np
from tqdm.notebook import tqdm

def generate_test_submission():
    TEST_CONFIG = {
        'test_path': 'test_graphs_preprocessed.pkl',
        'model_weights': 'molqwen_val_epoch2.pt',
        'output_file': 'submission.csv',
        'device': CONFIG['device']
    }

    if not os.path.exists(TEST_CONFIG['test_path']):
        print("Test file not found.")
        return

    with open(TEST_CONFIG['test_path'], 'rb') as f:
        test_graphs = pickle.load(f)

    # Model
    encoder = FullEncoderModel(CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'],
                               CONFIG['hidden_dim'], CONFIG['output_dim'], max_k=11).to(TEST_CONFIG['device'])
    model = MolQwen(encoder, CONFIG['llm_name']).to(TEST_CONFIG['device'])

    if os.path.exists(TEST_CONFIG['model_weights']):
        model.load_state_dict(torch.load(TEST_CONFIG['model_weights'], map_location=TEST_CONFIG['device']), strict=False)
        print(f"Weight loaded.")

    model.eval()
    results = []

    # Molecule generation one by one
    with torch.no_grad():
        for i, g_data in enumerate(tqdm(test_graphs, desc="G√©n√©ration Test")):
            
              g_data = g_data.to(TEST_CONFIG['device'])
              caption = model.generate_caption(g_data, tokenizer, max_new_tokens=128)
              print(caption)

              mol_id = getattr(g_data, 'id', i)
              results.append({'ID': mol_id, 'description': caption.replace('\n', ' ').strip()})

    # Save CSV
    with open(TEST_CONFIG['output_file'], 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=['ID', 'description'])
        writer.writeheader()
        writer.writerows(results)
    print(f"Done ! File : {TEST_CONFIG['output_file']}")

if __name__ == "__main__":
    generate_test_submission()

In [None]:
import torch
import torch.nn.functional as F
import pandas as pd
import pickle
from tqdm import tqdm

def final_rescue_mission(input_csv, output_csv, train_pkl_path, test_pkl_path, encoder_weights_path):
    """To replace failures from the decoder (like !!!!!!)"""
    print("Loading data...")
    with open(train_pkl_path, 'rb') as f:
        train_data = pickle.load(f)
    with open(test_pkl_path, 'rb') as f:
        test_data = pickle.load(f)

    df_sub = pd.read_csv(input_csv)

    print(f"Loading weights : {encoder_weights_path}")

    encoder_only = FullEncoderModel(
        lap_dim=CONFIG['lap_dim'],
        rw_dim=CONFIG['rw_dim'],
        wl_vocab=CONFIG['wl_vocab'],
        hidden_dim=CONFIG['hidden_dim'],
        output_dim=CONFIG['output_dim'],
        max_k=11
    ).to(CONFIG['device'])

    encoder_only.load_state_dict(torch.load(encoder_weights_path, map_location=CONFIG['device']))
    encoder_only.eval()

    device = CONFIG['device']
    train_embeddings = []
    train_descriptions = []

    print("Retrieving vectors from training...")
    with torch.no_grad():
        for g_data in tqdm(train_data):
            g_batch = g_data.to(device)

            res = encoder_only(g_batch, g_batch.sorted_neighbors, g_batch.sorted_edges, g_batch.geo_features)
            h_graph = res[0]

            train_embeddings.append(h_graph.cpu())
            train_descriptions.append(getattr(g_data, 'description', ""))

    train_embeddings = torch.cat(train_embeddings, dim=0)

    print("Replacing failures with closest training set neighbor...")
    count_rescued = 0

    for i, row in df_sub.iterrows():
        desc = str(row['description'])

        # if the generation failed
        if "!!" in desc or len(desc) < 25:
            g_test = test_data[i]
            g_test_batch = g_test.to(device)

            with torch.no_grad():
                res_test = encoder_only(g_test_batch, g_test_batch.sorted_neighbors,
                                        g_test_batch.sorted_edges, g_test_batch.geo_features)
                h_test = res_test[0]

            similarities = F.cosine_similarity(h_test.cpu(), train_embeddings)
            best_match_idx = torch.argmax(similarities).item()

            df_sub.at[i, 'description'] = train_descriptions[best_match_idx]
            count_rescued += 1

    df_sub.to_csv(output_csv, index=False)
    print(f"\nDone !")
    print(f"Saved molecules (replaced with Train examples) : {count_rescued} / {len(df_sub)}")
    print(f"Submission file ready : {output_csv}")

final_rescue_mission(
    input_csv='submission.csv',
    output_csv='submission_final.csv',
    train_pkl_path='train_graphs_preprocessed.pkl',
    test_pkl_path='test_graphs_preprocessed.pkl',
    encoder_weights_path='model_best.pt'
)