In [1]:
!pip install rdkit torch_geometric mendeleev torch_scatter torch_cluster e3nn optuna --quiet

# ============================================================================
# Imports und Konstanten
# ============================================================================
from __future__ import annotations
import warnings, random, math, os, copy
warnings.filterwarnings("ignore", category=DeprecationWarning)

import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter
from typing import List, Dict

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski
import mendeleev, optuna

import torch, torch.nn as nn
from torch_geometric.data import Data
from torch_cluster import radius_graph
from torch_scatter import scatter_sum, scatter_add
from torch_geometric.nn import MetaLayer
from torch_geometric.loader import DataLoader

from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import OneHotEncoder

import seaborn as sns

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

SUBSET_K = 14
CSV_FILE = Path("/content/Ilm-NMR-P31.csv")
NUM_CONFORMERS = 1
RADIUS = 2  # Å
TRIALS_OPT = 15 #Hyperparameter Kombinationen mit Optuna

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

RDLogger.logger().setLevel(RDLogger.CRITICAL)

# ============================================================================
# Hilfsfunktionen
# ============================================================================

def clean_array(arr: np.ndarray) -> np.ndarray:
    """Ersetzt NaN sowie Inf durch 0 """
    return np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)

def clean_tensor(t: torch.Tensor) -> torch.Tensor:
    if torch.isnan(t).any() or torch.isinf(t).any():
        t = torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0)
    return t

def seed_worker(worker_id):
    np.random.seed(SEED + worker_id); random.seed(SEED + worker_id)

def process_graph(g):
    g_proc = copy.deepcopy(g)
    for name, sc in fold_scalers.items():
        idx = cont_idx[name]
        if not idx:
            continue

        # Hole die rohe Matrix …
        m = getattr(g_proc, name).numpy()

        # … erst einmal säubern
        m = clean_array(m)

        # … dann skalieren
        m[:, idx] = sc.transform(m[:, idx])

        # … und noch einmal säubern
        m = clean_array(m)

        setattr(g_proc, name, torch.tensor(m, dtype=torch.float32))

    # Shift ebenfalls sauber anwenden
    y_raw = g_proc.y.reshape(-1, 1).numpy()
    y_raw = clean_array(y_raw)
    y_scaled = fold_shift_scaler.transform(y_raw)
    g_proc.y = torch.tensor(y_scaled.flatten(), dtype=torch.float32)

    # Edge/Global zusammenführen
    g_proc.edge_attr = torch.cat([g_proc.edge_attr_topo, g_proc.edge_attr_geom], dim=1)
    g_proc.u = g_proc.u_topo
    del g_proc.edge_attr_topo, g_proc.edge_attr_geom, g_proc.u_topo

    # Finales Tensor-Cleaning (optional, aber sicher)
    g_proc.x = clean_tensor(g_proc.x)
    g_proc.edge_attr = clean_tensor(g_proc.edge_attr)
    g_proc.u = clean_tensor(g_proc.u)

    return g_proc

# ============================================================================
# 1. Daten einlesen und bereinigen
# ============================================================================

df2 = pd.read_csv(CSV_FILE, sep=r"\s+", quotechar='"', engine='python')
df2.columns = df2.columns.str.strip('"')

for col in df2.select_dtypes('object').columns:
    df2[col] = df2[col].str.strip('"')

for col in ['shift', 'MW', 'C', 'N', 'P', 'O']:
    df2[col] = pd.to_numeric(df2[col], errors='coerce')

df2.index = df2.index.astype(str).str.strip('"').astype(int)

df = df2.sample(n=SUBSET_K, random_state=SEED)  # Subset
print(f"Geladene Zeilen: {len(df)} – nach NaN‑Drop:", end=' ')
df.dropna(subset=['shift', 'cansmi'], inplace=True)
print(len(df))

# ============================================================================
# 2. Smiles in Mol überführen
# ============================================================================

tqdm.pandas()

