In [None]:
# ==========================================
# 1. INSTALLATION DES D√âPENDANCES
# ==========================================

# A. Biblioth√®ques Graphe & Optimisation
# torch-scatter est crucial pour acc√©l√©rer le pooling du GNN
!pip install -q torch-geometric torch-scatter

# B. Biblioth√®ques LLM & Fine-Tuning (LoRA/4-bit)
# bitsandbytes : pour charger le mod√®le en 4-bit (gain m√©moire)
# peft : pour LoRA (Parameter-Efficient Fine-Tuning)
# accelerate : pour g√©rer le chargement efficace sur GPU
!pip install -q transformers accelerate bitsandbytes peft

# C. Outils Data Science classiques
!pip install -q pandas networkx scipy tqdm

# ==========================================
# 2. V√âRIFICATION DU MAT√âRIEL
# ==========================================
import torch
import os

print(f"CUDA disponible : {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"‚úÖ GPU d√©tect√© : {device_name}")
    print(f"üíæ VRAM Totale : {vram:.2f} GB")

    # V√©rification compatibilit√© 4-bit (T4, L4, A100...)
    major, minor = torch.cuda.get_device_capability()
    if major >= 7:
        print("üöÄ GPU compatible avec l'entra√Ænement optimis√© (Mixed Precision).")
    else:
        print("‚ö†Ô∏è GPU un peu ancien, √ßa marchera mais moins vite.")
else:
    raise RuntimeError("‚ùå ERREUR : Activez le GPU dans Ex√©cution > Modifier le type d'ex√©cution !")

[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m63.7/63.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m108.0/108.0 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import pickle
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import Batch
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import tqdm

# Configuration Globale
CONFIG = {
    'lap_dim': 8,
    'rw_dim': 16,
    'wl_vocab': 50000,
    'hidden_dim': 256,
    'output_dim': 768,
    'batch_size': 512,
    'lr': 4e-4,
    'epochs': 100,
    'temperature': 0.07,
    'device': 'cuda'
}

In [None]:
# --- 1. SOUS-COUCHES OPTIMIS√âES (LIGHTWEIGHT) ---

class AtomFeatureEncoder(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        dims = [119, 9, 11, 12, 9, 5, 8, 2, 2]
        self.embeddings = nn.ModuleList([nn.Embedding(d, emb_dim) for d in dims])

    def forward(self, x):
        out = 0
        for i, layer in enumerate(self.embeddings):
            out += layer(x[:, i].long())
        return out

class LighterInputLayer(nn.Module):
    """
    Version optimis√©e : Utilise la SOMME au lieu de la CONCAT√âNATION.
    R√©duit drastiquement le nombre de param√®tres et la m√©moire.
    Inspir√© des Transformers (Token Emb + Pos Emb).
    """
    def __init__(self, hidden_dim, num_bond_types=22, max_k=11):
        super().__init__()

        # Embeddings de m√™me dimension que le hidden state
        self.edge_embedding = nn.Embedding(num_bond_types, hidden_dim, padding_idx=0)
        self.rank_embedding = nn.Embedding(max_k, hidden_dim)

        # Projection g√©om√©trique (5 -> hidden_dim) pour pouvoir sommer
        self.geo_proj = nn.Sequential(
            nn.Linear(5, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, h_neighbors, neighbors_edge_idx, geo_features, mask):
        """
        Tout est additionn√© √©l√©ment par √©l√©ment.
        """
        device = h_neighbors.device
        N, K, _ = h_neighbors.shape

        # 1. Features
        # Ar√™tes
        safe_edges = neighbors_edge_idx.clone()
        safe_edges[safe_edges == -1] = 0
        h_edges = self.edge_embedding(safe_edges)

        # Rangs
        ranks = torch.arange(K, device=device).unsqueeze(0).expand(N, K)
        h_ranks = self.rank_embedding(ranks)

        # G√©om√©trie
        h_geo = self.geo_proj(geo_features)

        # 2. SOMME (Op√©ration l√©g√®re)
        # h_neighbors est la base. On lui "ajoute" du contexte.
        # C'est valide math√©matiquement car on travaille dans un espace latent riche.
        out = h_neighbors + h_edges + h_ranks + h_geo

        return self.norm(out) * mask.unsqueeze(-1)

class PreNormAttention(nn.Module):
    """
    Architecture Pre-Norm pour √©viter l'Oversmoothing.
    x = x + Attention(Norm(x)) au lieu de x = Norm(x + Attention(x))
    """
    def __init__(self, input_dim, num_heads=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(input_dim)
        self.mha = nn.MultiheadAttention(input_dim, num_heads, batch_first=True)

        self.norm2 = nn.LayerNorm(input_dim)
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, 4*input_dim),
            nn.ReLU(),
            nn.Linear(4*input_dim, input_dim)
        )

    def forward(self, x, neighbor_indices):
        key_mask = (neighbor_indices == -1)

        # Bloc 1 : Attention (Pre-Norm)
        x_norm = self.norm1(x)
        attn, _ = self.mha(x_norm, x_norm, x_norm, key_padding_mask=key_mask, need_weights=False)
        x = x + attn # Residual pur (Crucial contre oversmoothing)

        # Bloc 2 : FFN (Pre-Norm)
        x_norm = self.norm2(x)
        x = x + self.ffn(x_norm) # Residual pur

        return x * (~key_mask).unsqueeze(-1)

class FullEncoderModel(nn.Module):
    def __init__(self, lap_dim, rw_dim, wl_vocab, hidden_dim, output_dim, max_k=11):
        super().__init__()

        # 1. Encodage Atome
        self.atom_enc = AtomFeatureEncoder(hidden_dim)

        # 2. Global Features (Somme pond√©r√©e)
        self.global_proj = nn.Sequential(nn.Linear(lap_dim + rw_dim, hidden_dim), nn.SiLU())
        self.wl_emb = nn.Embedding(wl_vocab, hidden_dim)

        # 3. Input Layer "Light"
        self.input_layer = LighterInputLayer(hidden_dim, max_k=max_k)

        # 4. Attention (Pre-Norm)
        self.layers = nn.ModuleList([PreNormAttention(hidden_dim) for _ in range(4)])

        # 5. Agr√©gation L√©g√®re
        self.agg_scorer = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
        self.norm_final = nn.LayerNorm(hidden_dim)

        # 6. Sortie (Input dim est hidden_dim car on va sommer mean et max)
        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )
        self.final_norm = nn.LayerNorm(hidden_dim)

    def forward(self, data, sorted_neighbors, sorted_edges, geo_features):
        geo_features = geo_features.contiguous()

        # --- Phase 1 : Noeud Central ---
        h = self.atom_enc(data.x)
        h_global = self.global_proj(torch.cat([data.pe_lap, data.pe_rw], dim=-1))
        h_wl = self.wl_emb(data.pe_wl).sum(dim=1)

        # Somme pure (Memory efficient)
        h = h + h_global + h_wl

        # --- Phase 2 : Enrichissement Voisins ---
        mask = (sorted_neighbors != -1)
        safe_indices = sorted_neighbors.clone()
        safe_indices[~mask] = 0
        h_neighbors_raw = h[safe_indices]

        # Appel de la couche optimis√©e
        h_neighbors = self.input_layer(h_neighbors_raw, sorted_edges, geo_features, mask)

        # --- Phase 3 : Attention ---
        for layer in self.layers:
            # Le r√©siduel est g√©r√© DANS le bloc PreNormAttention
            h_neighbors = layer(h_neighbors, sorted_neighbors)

        # --- Phase 4 : Agr√©gation ---
        scores = self.agg_scorer(h_neighbors)
        scores = scores.masked_fill(~mask.unsqueeze(-1), 0)
        h_context = (h_neighbors * scores).sum(dim=1)

        h_final = self.norm_final(h + h_context)

        # --- Phase 5 : Pooling Optimis√© (Somme) ---
        z_mean = global_mean_pool(h_final, data.batch)
        z_max  = global_max_pool(h_final, data.batch)

        # SOMME au lieu de Concat√©nation (R√©duit la dim par 2 par rapport √† avant)
        z_graph = z_mean + z_max
        z_graph = self.final_norm(z_graph)

        return self.final_mlp(z_graph), data.batch

In [None]:
class PreprocessedGraphDataset(Dataset):
    def __init__(self, graph_path, emb_dict=None):
        print(f"Chargement de {graph_path}...")
        with open(graph_path, 'rb') as f:
            self.graphs = pickle.load(f)
        self.emb_dict = emb_dict

    def __len__(self): return len(self.graphs)

    def __getitem__(self, idx):
        g = self.graphs[idx]
        # On r√©cup√®re aussi sorted_edges maintenant
        # Si sorted_edges n'existe pas encore (vieux fichier), on met du dummy pour √©viter le crash
        edges = getattr(g, 'sorted_edges', torch.zeros_like(g.sorted_neighbors))

        if self.emb_dict:
            return g, g.sorted_neighbors, edges, g.geo_features, self.emb_dict[str(g.id)]
        return g, g.sorted_neighbors, edges, g.geo_features

def collate_fn(batch):
    # D√©tection si on a du texte (5 √©l√©ments) ou pas (4 √©l√©ments)
    has_text = len(batch[0]) == 5

    if has_text:
        graphs, neighbors, edges, geos, texts = zip(*batch)
        texts = torch.stack(texts)
    else:
        graphs, neighbors, edges, geos = zip(*batch)
        texts = None

    batch_graph = Batch.from_data_list(list(graphs))

    # 1. Ajustement des voisins (offsets pour le batch)
    adjusted_neighbors = []
    offset = 0
    for i, n in enumerate(neighbors):
        mask = (n != -1)
        gn = n.clone()
        gn[mask] += offset
        adjusted_neighbors.append(gn)
        offset += graphs[i].num_nodes
    total_neighbors = torch.cat(adjusted_neighbors, dim=0)

    # 2. Gestion des ar√™tes (Pas d'offset, juste concat√©nation)
    # Les types d'ar√™tes (1, 2, 12...) sont des cat√©gories, pas des indices de noeuds.
    total_edges = torch.cat(edges, dim=0)

    # 3. G√©om√©trie
    total_geos = torch.cat(geos, dim=0).contiguous()

    if has_text:
        return batch_graph, total_neighbors, total_edges, total_geos, texts
    return batch_graph, total_neighbors, total_edges, total_geos

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
import pandas as pd
import os
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

# ==========================================
# 1. LA LOSS (Ta version originale)
# ==========================================
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temp = temperature
        self.ce = nn.CrossEntropyLoss()

    def forward(self, g_emb, t_emb):
        # Normalisation sur la sph√®re unit√©
        g_emb = nn.functional.normalize(g_emb, dim=-1)
        t_emb = nn.functional.normalize(t_emb, dim=-1)

        # Calcul de la similarit√© cosinus
        logits = (g_emb @ t_emb.T) / self.temp

        # Labels : la diagonale est la cible (graphe i correspond au texte i)
        labels = torch.arange(logits.size(0), device=logits.device)

        # Moyenne de la loss Graphe->Texte et Texte->Graphe
        return (self.ce(logits, labels) + self.ce(logits.T, labels)) / 2

# ==========================================
# 2. BOUCLE D'ENTRA√éNEMENT (Adapt√©e)
# ==========================================
import torch.optim as optim

def train():
    # --- A. CONFIG ---
    GRAPH_PATH = 'train_graphs_preprocessed.pkl'
    EMB_PATH = 'train_embeddings.csv'

    if not os.path.exists(GRAPH_PATH):
        print("‚ùå ERREUR : Fichiers manquants !")
        return

    # --- B. DATA ---
    print("üìÇ Chargement Dataset...")
    df = pd.read_csv(EMB_PATH)
    emb_dict = {}
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Parsing CSV"):
        try:
            vals = [float(x) for x in str(row['embedding']).replace('[','').replace(']','').split(',')]
            emb_dict[str(row['ID'])] = torch.tensor(vals, dtype=torch.float32)
        except: continue

    dataset = PreprocessedGraphDataset(GRAPH_PATH, emb_dict)
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate_fn)

    # --- C. MODEL ---
    model = FullEncoderModel(
        CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'],
        CONFIG['hidden_dim'], CONFIG['output_dim'],
        max_k=11
    ).to(CONFIG['device'])

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=1e-4) # AdamW est mieux avec un Scheduler
    criterion = ContrastiveLoss(CONFIG['temperature'])
    scaler = GradScaler()

    print("üöÄ D√©but de l'entra√Ænement (Avec Scheduler)...")
    model.train()

    best_loss = float('inf')

    for epoch in range(CONFIG['epochs']):
        total_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")

        for batch in pbar:
            g_data, neighbors, edges, geos, texts = batch

            g_data = g_data.to(CONFIG['device'])
            neighbors = neighbors.to(CONFIG['device'])
            edges = edges.to(CONFIG['device'])
            geos = geos.to(CONFIG['device'])
            texts = texts.to(CONFIG['device'])

            optimizer.zero_grad()

            with autocast():
                z_graph, _ = model(g_data, neighbors, edges, geos)
                loss = criterion(z_graph, texts)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        # --- FIN EPOCH ---
        avg_loss = total_loss / len(loader)

        print(f"‚úÖ Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

        # Sauvegarde intelligente (seulement si √ßa s'am√©liore ou tous les 10)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), "model_best.pt")
            print("üíæ Nouveau record ! Mod√®le sauvegard√©.")

        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")

In [None]:
 train()

‚ùå ERREUR : Fichiers manquants !


In [None]:
# =============================================================================
# CELLULE D'√âVALUATION COMPL√àTE (CORRIG√âE)
# =============================================================================
import torch
import torch.nn.functional as F
import pandas as pd
import os
import numpy as np
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

# --- 1. CONFIGURATION ---
EVAL_CONFIG = {
    'model_path': 'model_best.pt', # Assure-toi que c'est le bon fichier

    # Chemins Validation (ou Test)
    'val_graph': 'validation_graphs_preprocessed.pkl',
    'val_csv':   'validation_embeddings.csv',

    # Hyperparam√®tres (Doivent √™tre IDENTIQUES √† ceux de l'entra√Ænement)
    'lap_dim': 8,
    'rw_dim': 16,
    'wl_vocab': 50000,
    'hidden_dim': 256,
    'output_dim': 768,
    'max_k': 11, # Tr√®s important d'ajouter √ßa

    'batch_size': 128, # On peut mettre plus grand en eval (pas de gradient), ex: 256
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

def run_evaluation():
    print(f"‚öôÔ∏è D√©marrage de l'√©valuation sur {EVAL_CONFIG['device']}...")

    # --- A. Chargement Robuste du CSV ---
    print(f"üìÇ Lecture des embeddings textes depuis {EVAL_CONFIG['val_csv']}...")
    if not os.path.exists(EVAL_CONFIG['val_csv']):
        print("‚ùå ERREUR : Fichier CSV introuvable. Upload-le !")
        return

    try:
        df = pd.read_csv(EVAL_CONFIG['val_csv'], engine='python', on_bad_lines='skip')
        emb_dict = {}
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Parsing CSV"):
            try:
                # Parsing string "[0.1, ...]" -> Tensor
                val_str = str(row['embedding']).replace('[','').replace(']','').replace('\n','')
                vals = [float(x) for x in val_str.split(',') if x.strip()]
                emb_dict[str(row['ID'])] = torch.tensor(vals, dtype=torch.float32)
            except:
                continue
        print(f"‚úÖ {len(emb_dict)} embeddings textes charg√©s.")
    except Exception as e:
        print(f"‚ùå Erreur critique lecture CSV : {e}")
        return

    # --- B. Pr√©paration du Dataset ---
    if not os.path.exists(EVAL_CONFIG['val_graph']):
        print(f"‚ùå ERREUR : {EVAL_CONFIG['val_graph']} introuvable.")
        return

    dataset = PreprocessedGraphDataset(EVAL_CONFIG['val_graph'], emb_dict)

    loader = DataLoader(
        dataset,
        batch_size=EVAL_CONFIG['batch_size'],
        shuffle=False, # OBLIGATOIRE : False pour garder l'alignement Graphe <-> Texte
        collate_fn=collate_fn
    )

    # --- C. Chargement du Mod√®le ---
    print(f"üß† Chargement des poids depuis {EVAL_CONFIG['model_path']}...")

    model = FullEncoderModel(
        lap_dim=EVAL_CONFIG['lap_dim'],
        rw_dim=EVAL_CONFIG['rw_dim'],
        wl_vocab=EVAL_CONFIG['wl_vocab'],
        hidden_dim=EVAL_CONFIG['hidden_dim'],
        output_dim=EVAL_CONFIG['output_dim'],
        max_k=EVAL_CONFIG['max_k']
    ).to(EVAL_CONFIG['device'])

    if os.path.exists(EVAL_CONFIG['model_path']):
        state_dict = torch.load(EVAL_CONFIG['model_path'], map_location=EVAL_CONFIG['device'])
        # strict=True est recommand√© pour √™tre s√ªr qu'on a la bonne architecture
        try:
            model.load_state_dict(state_dict, strict=True)
            print("‚úÖ Poids charg√©s parfaitement.")
        except RuntimeError as e:
            print(f"‚ö†Ô∏è Attention : mismatch des cl√©s (architecture diff√©rente ?). Essai avec strict=False.\nErreur: {e}")
            model.load_state_dict(state_dict, strict=False)
    else:
        print("‚ùå ERREUR : Le fichier .pt du mod√®le n'existe pas.")
        return

    model.eval() # Mode √©valuation (D√©sactive Dropout, fixe BatchNorm)

    # --- D. Boucle d'Inf√©rence ---
    all_graph_embs = []
    all_text_embs = []

    print("üöÄ Extraction des vecteurs...")
    with torch.no_grad(): # Pas de gradient = Gain m√©moire √©norme
        for batch in tqdm(loader, desc="Inf√©rence"):
            # --- MODIFICATION 1 : On d√©balle 5 √©l√©ments (avec edges) ---
            g_data, neighbors, edges, geos, texts = batch

            # Transfert GPU
            g_data = g_data.to(EVAL_CONFIG['device'])
            neighbors = neighbors.to(EVAL_CONFIG['device'])
            edges = edges.to(EVAL_CONFIG['device']) # <-- NOUVEAU
            geos = geos.to(EVAL_CONFIG['device'])
            texts = texts.to(EVAL_CONFIG['device'])

            # Mixed Precision
            with torch.amp.autocast('cuda', enabled=(EVAL_CONFIG['device']=='cuda')):
                # --- MODIFICATION 2 : Appel avec edges ---
                # --- MODIFICATION 3 : Le mod√®le renvoie d√©j√† z_graph pool√© ---
                z_graph, _ = model(g_data, neighbors, edges, geos)

            # Stockage CPU pour √©viter saturation VRAM
            all_graph_embs.append(z_graph.cpu())
            all_text_embs.append(texts.cpu())

    # --- E. Calcul des M√©triques (R@K) ---
    print("üìä Calcul des similarit√©s...")

    # 1. Empilement
    graph_matrix = torch.cat(all_graph_embs, dim=0).float()
    text_matrix = torch.cat(all_text_embs, dim=0).float()

    # 2. Normalisation L2 (Crucial pour Cosine Similarity)
    graph_matrix = F.normalize(graph_matrix, p=2, dim=-1)
    text_matrix = F.normalize(text_matrix, p=2, dim=-1)

    # 3. Produit Scalaire (Cosine Sim car normalis√©)
    # [N_samples, N_samples] -> Attention, √ßa peut √™tre gros en m√©moire
    # Si √ßa plante ici, il faudra faire par batchs, mais pour <10k samples √ßa passe.
    similarity_matrix = torch.matmul(graph_matrix, text_matrix.T)

    # 4. M√©triques
    n_samples = similarity_matrix.size(0)
    metrics = {"R@1": 0, "R@5": 0, "R@10": 0}

    print("üîç Recherche des Top-K...")
    # topk renvoie les valeurs et les indices des k meilleurs scores
    _, topk_indices = torch.topk(similarity_matrix, k=10, dim=1)
    topk_indices = topk_indices.numpy()

    for i in range(n_samples):
        # La bonne r√©ponse est l'index i (car on a gard√© l'alignement shuffle=False)
        target_idx = i
        preds = topk_indices[i]

        if target_idx == preds[0]:
            metrics["R@1"] += 1
        if target_idx in preds[:5]:
            metrics["R@5"] += 1
        if target_idx in preds[:10]:
            metrics["R@10"] += 1

    # Affichage
    print("\n" + "="*40)
    print(f"üèÜ R√âSULTATS VALIDATION ({n_samples} Mol√©cules)")
    print("="*40)
    print(f"R@1  (Exact)  : {metrics['R@1'] / n_samples * 100:.2f}%")
    print(f"R@5  (Top 5)  : {metrics['R@5'] / n_samples * 100:.2f}%")
    print(f"R@10 (Top 10) : {metrics['R@10'] / n_samples * 100:.2f}%")
    print("="*40)

# Lancer l'√©valuation
if __name__ == "__main__":
    run_evaluation()

‚öôÔ∏è D√©marrage de l'√©valuation sur cuda...
üìÇ Lecture des embeddings textes depuis validation_embeddings.csv...
‚ùå ERREUR : Fichier CSV introuvable. Upload-le !


In [None]:
import torch
import os
import pickle
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

# ==========================================
# CONFIGURATION
# ==========================================
CHECK_CONFIG = {
    'test_path': 'test_graphs_preprocessed.pkl',  # Ton fichier test
    'model_path': 'model_best.pt',                     # Tes poids
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 128,
    'max_k': 11
}

# ==========================================
# FONCTION DE S√âCURIT√â (ANTI-CRASH)
# ==========================================
def secure_batch(g, max_k=11):
    """Coupe et nettoie les donn√©es pour √©viter les erreurs CUDA."""
    # 1. Coupe les voisins en trop (Slicing)
    if hasattr(g, 'sorted_neighbors') and g.sorted_neighbors.size(1) > max_k:
        g.sorted_neighbors = g.sorted_neighbors[:, :max_k]
    if hasattr(g, 'sorted_edges') and g.sorted_edges.size(1) > max_k:
        g.sorted_edges = g.sorted_edges[:, :max_k]
    if hasattr(g, 'geo_features') and g.geo_features.size(1) > max_k:
        g.geo_features = g.geo_features[:, :max_k, :]

    # 2. V√©rifie les valeurs limites (Clamping)
    LIMITS = [119, 9, 11, 12, 9, 5, 8, 2, 2]
    # Atomes
    for i, m in enumerate(LIMITS):
        g.x[:, i] = torch.clamp(g.x[:, i], min=0, max=m-1)
    # Voisins (Indices < n_nodes)
    if hasattr(g, 'sorted_neighbors'):
        mask = (g.sorted_neighbors != -1)
        if mask.any():
            g.sorted_neighbors[mask] = torch.clamp(g.sorted_neighbors[mask], min=0, max=g.num_nodes - 1)
    # Ar√™tes (Types < 22)
    if hasattr(g, 'sorted_edges'):
        mask = (g.sorted_edges != -1)
        if mask.any():
            g.sorted_edges[mask] = torch.clamp(g.sorted_edges[mask], min=0, max=21)

    return g

# ==========================================
# SCRIPT DE V√âRIFICATION
# ==========================================
def run_sanity_check():
    print(f"ü©∫ D√©marrage du Sanity Check sur {CHECK_CONFIG['device']}...")

    # 1. Chargement Data (Sans Embeddings)
    if not os.path.exists(CHECK_CONFIG['test_path']):
        print(f"‚ùå Fichier introuvable : {CHECK_CONFIG['test_path']}")
        return

    # On utilise ta classe PreprocessedGraphDataset (emb_dict=None par d√©faut)
    dataset = PreprocessedGraphDataset(CHECK_CONFIG['test_path'], emb_dict=None)

    # collate_fn g√®re automatiquement le cas sans texte (retoune 4 items)
    loader = DataLoader(dataset, batch_size=CHECK_CONFIG['batch_size'], shuffle=False, collate_fn=collate_fn)
    print(f"‚úÖ Dataset charg√© : {len(dataset)} mol√©cules.")

    # 2. Chargement Mod√®le
    model = FullEncoderModel(
        lap_dim=8, rw_dim=16, wl_vocab=50000, hidden_dim=256, output_dim=768,
        max_k=CHECK_CONFIG['max_k']
    ).to(CHECK_CONFIG['device'])

    if os.path.exists(CHECK_CONFIG['model_path']):
        print(f"üì• Poids charg√©s : {CHECK_CONFIG['model_path']}")
        model.load_state_dict(torch.load(CHECK_CONFIG['model_path'], map_location=CHECK_CONFIG['device']), strict=False)
    else:
        print("‚ö†Ô∏è Poids introuvables. Test avec mod√®le non-entra√Æn√© (juste pour v√©rifier les bugs).")

    model.eval()

    # 3. Boucle de Test
    success_count = 0
    error_count = 0

    print("üöÄ Run en cours...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader, desc="V√©rification")):
            # R√©cup√©ration des donn√©es (4 √©l√©ments car pas de texte)
            # g_data, neighbors, edges, geos = batch
            # Si ton collate renvoie autre chose, adapte ici.
            # D'apr√®s ton notebook : collate_fn sans texte renvoie 4 items.
            g_data, neighbors, edges, geos = batch

            # S√©curisation
            try:
                g_data = secure_batch(g_data, max_k=CHECK_CONFIG['max_k'])
            except Exception as e:
                print(f"‚ö†Ô∏è Erreur Preprocessing Batch {i}: {e}")
                error_count += 1
                continue

            # GPU
            g_data = g_data.to(CHECK_CONFIG['device'])
            neighbors = neighbors.to(CHECK_CONFIG['device'])
            edges = edges.to(CHECK_CONFIG['device'])
            geos = geos.to(CHECK_CONFIG['device'])

            # Inf√©rence
            try:
                # Le mod√®le renvoie (output, batch_index) ou juste output selon ta version
                _ = model(g_data, neighbors, edges, geos)
                success_count += g_data.num_graphs

            except RuntimeError as e:
                print(f"üî• CRASH CUDA Batch {i} : {e}")
                error_count += 1
                torch.cuda.empty_cache()
                # On break pas forc√©ment, on essaie de continuer pour voir si c'est isol√©

    print("\n" + "="*30)
    print("BILAN DU RUN")
    print("="*30)
    if error_count == 0:
        print(f"‚úÖ SUCC√àS TOTAL : {success_count} mol√©cules trait√©es sans erreur.")
        print("Le mod√®le est stable sur le Test Set.")
    else:
        print(f"‚ùå √âCHEC : {error_count} batchs ont plant√©.")
        print("Regarde les messages d'erreur ci-dessus (souvent device-side assert).")
    print("="*30)

if __name__ == "__main__":
    run_sanity_check()

ü©∫ D√©marrage du Sanity Check sur cuda...
Chargement de test_graphs_preprocessed.pkl...
‚úÖ Dataset charg√© : 1000 mol√©cules.
‚ö†Ô∏è Poids introuvables. Test avec mod√®le non-entra√Æn√© (juste pour v√©rifier les bugs).
üöÄ Run en cours...


V√©rification:   0%|          | 0/8 [00:00<?, ?it/s]


BILAN DU RUN
‚úÖ SUCC√àS TOTAL : 1000 mol√©cules trait√©es sans erreur.
Le mod√®le est stable sur le Test Set.


In [None]:
# ==========================================
# BLOC 1 : IMPORTS ET ARCHITECTURE ENCODEUR
# ==========================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
import pandas as pd
import pickle
import os
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

# --- CONFIGURATION GLOBALE ---
CONFIG = {
    'lap_dim': 8,
    'rw_dim': 16,
    'wl_vocab': 50000,
    'hidden_dim': 256,
    'output_dim': 768,
    'max_k': 11,
    'llm_name': "Qwen/Qwen2.5-1.5B-Instruct",
    'max_text_len': 128,
    'device': "cuda" if torch.cuda.is_available() else "cpu"
}

class FullEncoderModel(nn.Module):
    def __init__(self, lap_dim, rw_dim, wl_vocab, hidden_dim, output_dim, max_k=11):
        super().__init__()

        # 1. Encodage Atome
        self.atom_enc = AtomFeatureEncoder(hidden_dim)

        # 2. Global Features (Somme pond√©r√©e)
        self.global_proj = nn.Sequential(nn.Linear(lap_dim + rw_dim, hidden_dim), nn.SiLU())
        self.wl_emb = nn.Embedding(wl_vocab, hidden_dim)

        # 3. Input Layer "Light"
        self.input_layer = LighterInputLayer(hidden_dim, max_k=max_k)

        # 4. Attention (Pre-Norm)
        self.layers = nn.ModuleList([PreNormAttention(hidden_dim) for _ in range(4)])

        # 5. Agr√©gation L√©g√®re
        self.agg_scorer = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
        self.norm_final = nn.LayerNorm(hidden_dim)

        # 6. Sortie (Input dim est hidden_dim car on va sommer mean et max)
        self.final_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )
        self.final_norm = nn.LayerNorm(hidden_dim)

    def forward(self, data, sorted_neighbors, sorted_edges, geo_features):
        geo_features = geo_features.contiguous()

        # --- Phase 1 : Noeud Central ---
        h = self.atom_enc(data.x)
        h_global = self.global_proj(torch.cat([data.pe_lap, data.pe_rw], dim=-1))
        h_wl = self.wl_emb(data.pe_wl).sum(dim=1)

        # Somme pure (Memory efficient)
        h = h + h_global + h_wl

        # --- Phase 2 : Enrichissement Voisins ---
        mask = (sorted_neighbors != -1)
        safe_indices = sorted_neighbors.clone()
        safe_indices[~mask] = 0
        h_neighbors_raw = h[safe_indices]

        # Appel de la couche optimis√©e
        h_neighbors = self.input_layer(h_neighbors_raw, sorted_edges, geo_features, mask)

        # --- Phase 3 : Attention ---
        for layer in self.layers:
            # Le r√©siduel est g√©r√© DANS le bloc PreNormAttention
            h_neighbors = layer(h_neighbors, sorted_neighbors)

        # --- Phase 4 : Agr√©gation ---
        scores = self.agg_scorer(h_neighbors)
        scores = scores.masked_fill(~mask.unsqueeze(-1), 0)
        h_context = (h_neighbors * scores).sum(dim=1)

        h_final = self.norm_final(h + h_context)

        # --- Phase 5 : Pooling Optimis√© (Somme) ---
        z_mean = global_mean_pool(h_final, data.batch)
        z_max  = global_max_pool(h_final, data.batch)

        # SOMME au lieu de Concat√©nation (R√©duit la dim par 2 par rapport √† avant)
        z_graph = z_mean + z_max
        z_graph = self.final_norm(z_graph)

        return self.final_mlp(z_graph), h_final, data.batch

In [None]:
# ==========================================
# BLOC 2 : DATASET & TOKENIZATION (QWEN)
# ==========================================

# 1. Chargement du Tokenizer Qwen
# C'est l'outil qui d√©coupe le texte en morceaux
print(f"Chargement du Tokenizer {CONFIG['llm_name']}...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['llm_name'], trust_remote_code=True)

# A. On d√©finit le PAD token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# B. --- LE FIX MAGIQUE ---
# On dit au tokenizer de ne JAMAIS d√©couper "<graph>"
num_added_tokens = tokenizer.add_tokens(["<graph>"], special_tokens=True)
GRAPH_TOKEN_ID = tokenizer.convert_tokens_to_ids("<graph>")

print(f"‚úÖ Token <graph> ajout√© ! (ID: {GRAPH_TOKEN_ID})")
if num_added_tokens == 0:
    print("‚ö†Ô∏è Attention : Le token existait d√©j√† (C'est bon si tu relances la cellule).")

# ==========================================
# 1. DATASET SIMPLIFI√â (LECTURE DIRECTE)
# ==========================================
class MoleculeGenDataset(Dataset):
    def __init__(self, pkl_path, tokenizer, max_len=128):
        print(f"üìÇ Chargement des graphes depuis {pkl_path}...")
        with open(pkl_path, 'rb') as f:
            self.graphs = pickle.load(f)
        self.tokenizer = tokenizer
        self.max_len = max_len

        # 1. On d√©finit le prompt "fixe" pour mesurer sa taille
        # Note : On n'inclut pas <im_end> ici car il fait partie de la fin
        self.prompt_prefix = (
            "<|im_start|>system\nYou are an expert chemist.<|im_end|>\n"
            "<|im_start|>user\n<graph>\nDescribe this molecule.<|im_end|>\n"
            "<|im_start|>assistant\n"
        )

        # 2. On calcule combien de tokens fait ce pr√©fixe
        # On utilise len() sur les input_ids pour savoir o√π couper
        self.prefix_ids = self.tokenizer(self.prompt_prefix, add_special_tokens=False).input_ids
        self.prefix_len = len(self.prefix_ids)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        data = self.graphs[idx]

        # On r√©cup√®re la description (texte cible)
        description = getattr(data, 'description', "Description unavailable.")

        # On construit le texte TOTAL
        full_text = self.prompt_prefix + description + "<|im_end|>"

        # Tokenization
        tokens = self.tokenizer(
            full_text,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)

        # --- LE FIX DU MASKING ---
        labels = input_ids.clone()

        # A. On ignore le padding (l√† o√π attention_mask est 0)
        labels[attention_mask == 0] = -100

        # B. On ignore le Prompt (la question)
        # On met -100 sur toute la longueur du pr√©fixe
        # On s'assure de ne pas d√©passer la taille du tenseur (au cas o√π on tronque)
        safe_len = min(self.prefix_len, len(labels))
        labels[:safe_len] = -100

        return data, input_ids, attention_mask, labels  # <--- On renvoie les labels !

# Fonction pour assembler les batchs (PyG + HuggingFace)
def multimodal_collate(batch):
    graphs = [item[0] for item in batch]
    input_ids = torch.stack([item[1] for item in batch])
    attention_mask = torch.stack([item[2] for item in batch])

    # --- AJOUT ---
    labels = torch.stack([item[3] for item in batch])

    batched_graph = Batch.from_data_list(graphs)

    # On retourne 4 √©l√©ments maintenant
    return batched_graph, input_ids, attention_mask, labels

print("‚úÖ Dataset et Tokenizer pr√™ts.")

Chargement du Tokenizer Qwen/Qwen2.5-1.5B-Instruct...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

‚úÖ Token <graph> ajout√© ! (ID: 151665)
‚úÖ Dataset et Tokenizer pr√™ts.


In [None]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

class MolQwen(nn.Module):
    def __init__(self, encoder, llm_model_name):
        super().__init__()
        self.encoder = encoder
        self.encoder.eval()

        # 1. Configuration LLM en 4-bit
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )

        # 2. Configuration LoRA
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
        )
        self.llm = get_peft_model(self.llm, peft_config)

        # 3. Projecteur et Stabilisation (FIX: Ajout de ln_graph)
        llm_dim = self.llm.config.hidden_size

        # On garde ln_graph et projector en float32 pour la stabilit√© des calculs
        self.ln_graph = nn.LayerNorm(256)
        self.projector = nn.Sequential(
            nn.Linear(256, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim)
        )

    def forward(self, g_data, input_ids, attention_mask, labels):
        # A. Encodage GNN stable
        with torch.no_grad():
            _, h_atoms, batch_idx = self.encoder(
                g_data, g_data.sorted_neighbors, g_data.sorted_edges, g_data.geo_features
            )

        # B. Projection en Float32 (√©vite les NaNs pendant le calcul matriciel)
        h_atoms_32 = self.ln_graph(h_atoms.float())
        h_projected_32 = self.projector(h_atoms_32)

        # C. Conversion en Float16 pour l'entr√©e du LLM quantifi√©
        visual_embeds, visual_mask = to_dense_batch(h_projected_32.to(torch.float16), batch_idx)

        # D. Fusion Multimodale
        text_embeds = self.llm.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([visual_embeds, text_embeds], dim=1)
        combined_mask = torch.cat([visual_mask, attention_mask], dim=1)

        # E. Masquage des labels
        ignore_image = torch.full(visual_embeds.shape[:2], -100, dtype=torch.long, device=visual_embeds.device)
        final_labels = torch.cat([ignore_image, labels], dim=1)

        return self.llm(inputs_embeds=inputs_embeds, attention_mask=combined_mask, labels=final_labels).loss

    @torch.no_grad()
    def generate_caption(self, g_data, tokenizer, max_new_tokens=100):
        # S√©curit√© Batch
        if not hasattr(g_data, 'batch') or g_data.batch is None:
             g_data.batch = torch.zeros(g_data.x.size(0), dtype=torch.long, device=g_data.x.device)

        # Encodage et Projection stable (Float32)
        _, h_atoms, batch_idx = self.encoder(
            g_data, g_data.sorted_neighbors, g_data.sorted_edges, g_data.geo_features
        )
        h_atoms_32 = self.ln_graph(h_atoms.float())
        h_projected_32 = self.projector(h_atoms_32)

        # Conversion Float16
        visual_embeds, visual_mask = to_dense_batch(h_projected_32.to(torch.float16), batch_idx)

        # Pr√©paration Prompt
        prompt = "<|im_start|>system\nYou are an expert chemist.<|im_end|>\n<|im_start|>user\n<graph>\nDescribe this molecule.<|im_end|>\n<|im_start|>assistant\n"
        text_inputs = tokenizer(prompt, return_tensors="pt").to(self.llm.device)
        text_embeds = self.llm.get_input_embeddings()(text_inputs.input_ids)

        # Fusion
        inputs_embeds = torch.cat([visual_embeds, text_embeds], dim=1)
        combined_mask = torch.cat([visual_mask, text_inputs.attention_mask], dim=1)

        # G√©n√©ration
        output_ids = self.llm.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=combined_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
        )
        return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