def smiles2mol(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Chem.AddHs(mol) if mol else None

df['BaseMol'] = df['cansmi'].progress_apply(smiles2mol)
df_valid = df[df['BaseMol'].notnull()].copy()
print("Gültige Moleküle:", len(df_valid))

# ============================================================================
# 3. One Hot Encoder
# ============================================================================
formal_charges = sorted({a.GetFormalCharge() for m in df_valid.BaseMol for a in m.GetAtoms()})
enc_charge = OneHotEncoder(sparse_output=False).fit(np.array(formal_charges).reshape(-1,1))
oh_charge = lambda c: enc_charge.transform([[c]]).ravel()

hybs = [getattr(Chem.rdchem.HybridizationType, h) for h in ('UNSPECIFIED','S','SP','SP2','SP3','SP3D','SP3D2','OTHER')]
enc_hyb = OneHotEncoder(sparse_output=False).fit(np.array([int(h) for h in hybs]).reshape(-1,1))
oh_hyb = lambda h: enc_hyb.transform([[int(h)]]).ravel()

atom_list = ['P','Al','As','Au','B','Br','C','Cl','Cs','F','Fe','Ge','H','Hg','I','K','Li','Mo','N','Na','Nb','Ni','O','Pb','Pd','Pt','S','Sb','Se','Si','Sn','Ta','Te','V','W']
enc_atom = OneHotEncoder(categories=[atom_list], sparse_output=False).fit(np.array(atom_list).reshape(-1,1))
oh_atom = lambda s: enc_atom.transform([[s]]).ravel()

bond_types = sorted({b.GetBondType() for m in df_valid.BaseMol for b in m.GetBonds()}, key=int)
enc_bond = OneHotEncoder(sparse_output=False).fit(np.array([int(b) for b in bond_types]).reshape(-1,1))
oh_bond = lambda bt: enc_bond.transform([[int(bt)]]).ravel()

# ============================================================================
# 4. Features berechnen
# ============================================================================

EL_CACHE: Dict[int, mendeleev.Element] = {}

def EL(n):
    if n not in EL_CACHE:
        EL_CACHE[n] = mendeleev.element(n)
    return EL_CACHE[n]

def ring_size(obj):
    if not obj.IsInRing():
        return 0
    for i in range(10,2,-1):
        if obj.IsInRingSize(i):
            return i
    return 0

def atom_features(mol, atom, ranks):
    me = EL(atom.GetAtomicNum())
    gchg = float(atom.GetProp('_GasteigerCharge')) if atom.HasProp('_GasteigerCharge') else 0.
    feats = list(oh_atom(atom.GetSymbol())) + [
        atom.GetDegree(), me.atomic_radius or 0, me.atomic_volume or 0] + \
        list(oh_charge(atom.GetFormalCharge())) + [
        me.covalent_radius or 0, me.vdw_radius or 0, me.en_pauling or 0,
        me.electrons or 0, me.neutrons or 0,
        int(atom.GetChiralTag()), int(atom.IsInRing()), int(atom.GetIsAromatic())] + \
        list(oh_hyb(atom.GetHybridization())) + [
        atom.GetMass(), atom.GetNumRadicalElectrons(), atom.GetTotalValence(),
        ring_size(atom), ranks[atom.GetIdx()], gchg]
    ri = mol.GetRingInfo()
    feats += [sum(1 for r in ri.AtomRings() if atom.GetIdx() in r),
              sum(1 for b in atom.GetBonds() if b.GetIsAromatic()),
              sum(1 for n in atom.GetNeighbors() if n.GetAtomicNum()!=1),
              sum(1 for n in atom.GetNeighbors() if n.GetAtomicNum()!=6),
              gchg - atom.GetFormalCharge()]
    return feats

def bond_features(mol, bond):
    a1,a2 = bond.GetBeginAtom(), bond.GetEndAtom()
    en_diff = abs((EL(a1.GetAtomicNum()).en_pauling or 0)-(EL(a2.GetAtomicNum()).en_pauling or 0))
    feats = list(oh_bond(bond.GetBondType())) + [
        float(bond.GetIsConjugated()), float(bond.GetIsAromatic()), float(bond.IsInRing()),
        ring_size(bond), float(bond.GetStereo()),
        1. if (bond.GetBondType()==Chem.BondType.SINGLE and not bond.IsInRing()) else 0.,
        en_diff, float(mol.GetRingInfo().NumBondRings(bond.GetIdx())>1)]
    return feats

def global_features(mol):
    safe = lambda f, d=0.: (f(mol) if callable(f) else f) if not pd.isna(f) else d
    feats=[mol.GetNumAtoms(),mol.GetNumBonds(),mol.GetNumHeavyAtoms(),
           safe(Descriptors.MolWt),safe(Crippen.MolLogP),safe(Descriptors.TPSA),safe(Crippen.MolMR)]
    cnt=Counter(a.GetSymbol() for a in mol.GetAtoms());
    ri=mol.GetRingInfo(); feats+=[ri.NumRings(),safe(Lipinski.NumAromaticRings,0),safe(Lipinski.NumAliphaticRings,0),
                                  safe(Lipinski.NumRotatableBonds,0),safe(Lipinski.NumHDonors,0),safe(Lipinski.NumHAcceptors,0),
                                  safe(Descriptors.BalabanJ),safe(Descriptors.Kappa2)]
    return np.array(feats,dtype=np.float32)

# ============================================================================
# 5. Topo graphen
# ============================================================================

topo_graphs=[]
for mol in df_valid.BaseMol:
    AllChem.ComputeGasteigerCharges(mol)

for idx,row in tqdm(df_valid.iterrows(), total=len(df_valid)):
    mol=row.BaseMol; ranks=Chem.CanonicalRankAtoms(mol)
    node=[atom_features(mol,a,ranks) for a in mol.GetAtoms()]
    ei,ea=[],[]
    for b in mol.GetBonds():
        f=bond_features(mol,b); i,j=b.GetBeginAtomIdx(),b.GetEndAtomIdx()
        ei+= [[i,j],[j,i]]; ea+=[f,f]
    topo_graphs.append(dict(molecule_id=idx,smiles=row.cansmi,x=np.asarray(node,np.float32),
                            edge_index=np.asarray(ei,np.int64).T,edge_attr=np.asarray(ea,np.float32),
                            u=global_features(mol),y=float(row['shift'])))
print("Topo‑Graphen:", len(topo_graphs))

# ============================================================================
# 6. Konformere und Geometrie
# ============================================================================

def embed(smi: str):
    m = Chem.AddHs(Chem.MolFromSmiles(smi))
    p = AllChem.ETKDGv3(); p.randomSeed = SEED
    cids = list(AllChem.EmbedMultipleConfs(m, NUM_CONFORMERS, p))

    # Fallback: Einzelkonformer, falls mehrere nicht klappten
    if not cids and AllChem.EmbedMolecule(m, randomSeed=SEED) == -1:
        return None

    # prüfe, ob UFF Parameter kennt
    if not AllChem.UFFHasAllMoleculeParams(m):
        # optional MMFF verwenden
        return None

    try:
        res = AllChem.UFFOptimizeMoleculeConfs(m)
    except RuntimeError:
        return None

    # nur Konformer behalten die sauber optimiert wurden
    good_cids = [cid for cid, (code, _) in zip(cids or [0], res)
                 if code == 0 and not np.isnan(m.GetConformer(cid).GetPositions()).any()]

    return m if good_cids else None


def geom_dist(pos, eidx):
    """Berechnet die euklidische Distanz für gegebene Kanten."""
    r, c = eidx
    return torch.linalg.norm(pos[r] - pos[c], dim=1, keepdim=True)

final_graphs = []
for g in tqdm(topo_graphs):
    m3 = embed(g['smiles'])
    if m3 is None:
        continue

    for conf in m3.GetConformers():
        pos = torch.tensor(conf.GetPositions(), dtype=torch.float32)

        # Topologische Kanten
        tei = torch.tensor(g['edge_index'], dtype=torch.long)
        tea = torch.tensor(g['edge_attr'], dtype=torch.float32)

        # Feautures für kovalente Bindungen
        geom_cov = geom_dist(pos, tei)
        flag_cov = torch.tensor([[1., 0.]], dtype=torch.float32).repeat(tea.size(0), 1)

        #räumliche Kanten
        sei_all = radius_graph(pos, r=RADIUS, loop=False)

        #Redundante Kanten entfernen
        covalent_edges_set = set(
            tuple(sorted(pair)) for pair in tei.t().tolist()
        )

        # nur Kanten die NICHT kovalent sind beibhalten
        non_covalent_spatial_indices = []
        for i in range(sei_all.size(1)):
            u, v = sei_all[:, i].tolist()
            if tuple(sorted((u, v))) not in covalent_edges_set:
                non_covalent_spatial_indices.append(i)

        sei_clean = sei_all[:, non_covalent_spatial_indices]

        # feature für geometrische Kanten
        geom_spa = geom_dist(pos, sei_clean)
        flag_spa = torch.tensor([[0., 1.]], dtype=torch.float32).repeat(sei_clean.size(1), 1)

        placeholder_spa = torch.zeros(sei_clean.size(1), tea.size(1), dtype=torch.float32)

        edge_attr_topo_cov = torch.cat([flag_cov, tea], dim=1)
        edge_attr_topo_spa = torch.cat([flag_spa, placeholder_spa], dim=1)
        final_edge_attr_topo = torch.cat([edge_attr_topo_cov, edge_attr_topo_spa], dim=0)
        final_edge_attr_geom = torch.cat([geom_cov, geom_spa], dim=0)

        final_graphs.append(Data(
            x=torch.tensor(g['x'], dtype=torch.float32),
            edge_index=torch.cat([tei, sei_clean], dim=1),
            edge_attr_topo=final_edge_attr_topo,
            edge_attr_geom=final_edge_attr_geom,
            pos=pos,
            u_topo=torch.tensor(g['u'], dtype=torch.float32).unsqueeze(0),
            y=torch.tensor([g['y']], dtype=torch.float32),
            molecule_id=g['molecule_id'],
            conformer_id=conf.GetId(),
            is_p_atom=torch.tensor([a.GetSymbol() == 'P' for a in m3.GetAtoms()])
        ))

print("Finale Graphen (ohne redundante Kanten):", len(final_graphs))

# ============================================================================
# 7.Split, Train, Test, VAl
# ============================================================================

mids=np.array([g.molecule_id for g in final_graphs])
train_ids,temp_ids=train_test_split(np.unique(mids),test_size=0.2,random_state=SEED)
val_ids,test_ids=train_test_split(temp_ids,test_size=0.5,random_state=SEED)
SPLIT={p:[g for g in final_graphs if g.molecule_id in ids] for p,ids in
       zip(['train','val','test'],[train_ids,val_ids,test_ids])}
print({k:len(v) for k,v in SPLIT.items()})

# ============================================================================
# 8. Skalierung und Bereinigung
# ============================================================================

def is_onehot(col):
    u=np.unique(col); return (len(u)==1) or (set(u)<= {0,1})

def continuous_idx(mat):
    idx=[]
    for i in range(mat.shape[1]):
        col=mat[:,i]
        if is_onehot(col): continue
        if np.allclose(col,col[0]): continue
        idx.append(i)
    return idx

train_mats={k:np.vstack([getattr(g,k).numpy() for g in SPLIT['train']]) for k in ['x','edge_attr_topo','edge_attr_geom','u_topo']}

for k in train_mats:
    train_mats[k] = clean_array(train_mats[k])

cont_idx={k:continuous_idx(v) for k,v in train_mats.items()}
SCALERS={}
for k,idx in cont_idx.items():
    if not idx: continue
    sc=MinMaxScaler() if k=='edge_attr_geom' else StandardScaler()
    sc.fit(train_mats[k][:,idx]); SCALERS[k]=sc
print('Scaler:',list(SCALERS))

shift_scaler=StandardScaler().fit(df.loc[train_ids,'shift'].values.reshape(-1,1))

def apply_scaling(g):
    for name, sc in SCALERS.items():
        idx = cont_idx[name]
        m = getattr(g, name).numpy()
        m = clean_array(m)

        if idx:
            m[:, idx] = sc.transform(m[:, idx])

        m = clean_array(m)
        setattr(g, name, torch.tensor(m, dtype=torch.float32))
    return g


for phase in SPLIT:
    SPLIT[phase]=[apply_scaling(copy.deepcopy(g)) for g in SPLIT[phase]]

for phase,ids in zip(['train','val','test'],[train_ids,val_ids,test_ids]):
    scaled=shift_scaler.transform(df.loc[ids,'shift'].values.reshape(-1,1)).flatten();
    map_y=dict(zip(ids,scaled))
    for g in SPLIT[phase]:
        g.y=torch.tensor([map_y[g.molecule_id]],dtype=torch.float32)

# combine und clean
for phase in SPLIT:
    for g in SPLIT[phase]:
        g.edge_attr=torch.cat([g.edge_attr_topo,g.edge_attr_geom],dim=1); g.u=g.u_topo
        del g.edge_attr_topo,g.edge_attr_geom,g.u_topo
        g.x=clean_tensor(g.x); g.edge_attr=clean_tensor(g.edge_attr); g.u=clean_tensor(g.u)
        for name in ('x','edge_attr','u'):
            t=getattr(g,name)
            if torch.isnan(t).any() or torch.isinf(t).any():
                raise ValueError(f"{name} noch NaN/Inf nach Reinigung – Mol {g.molecule_id}")
print('Alle Tensoren gereinigt.')

# ============================================================================
# 9. Modell sowie Train und Eval
# ============================================================================
class EdgeModel(nn.Module):
    def __init__(self, node_in_dim, edge_in_dim, hidden_dim):
        super(EdgeModel, self).__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * node_in_dim + edge_in_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest, edge_attr], dim=1)
        return self.edge_mlp(out)

class NodeModel(nn.Module):
    def __init__(self, node_in_dim, hidden_dim):
        super(NodeModel, self).__init__()
        self.node_mlp = nn.Sequential(
            nn.Linear(node_in_dim + hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        agg = scatter_sum(edge_attr, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, agg], dim=1)
        return self.node_mlp(out)

class GlobalModel(nn.Module):
    def __init__(self, global_in_dim, hidden_dim):
        super(GlobalModel, self).__init__()
        self.global_mlp = nn.Sequential(
            nn.Linear(global_in_dim + hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
    def forward(self, x, edge_index, edge_attr, u, batch):
        node_agg = scatter_sum(x, batch, dim=0)
        out = torch.cat([u, node_agg], dim=1)
        return self.global_mlp(out)

class MetaGNN(nn.Module):
    def __init__(self, node_in_dim, edge_in_dim, global_in_dim, hidden_dim):
        super(MetaGNN, self).__init__()
        self.meta_layer = MetaLayer(
            edge_model=EdgeModel(node_in_dim, edge_in_dim, hidden_dim),
            node_model=NodeModel(node_in_dim, hidden_dim),
            global_model=GlobalModel(global_in_dim, hidden_dim)
        )
        self.final_mlp = nn.Sequential(nn.Linear(hidden_dim, 1))

    def forward(self, data):
        x, edge_index, edge_attr, u, batch = data.x, data.edge_index, data.edge_attr, data.u, data.batch
        x, edge_attr, u = self.meta_layer(x, edge_index, edge_attr, u, batch)
        out = self.final_mlp(u)
        return out.squeeze()

def train_epoch(model,loader,opt,crit):
    model.train(); tot=0
    for d in loader:
        d=d.to(DEVICE); opt.zero_grad(); loss=crit(model(d),d.y.squeeze()); loss.backward(); nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step(); tot+=loss.item()*d.num_graphs
    return tot/len(loader.dataset)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def eval_loader(model, loader, ysc, train_y_unscaled):
    model.eval()
    preds_unscaled, targets_unscaled, molecule_ids = [], [], []
    with torch.no_grad():
        for d in loader:
            d = d.to(DEVICE)
            predictions = model(d).cpu()
            preds_unscaled.append(ysc.inverse_transform(predictions.numpy().reshape(-1, 1)).ravel())
            targets_unscaled.append(ysc.inverse_transform(d.y.cpu().numpy().reshape(-1, 1)).ravel())
            molecule_ids.extend(d.molecule_id.tolist())

    preds_unscaled = np.concatenate(preds_unscaled)
    targets_unscaled = np.concatenate(targets_unscaled)

    df_preds = pd.DataFrame({
        'molecule_id': molecule_ids,
        'final_prediction': preds_unscaled,
        'target': targets_unscaled
    })

    df_agg = df_preds.groupby('molecule_id').mean()
    p_agg, t_agg = df_agg['final_prediction'], df_agg['target']

    mae = (p_agg - t_agg).abs().mean()
    rmse = np.sqrt(mean_squared_error(t_agg, p_agg))

    mae_naive = np.mean(np.abs(t_agg - np.mean(train_y_unscaled)))
    mase = mae / (mae_naive + 1e-8)

    metrics = {'mae': float(mae), 'rmse': float(rmse), 'mase': float(mase)}

    return metrics, df_preds


# ============================================================================
# 10. Optuna Hyperparameteroptimierung
# ============================================================================
outer_k = 3; inner_k = 3; max_epochs = 150
unique_mids = np.unique([g.molecule_id for g in final_graphs])
outer_cv = KFold(n_splits=outer_k, shuffle=True, random_state=SEED)
summary = []
all_test_predictions = pd.DataFrame()

# prüfe ob genügend Daten für die Kreuzvalidierung vorhanden sind
if len(unique_mids) >= outer_k:
    for fold, (train_val_idx, test_idx) in enumerate(outer_cv.split(unique_mids), 1):
        print(f"\n===== ÄUẞERER FOLD {fold}/{outer_k} =====")
        train_val_ids = unique_mids[train_val_idx]
        test_ids = unique_mids[test_idx]

        graphs_train_val = [copy.deepcopy(g) for g in final_graphs if g.molecule_id in train_val_ids]
        graphs_test = [copy.deepcopy(g) for g in final_graphs if g.molecule_id in test_ids]

        train_mats = {k: np.vstack([getattr(g, k).numpy() for g in graphs_train_val]) for k in ['x', 'edge_attr_topo', 'edge_attr_geom', 'u_topo']}
        for k in train_mats:
            train_mats[k] = clean_array(train_mats[k])
        cont_idx = {k: continuous_idx(v) for k, v in train_mats.items()}
        fold_scalers = {}
        for k, idx in cont_idx.items():
            if not idx: continue
            sc = MinMaxScaler() if k == 'edge_attr_geom' else StandardScaler()
            sc.fit(train_mats[k][:, idx]); fold_scalers[k] = sc

        fold_shift_scaler = StandardScaler().fit(df.loc[train_val_ids, 'shift'].values.reshape(-1, 1))

        train_y_unscaled = df.loc[train_val_ids, 'shift'].values

        processed_train_val = [process_graph(g) for g in graphs_train_val]
        processed_test = [process_graph(g) for g in graphs_test]

        node_in_dim = processed_train_val[0].x.shape[1]
        edge_in_dim = processed_train_val[0].edge_attr.shape[1]
        global_in_dim = processed_train_val[0].u.shape[1]

        study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler(seed=SEED))
        def objective(trial):
            hidden_dim = trial.suggest_categorical('hidden_dim', [32, 64, 128, 256])
            lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
            wd = trial.suggest_float('wd', 1e-8, 1e-3, log=True)
            bs = trial.suggest_categorical('bs', [16, 32, 64])

            inner_cv = KFold(n_splits=inner_k, shuffle=True, random_state=SEED)
            maes = []
            for tr_idx, val_idx in inner_cv.split(train_val_ids):
                tr_ids, val_ids = train_val_ids[tr_idx], train_val_ids[val_idx]
                g_tr = [g for g in processed_train_val if g.molecule_id in tr_ids]
                g_val = [g for g in processed_train_val if g.molecule_id in val_ids]
                if not g_tr or not g_val: continue

                inner_train_y_unscaled = df.loc[tr_ids, 'shift'].values

                dl_tr = DataLoader(g_tr, batch_size=bs, shuffle=True, worker_init_fn=seed_worker, drop_last=True)
                dl_val = DataLoader(g_val, batch_size=min(bs, len(g_val)))

                model = MetaGNN(
                    node_in_dim=node_in_dim,
                    edge_in_dim=edge_in_dim,
                    global_in_dim=global_in_dim,
                    hidden_dim=hidden_dim
                ).to(DEVICE)

                opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
                crit = nn.L1Loss()

                for ep in range(max_epochs):
                    train_epoch(model, dl_tr, opt, crit)

                val_metrics, _ = eval_loader(model, dl_val, fold_shift_scaler, inner_train_y_unscaled)
                maes.append(val_metrics['mae'])
                trial.report(np.mean(maes), len(maes))
                if trial.should_prune():
                    raise optuna.TrialPruned()
            return np.mean(maes)

        study.optimize(objective, n_trials=TRIALS_OPT, timeout=1800)
        best = study.best_trial.params
        print(f"Fold {fold}: Bester innerer MAE {study.best_value:.4f} – Params {best}")

        dl_tr = DataLoader(processed_train_val, batch_size=best['bs'], shuffle=True, worker_init_fn=seed_worker, drop_last=True)
        dl_te = DataLoader(processed_test, batch_size=min(best['bs'], len(processed_test)))

        model = MetaGNN(
            node_in_dim=node_in_dim,
            edge_in_dim=edge_in_dim,
            global_in_dim=global_in_dim,
            hidden_dim=best['hidden_dim']
        ).to(DEVICE)

        opt = torch.optim.Adam(model.parameters(), lr=best['lr'], weight_decay=best['wd'])
        crit = nn.L1Loss()
        for ep in range(max_epochs):
            train_epoch(model, dl_tr, opt, crit)

        num_params_millions = count_parameters(model) / 1_000_000
        test_metrics, final_test_predictions = eval_loader(model, dl_te, fold_shift_scaler, train_y_unscaled)

        fold_summary = {
            'fold': fold,
            'mae': test_metrics['mae'],
            'rmse': test_metrics['rmse'],
            'mase': test_metrics['mase'],
            'params_M': num_params_millions,
            'hyperparams': best
        }
        summary.append(fold_summary)

        print(f"  → Test-Ergebnisse Fold {fold}:")
        print(f"    - MAE: {test_metrics['mae']:.4f}")
        print(f"    - RMSE: {test_metrics['rmse']:.4f}")
        print(f"    - MASE: {test_metrics['mase']:.4f}")
        print(f"    - Trainierbare Parameter: {num_params_millions:.3f}M")

        TITLE_FONTSIZE, LABEL_FONTSIZE, TICK_FONTSIZE, LEGEND_FONTSIZE, METRICS_FONTSIZE = 22, 20, 16, 16, 15
        plt.style.use('default')
        fig, ax = plt.subplots(figsize=(8, 8))
        fig.patch.set_facecolor('white')
        ax.set_facecolor('white')

        sns.scatterplot(
            x='target', y='final_prediction', data=final_test_predictions,
            ax=ax, alpha=0.7, edgecolor='b', s=20, label='Vorhersagen'
        )

        min_val = min(final_test_predictions['target'].min(), final_test_predictions['final_prediction'].min())
        max_val = max(final_test_predictions['target'].max(), final_test_predictions['final_prediction'].max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2.5, label='Perfekte Vorhersage')

        metrics_text = (
            f"Test MAE: {test_metrics['mae']:.3f}\n"
            f"Test RMSE: {test_metrics['rmse']:.3f}\n"
            f"Test MASE: {test_metrics['mase']:.3f}\n"
            f"Parameter: {num_params_millions:.3f}M"
        )
        ax.text(0.05, 0.95, metrics_text, transform=ax.transAxes, fontsize=METRICS_FONTSIZE,
                verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', facecolor='aliceblue', edgecolor='black', alpha=0.8))

        ax.set_xlabel('Experimenteller Shift', fontsize=LABEL_FONTSIZE)
        ax.set_ylabel('Vorhergesagter Shift', fontsize=LABEL_FONTSIZE)
        ax.set_title(f'Modellleistung auf Testdaten (Fold {fold})', fontsize=TITLE_FONTSIZE, pad=20)
        ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
        ax.legend(loc='lower right', fontsize=LEGEND_FONTSIZE)
        ax.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        plot_filename = f"sdgnn_k{int(SUBSET_K)}_fold_{fold}_performance.png"
        plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
        print(f"Plot gespeichert unter: {plot_filename}")
        plt.close(fig)

# ============================================================================
# Ergebnisse:
# ============================================================================
print("\n=== Kreuzvalidierungs-Zusammenfassung ===")
if summary:
    summary_df = pd.DataFrame(summary).set_index('fold')
    print(summary_df[['mae', 'rmse', 'mase', 'params_M']])

    if not summary_df.empty:
        print("\n--- Durchschnittliche Leistung ---")
        print(summary_df[['mae', 'rmse', 'mase']].mean())
else:
    print("Keine Kreuzvalidierung durchgeführt (nicht genügend Daten).")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m448.0/448.0 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

  0%|          | 0/14 [00:00<?, ?it/s]

Gültige Moleküle: 14


  0%|          | 0/14 [00:00<?, ?it/s]

Topo‑Graphen: 14


  0%|          | 0/14 [00:00<?, ?it/s]

[I 2025-07-22 13:08:09,937] A new study created in memory with name: no-name-25273c8b-bd0b-4752-bfdb-29906d68d2ff


Finale Graphen (ohne redundante Kanten): 9
{'train': 7, 'val': 1, 'test': 1}
Scaler: ['x', 'edge_attr_topo', 'edge_attr_geom', 'u_topo']
Alle Tensoren gereinigt.

===== ÄUẞERER FOLD 1/3 =====


[I 2025-07-22 13:08:10,097] Trial 0 finished with value: 37.00062624613444 and parameters: {'hidden_dim': 64, 'lr': 0.0002051338263087451, 'wd': 6.02521573620385e-08, 'bs': 32}. Best is trial 0 with value: 37.00062624613444.
[I 2025-07-22 13:08:10,222] Trial 1 finished with value: 40.63613255818685 and parameters: {'hidden_dim': 128, 'lr': 0.00026587543983272726, 'wd': 8.111941985431907e-08, 'bs': 64}. Best is trial 0 with value: 37.00062624613444.
[I 2025-07-22 13:08:10,339] Trial 2 finished with value: 43.01049041748047 and parameters: {'hidden_dim': 128, 'lr': 0.0003839629299804173, 'wd': 6.78905327169848e-07, 'bs': 32}. Best is trial 0 with value: 37.00062624613444.
[I 2025-07-22 13:08:10,450] Trial 3 finished with value: 37.59418328603109 and parameters: {'hidden_dim': 256, 'lr': 0.00021930485556643703, 'wd': 2.1147447960615646e-08, 'bs': 32}. Best is trial 0 with value: 37.00062624613444.
[I 2025-07-22 13:08:10,556] Trial 4 finished with value: 41.303881327311196 and parameters: 

Fold 1: Bester innerer MAE 35.8563 – Params {'hidden_dim': 64, 'lr': 0.00012315571723666037, 'wd': 4.233032996527588e-07, 'bs': 64}
  → Test-Ergebnisse Fold 1:
    - MAE: 69.9848
    - RMSE: 87.2717
    - MASE: 1.0318
    - Trainierbare Parameter: 0.036M


[I 2025-07-22 13:08:12,299] A new study created in memory with name: no-name-ee3395cd-5c59-40d9-8713-6a5cf60ce06b
[I 2025-07-22 13:08:12,403] Trial 0 finished with value: 47.59460322062174 and parameters: {'hidden_dim': 64, 'lr': 0.0002051338263087451, 'wd': 6.02521573620385e-08, 'bs': 32}. Best is trial 0 with value: 47.59460322062174.


Plot gespeichert unter: sdgnn_k14_fold_1_performance.png

===== ÄUẞERER FOLD 2/3 =====


[I 2025-07-22 13:08:12,507] Trial 1 finished with value: 45.70199330647787 and parameters: {'hidden_dim': 128, 'lr': 0.00026587543983272726, 'wd': 8.111941985431907e-08, 'bs': 64}. Best is trial 1 with value: 45.70199330647787.
[I 2025-07-22 13:08:12,611] Trial 2 finished with value: 46.79461415608724 and parameters: {'hidden_dim': 128, 'lr': 0.0003839629299804173, 'wd': 6.78905327169848e-07, 'bs': 32}. Best is trial 1 with value: 45.70199330647787.
[I 2025-07-22 13:08:12,737] Trial 3 finished with value: 44.91835403442383 and parameters: {'hidden_dim': 256, 'lr': 0.00021930485556643703, 'wd': 2.1147447960615646e-08, 'bs': 32}. Best is trial 3 with value: 44.91835403442383.
[I 2025-07-22 13:08:12,853] Trial 4 finished with value: 53.27860895792643 and parameters: {'hidden_dim': 128, 'lr': 0.00017541893487450815, 'wd': 2.9914693021302164e-06, 'bs': 32}. Best is trial 3 with value: 44.91835403442383.
[I 2025-07-22 13:08:12,894] Trial 5 pruned. 
[I 2025-07-22 13:08:12,969] Trial 6 pruned.

Fold 2: Bester innerer MAE 37.5563 – Params {'hidden_dim': 32, 'lr': 0.000739115897463125, 'wd': 4.471201743493389e-05, 'bs': 16}
  → Test-Ergebnisse Fold 2:
    - MAE: 70.3563
    - RMSE: 77.6632
    - MASE: 0.9403
    - Trainierbare Parameter: 0.013M


[I 2025-07-22 13:08:14,393] A new study created in memory with name: no-name-81edaf54-78d4-46c3-8875-d177fe7646dc
[I 2025-07-22 13:08:14,495] Trial 0 finished with value: 71.27907816569011 and parameters: {'hidden_dim': 64, 'lr': 0.0002051338263087451, 'wd': 6.02521573620385e-08, 'bs': 32}. Best is trial 0 with value: 71.27907816569011.


Plot gespeichert unter: sdgnn_k14_fold_2_performance.png

===== ÄUẞERER FOLD 3/3 =====


[I 2025-07-22 13:08:14,613] Trial 1 finished with value: 71.18150583902995 and parameters: {'hidden_dim': 128, 'lr': 0.00026587543983272726, 'wd': 8.111941985431907e-08, 'bs': 64}. Best is trial 1 with value: 71.18150583902995.
[I 2025-07-22 13:08:14,717] Trial 2 finished with value: 72.62646865844727 and parameters: {'hidden_dim': 128, 'lr': 0.0003839629299804173, 'wd': 6.78905327169848e-07, 'bs': 32}. Best is trial 1 with value: 71.18150583902995.
[I 2025-07-22 13:08:14,836] Trial 3 finished with value: 68.00106557210286 and parameters: {'hidden_dim': 256, 'lr': 0.00021930485556643703, 'wd': 2.1147447960615646e-08, 'bs': 32}. Best is trial 3 with value: 68.00106557210286.
[I 2025-07-22 13:08:14,947] Trial 4 finished with value: 66.46744791666667 and parameters: {'hidden_dim': 128, 'lr': 0.00017541893487450815, 'wd': 2.9914693021302164e-06, 'bs': 32}. Best is trial 4 with value: 66.46744791666667.
[I 2025-07-22 13:08:14,987] Trial 5 pruned. 
[I 2025-07-22 13:08:15,095] Trial 6 finishe

Fold 3: Bester innerer MAE 66.2191 – Params {'hidden_dim': 64, 'lr': 0.00012315571723666037, 'wd': 4.233032996527588e-07, 'bs': 64}
  → Test-Ergebnisse Fold 3:
    - MAE: 38.1310
    - RMSE: 40.0003
    - MASE: 0.9860
    - Trainierbare Parameter: 0.036M
Plot gespeichert unter: sdgnn_k14_fold_3_performance.png

=== Kreuzvalidierungs-Zusammenfassung ===
            mae       rmse      mase  params_M
fold                                          
1     69.984795  87.271745  1.031768  0.035841
2     70.356300  77.663188  0.940299  0.012801
3     38.131008  40.000342  0.985977  0.035841

--- Durchschnittliche Leistung ---
mae     59.490701
rmse    68.311758
mase     0.986014
dtype: float64