# ==========================================
# BLOC 4 : ENTRA√éNEMENT COMPLET (TRAIN + VAL)
# ==========================================
def train_llm():
    LLM_CONFIG = {
        'batch_size': 2,
        'grad_accum': 32,
        'lr': 5e-5,
        'epochs': 2,
        'device': CONFIG['device']
    }

    # A. Chargement Datasets
    train_dataset = MoleculeGenDataset('train_graphs_preprocessed.pkl', tokenizer)
    val_dataset = MoleculeGenDataset('validation_graphs_preprocessed.pkl', tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=LLM_CONFIG['batch_size'], shuffle=True, collate_fn=multimodal_collate)
    val_loader = DataLoader(val_dataset, batch_size=LLM_CONFIG['batch_size'], shuffle=False, collate_fn=multimodal_collate)

    # B. Mod√®le & Optimiseur
    # L'encodeur est charg√© et gel√© √† l'int√©rieur de MolQwen
    encoder = FullEncoderModel(CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'], CONFIG['hidden_dim'], CONFIG['output_dim'], max_k=11).to(LLM_CONFIG['device'])
    if os.path.exists("model_best.pt"):
        encoder.load_state_dict(torch.load("model_best.pt", map_location=LLM_CONFIG['device']), strict=False)

    model = MolQwen(encoder, CONFIG['llm_name']).to(LLM_CONFIG['device'])

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LLM_CONFIG['lr'], weight_decay=0.01)

    # C. Boucle Principale
    for epoch in range(LLM_CONFIG['epochs']):
        # --- PHASE D'ENTRA√éNEMENT ---
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]")

        for step, batch in enumerate(pbar):
            g, t, m, l = batch
            try:
                # Nettoyage et transfert
                g = g.to(LLM_CONFIG['device'])
                t, m, l = t.to(LLM_CONFIG['device']), m.to(LLM_CONFIG['device']), l.to(LLM_CONFIG['device'])

                loss = model(g, t, m, l)

                (loss / LLM_CONFIG['grad_accum']).backward()

                if (step + 1) % LLM_CONFIG['grad_accum'] == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    optimizer.zero_grad()

                train_loss += loss.item()
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})
            except Exception as e:
                print(f"Erreur train step {step}: {e}")
                continue

        # --- PHASE DE VALIDATION ---
        model.eval()
        torch.save(model.state_dict(), f"molqwen_val_epoch{epoch+1}.pt")
        val_loss = 0
        print(f"üìä Validation Epoch {epoch+1}...")

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [VAL]"):
                g, t, m, l = batch
                g = ultra_secure_batch(g, max_k=11, wl_vocab=CONFIG['wl_vocab']).to(LLM_CONFIG['device'])
                t, m, l = t.to(LLM_CONFIG['device']), m.to(LLM_CONFIG['device']), l.to(LLM_CONFIG['device'])

                loss = model(g, t, m, l)
                val_loss += loss.item()

        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        print(f"‚úÖ Epoch {epoch+1} termin√©e | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

        # Sauvegarde du point de contr√¥le

In [None]:
train_llm()

üìÇ Chargement des graphes depuis train_graphs_preprocessed.pkl...
üìÇ Chargement des graphes depuis validation_graphs_preprocessed.pkl...


Epoch 1 [TRAIN]:   0%|          | 0/15504 [00:00<?, ?it/s]

üìä Validation Epoch 1...


Epoch 1 [VAL]:   0%|          | 0/500 [00:00<?, ?it/s]

‚úÖ Epoch 1 termin√©e | Train Loss: 1.3627 | Val Loss: nan


Epoch 2 [TRAIN]:   0%|          | 0/15504 [00:00<?, ?it/s]

üìä Validation Epoch 2...


Epoch 2 [VAL]:   0%|          | 0/500 [00:00<?, ?it/s]

‚úÖ Epoch 2 termin√©e | Train Loss: 1.0917 | Val Loss: nan


In [None]:
# ==========================================
# CELLULE DE G√âN√âRATION FINALE "ULTRA-SAFE"
# ==========================================
import csv
import os
import pickle
import torch
import numpy as np
from tqdm.notebook import tqdm

def ultra_secure_batch(g, max_k=11, wl_vocab=50000):
    """Version renforc√©e pour √©viter TOUS les crashs CUDA."""
    if g is None: return None

    # 1. Protection des Embeddings d'atomes (LIMITS de ton AtomFeatureEncoder)
    LIMITS = [119, 9, 11, 12, 9, 5, 8, 2, 2]
    for i, m in enumerate(LIMITS):
        g.x[:, i] = torch.clamp(g.x[:, i], min=0, max=m-1)

    # 2. Protection du vocabulaire WL (TR√àS IMPORTANT : c'est souvent lui qui crash)
    if hasattr(g, 'pe_wl'):
        g.pe_wl = torch.clamp(g.pe_wl, min=0, max=wl_vocab - 1)

    # 3. Nettoyage des NaNs/Infs dans les descripteurs g√©om√©triques
    if hasattr(g, 'pe_lap'):
        g.pe_lap = torch.nan_to_num(g.pe_lap, nan=0.0, posinf=0.0, neginf=0.0)
    if hasattr(g, 'pe_rw'):
        g.pe_rw = torch.nan_to_num(g.pe_rw, nan=0.0, posinf=0.0, neginf=0.0)
    if hasattr(g, 'geo_features'):
        g.geo_features = torch.nan_to_num(g.geo_features, nan=0.0)

    # 4. Slicing strict des voisins et ar√™tes
    if hasattr(g, 'sorted_neighbors'):
        g.sorted_neighbors = g.sorted_neighbors[:, :max_k]
        # On s'assure que les indices pointent vers des noeuds existants
        mask = (g.sorted_neighbors != -1)
        g.sorted_neighbors[mask] = torch.clamp(g.sorted_neighbors[mask], min=0, max=g.num_nodes - 1)

    if hasattr(g, 'sorted_edges'):
        g.sorted_edges = g.sorted_edges[:, :max_k]
        mask = (g.sorted_edges != -1)
        g.sorted_edges[mask] = torch.clamp(g.sorted_edges[mask], min=0, max=21) # Max 22 types

    return g

def generate_test_submission():
    TEST_CONFIG = {
        'test_path': 'test_graphs_preprocessed.pkl',
        'model_weights': 'molqwen_val_epoch2.pt',
        'output_file': 'submission.csv',
        'device': CONFIG['device']
    }

    if not os.path.exists(TEST_CONFIG['test_path']):
        print("‚ùå Fichier de test introuvable.")
        return

    # Chargement
    with open(TEST_CONFIG['test_path'], 'rb') as f:
        test_graphs = pickle.load(f)

    # Initialisation mod√®le
    encoder = FullEncoderModel(CONFIG['lap_dim'], CONFIG['rw_dim'], CONFIG['wl_vocab'],
                               CONFIG['hidden_dim'], CONFIG['output_dim'], max_k=11).to(TEST_CONFIG['device'])
    model = MolQwen(encoder, CONFIG['llm_name']).to(TEST_CONFIG['device'])

    if os.path.exists(TEST_CONFIG['model_weights']):
        model.load_state_dict(torch.load(TEST_CONFIG['model_weights'], map_location=TEST_CONFIG['device']), strict=False)
        print(f"‚úÖ Poids charg√©s.")

    model.eval()
    results = []

    # Inf√©rence avec gestion d'erreur par mol√©cule
    with torch.no_grad():
        for i, g_data in enumerate(tqdm(test_graphs, desc="G√©n√©ration Test")):
              # 1. Nettoyage CPU avant envoi au GPU
              g_data = ultra_secure_batch(g_data, max_k=11, wl_vocab=CONFIG['wl_vocab'])
              g_data = g_data.to(TEST_CONFIG['device'])

              # 2. G√©n√©ration
              caption = model.generate_caption(g_data, tokenizer, max_new_tokens=128)
              print(caption)

              mol_id = getattr(g_data, 'id', i)
              results.append({'ID': mol_id, 'description': caption.replace('\n', ' ').strip()})

    # Sauvegarde
    with open(TEST_CONFIG['output_file'], 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=['ID', 'description'])
        writer.writeheader()
        writer.writerows(results)
    print(f"‚ú® Termin√© ! Fichier : {TEST_CONFIG['output_file']}")

if __name__ == "__main__":
    generate_test_submission()

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

‚úÖ Poids charg√©s.


G√©n√©ration Test:   0%|          | 0/1000 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


The molecule is a 1,2-diacyl-sn-glycero-3-phosphoethanolamine in which the acyl groups at positions 1 and 2 are specified as hexadecanoyl and (9Z)-octadecenoyl respectively. It derives from an octadecanoic acid and a hexadecanoic acid.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
The molecule is a 1,2-diol that consists of (3R)-3-[(4S)-5-(hydroxyethyl)oxetan-2-yl]butanol having an alpha-D-glucosyl residue attached at position 6. It has a role as a metabolite and a human xenobiotic metabolite.
The molecule is a 3-hydroxy-5-methylhexanoate that consists of (R)-2,4-dimethoxyphenol carrying an acetyl group at position 1. It has a role as a metabolite and a plant metabolite.
The molecule is a member of the class of 2,3-dihydroxybenzoic acids that has an additional hydroxyl group attached to the benzene ring. It has a role as a metabolite and a human xenobiotic metabolite.
The molecule is a glycosylphosphatid

In [None]:
import torch
import torch.nn.functional as F
import pandas as pd
import pickle
from tqdm import tqdm

def final_rescue_mission(input_csv, output_csv, train_pkl_path, test_pkl_path, encoder_weights_path):
    # 1. Chargement des donn√©es
    print("üìÇ Chargement des fichiers de donn√©es...")
    with open(train_pkl_path, 'rb') as f:
        train_data = pickle.load(f)
    with open(test_pkl_path, 'rb') as f:
        test_data = pickle.load(f)

    # Lecture du fichier de soumission √† corriger
    df_sub = pd.read_csv(input_csv)

    # 2. Initialisation et chargement de l'encodeur GNN seul
    print(f"üì° Initialisation de l'encodeur et chargement des poids : {encoder_weights_path}")

    # On utilise les param√®tres de ton CONFIG
    encoder_only = FullEncoderModel(
        lap_dim=CONFIG['lap_dim'],
        rw_dim=CONFIG['rw_dim'],
        wl_vocab=CONFIG['wl_vocab'],
        hidden_dim=CONFIG['hidden_dim'],
        output_dim=CONFIG['output_dim'],
        max_k=11
    ).to(CONFIG['device'])

    # Chargement des poids .pt
    encoder_only.load_state_dict(torch.load(encoder_weights_path, map_location=CONFIG['device']))
    encoder_only.eval()

    device = CONFIG['device']
    train_embeddings = []
    train_descriptions = []

    # 3. Encodage du dataset d'entra√Ænement (Base de donn√©es de r√©f√©rence)
    print("üß† Calcul des vecteurs de structure pour le dataset Train...")
    with torch.no_grad():
        for g_data in tqdm(train_data):
            # Utilisation de ta fonction ultra_secure_batch pour pr√©parer la donn√©e
            g_batch = ultra_secure_batch(g_data, max_k=11, wl_vocab=CONFIG['wl_vocab']).to(device)

            # Ton encodeur renvoie (final_mlp(z_graph), h_final, data.batch)
            res = encoder_only(g_batch, g_batch.sorted_neighbors, g_batch.sorted_edges, g_batch.geo_features)
            h_graph = res[0] # On prend l'embedding global de la mol√©cule

            train_embeddings.append(h_graph.cpu())
            # On r√©cup√®re la description "v√©rit√© terrain"
            train_descriptions.append(getattr(g_data, 'description', ""))

    train_embeddings = torch.cat(train_embeddings, dim=0)

    # 4. Correction par recherche de similarit√© (KNN)
    print("ü©π Remplacement des √©checs de g√©n√©ration par les voisins les plus proches...")
    count_rescued = 0

    for i, row in df_sub.iterrows():
        desc = str(row['description'])

        # Crit√®res de secours : pr√©sence de "!!" ou description trop courte (< 25 caract√®res)
        if "!!" in desc or len(desc) < 25:
            g_test = test_data[i]
            g_test_batch = ultra_secure_batch(g_test, max_k=11, wl_vocab=CONFIG['wl_vocab']).to(device)

            with torch.no_grad():
                res_test = encoder_only(g_test_batch, g_test_batch.sorted_neighbors,
                                        g_test_batch.sorted_edges, g_test_batch.geo_features)
                h_test = res_test[0]

            # Calcul de similarit√© cosinus avec toutes les mol√©cules du Train
            similarities = F.cosine_similarity(h_test.cpu(), train_embeddings)
            best_match_idx = torch.argmax(similarities).item()

            # Remplacement par la description parfaite du voisin le plus proche
            df_sub.at[i, 'description'] = train_descriptions[best_match_idx]
            count_rescued += 1

    # 5. Sauvegarde du fichier final propre
    df_sub.to_csv(output_csv, index=False)
    print(f"\n‚úÖ Mission termin√©e !")
    print(f"üöë Mol√©cules sauv√©es (remplac√©es par le Train) : {count_rescued} / {len(df_sub)}")
    print(f"üíæ Fichier de soumission pr√™t : {output_csv}")

# --- EXECUTION ---
# Assure-toi que les chemins correspondent √† tes fichiers
final_rescue_mission(
    input_csv='submission.csv',
    output_csv='submission_final.csv',
    train_pkl_path='train_graphs_preprocessed.pkl',
    test_pkl_path='test_graphs_preprocessed.pkl',
    encoder_weights_path='model_best.pt' # Ou le nom de ton fichier .pt d'encodeur
)

üìÇ Chargement des fichiers de donn√©es...
üì° Initialisation de l'encodeur et chargement des poids : model_best.pt
üß† Calcul des vecteurs de structure pour le dataset Train...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 31008/31008 [02:46<00:00, 186.03it/s]


ü©π Remplacement des √©checs de g√©n√©ration par les voisins les plus proches...

‚úÖ Mission termin√©e !
üöë Mol√©cules sauv√©es (remplac√©es par le Train) : 42 / 1000
üíæ Fichier de soumission pr√™t : submission_final.csv
