In [1]:
# =========================
# Cell 1 | Imports & 全局配置
# =========================
import os, json, math, random
from pathlib import Path
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

from sklearn.model_selection import GroupShuffleSplit, GroupKFold, RandomizedSearchCV
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from scipy.stats import randint

from tqdm import tqdm

from rdkit import Chem
from rdkit.Chem import rdchem
from rdkit.Chem import Crippen, Lipinski, Descriptors, GraphDescriptors, AllChem
from rdkit.Chem.EState import EStateIndices
from rdkit.Chem.rdMolDescriptors import (
    CalcLabuteASA, CalcTPSA, CalcNumAromaticRings,
    CalcFractionCSP3, CalcKappa1, CalcKappa2, CalcKappa3,
)

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, global_mean_pool

from transformers import RobertaModel, RobertaTokenizer

# -------- 路径：按你的实际情况改 --------
META_XLSX = Path("/root/fusion_dataset/Aqutic_unique.xlsx")  # 元数据集总表
HAZARD_XLSX = Path("/root/危险化合物应用/hazard_ready/hazard_filtered_with_physchem.xlsx")
HAZARD_GRAPH_PT = Path("/root/危险化合物应用/hazard_ready/hazard_graph_list.pt")
HAZARD_GRAPH_ROWID = Path("/root/危险化合物应用/hazard_ready/hazard_row_id_graph.npy")

# ChemBERTa 本地目录（必须存在）
CHEMBERTA_DIR = Path("/root/多模态/model")  # <<< 你改这里（例如你之前工程用的那个目录）

# 输出目录
OUT_BASE = Path("/root/危险化合物应用/apply_outputs")
OUT_BASE.mkdir(parents=True, exist_ok=True)

# 训练开关（建议先小规模跑通）
FAST_DEV_RUN = False   # True: 每个任务只用少量样本+少量epoch快速验通
DO_RF_SEARCH = False   # True: 对 fused embedding 做 RF 随机搜索（更慢）

# 随机种子
SEED = 42
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device =", device)


device = cuda


In [2]:
# =========================
# Cell 2 | Tasks 定义 + 工具函数
# =========================
TASK_SPECS = {
    # Acute EC50
    "ALGAE_EC50_72_POP": dict(species="algae", endpoint="EC50", duration=72, effect="POP"),
    "INV_EC50_48_ITX":   dict(species="invertebrates", endpoint="EC50", duration=48, effect="ITX"),
    "FISH_EC50_96_MOR":  dict(species="fish", endpoint="EC50", duration=96, effect="MOR"),
    # Chronic EC10
    "ALGAE_EC10_72_POP": dict(species="algae", endpoint="EC10", duration=72, effect="POP"),
    "INV_EC10_48_ITX":   dict(species="invertebrates", endpoint="EC10", duration=48, effect="ITX"),
    "FISH_EC10_96_MOR":  dict(species="fish", endpoint="EC10", duration=96, effect="MOR"),
}

def normalize_species(s: str) -> str:
    return str(s).strip().lower()

def compute_metrics(y_true, y_pred):
    y_true = np.asarray(y_true, float)
    y_pred = np.asarray(y_pred, float)
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    return dict(
        r2=float(r2_score(y_true, y_pred)),
        mae=float(mean_absolute_error(y_true, y_pred)),
        rmse=rmse,
    )

def safe_log10_mgL(x):
    # 你的 mgperL 是 mg/L；这里用 log10，<=0 的直接丢弃
    x = float(x)
    if (x is None) or (not np.isfinite(x)) or (x <= 0):
        return None
    return math.log10(x)

print("Tasks:", list(TASK_SPECS.keys()))


Tasks: ['ALGAE_EC50_72_POP', 'INV_EC50_48_ITX', 'FISH_EC50_96_MOR', 'ALGAE_EC10_72_POP', 'INV_EC10_48_ITX', 'FISH_EC10_96_MOR']


In [3]:
# =========================
# Cell 3 | 读取元数据集 + 读取名录(已过滤) + 基础检查
# =========================
meta_df = pd.read_excel(META_XLSX, engine="openpyxl")
print("meta_df:", meta_df.shape, meta_df.columns.tolist())

haz_df = pd.read_excel(HAZARD_XLSX, engine="openpyxl")
print("haz_df:", haz_df.shape)

# 必须列检查
need_meta_cols = {"SMILES_Canonical_RDKit","Duration_Value(hour)","Effect","Endpoint","mgperL","Species Group"}
assert need_meta_cols.issubset(meta_df.columns), f"meta_df 缺列: {need_meta_cols - set(meta_df.columns)}"

need_haz_cols = {"row_id","SMILES_Canonical_RDKit"}
assert need_haz_cols.issubset(haz_df.columns), f"haz_df 缺列: {need_haz_cols - set(haz_df.columns)}"

# hazard graph
assert HAZARD_GRAPH_PT.exists(), f"缺文件: {HAZARD_GRAPH_PT}"
assert HAZARD_GRAPH_ROWID.exists(), f"缺文件: {HAZARD_GRAPH_ROWID}"

haz_graph_list = torch.load(HAZARD_GRAPH_PT, map_location="cpu")
haz_rowid_graph = np.load(HAZARD_GRAPH_ROWID)
print("haz_graph_list:", len(haz_graph_list), "haz_rowid_graph:", haz_rowid_graph.shape)

# hazard physchem columns（你过滤时已经写进去了）
PHYS_KEYS = [
    "DESC_MolWt","DESC_ExactMolWt","DESC_HeavyAtomCount","DESC_RingCount","DESC_NumAromaticRings",
    "DESC_FractionCSP3","DESC_MolLogP","DESC_TPSA","DESC_ASA_Labute","DESC_HBA","DESC_HBD",
    "DESC_RotatableBonds","DESC_FormalCharge","DESC_MaxAbsPartialCharge","DESC_MinAbsPartialCharge",
    "KIER_Kappa1","KIER_Kappa2","KIER_Kappa3","KIER_Chi0v","KIER_Chi1v","KIER_Chi2v",
    "ESTATE_mean","ESTATE_std","ESTATE_sum",
]
missing_phys = [c for c in PHYS_KEYS if c not in haz_df.columns]
assert len(missing_phys) == 0, f"haz_df 缺理化列: {missing_phys}"


meta_df: (28461, 10) ['SMILES_Canonical_RDKit', 'Duration_Value(hour)', 'Effect', 'Endpoint', 'mgperL', 'Species Group', 'ChemicalName', 'CAS', 'CanonicalSMILES', 'database']
haz_df: (2106, 28)


  haz_graph_list = torch.load(HAZARD_GRAPH_PT, map_location="cpu")


haz_graph_list: 2106 haz_rowid_graph: (2106,)


In [4]:
# =========================
# Cell 4 | 复用你工程的 Graph 构图细节（atom/bond 特征 + mol_to_graph）
# =========================
ATOM_LIST = ["C","N","O","F","Cl","Br","I","S","P","B","Si","other"]
HYB_LIST  = [
    rdchem.HybridizationType.SP,
    rdchem.HybridizationType.SP2,
    rdchem.HybridizationType.SP3,
    rdchem.HybridizationType.SP3D,
    rdchem.HybridizationType.SP3D2,
    "other"
]
def one_hot(x, allowed):
    return [int(x == a) for a in allowed]

def atom_to_feature(atom: Chem.Atom):
    symbol = atom.GetSymbol()
    if symbol not in ATOM_LIST:
        symbol = "other"

    hyb = atom.GetHybridization()
    if hyb not in HYB_LIST:
        hyb = "other"

    feat_symbol = one_hot(symbol, ATOM_LIST)

    deg = atom.GetDegree()
    deg_list = [0,1,2,3,4,5,"6+"]
    deg_clamp = deg if deg <= 5 else "6+"
    feat_degree = one_hot(deg_clamp, deg_list)

    fc = atom.GetFormalCharge()
    fc_list = [-2,-1,0,1,2]
    fc_clamp = fc if fc in fc_list else 0
    feat_charge = one_hot(fc_clamp, fc_list)

    feat_hyb = one_hot(hyb, HYB_LIST)

    num_h = atom.GetTotalNumHs()
    h_list = [0,1,2,3,4,"5+"]
    h_clamp = num_h if num_h <= 4 else "5+"
    feat_num_h = one_hot(h_clamp, h_list)

    feat_ring     = [int(atom.IsInRing())]
    feat_aromatic = [int(atom.GetIsAromatic())]

    chiral_tag = atom.GetChiralTag()
    chiral_list = [
        rdchem.ChiralType.CHI_UNSPECIFIED,
        rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        "other"
    ]
    if chiral_tag not in chiral_list:
        chiral_tag = "other"
    feat_chiral = one_hot(chiral_tag, chiral_list)

    return (
        feat_symbol + feat_degree + feat_charge + feat_hyb + feat_num_h +
        feat_ring + feat_aromatic + feat_chiral
    )

def bond_to_feature(bond: Chem.Bond):
    bt = bond.GetBondType()
    if bt == rdchem.BondType.SINGLE:
        bt_vec = [1,0,0,0]
    elif bt == rdchem.BondType.DOUBLE:
        bt_vec = [0,1,0,0]
    elif bt == rdchem.BondType.TRIPLE:
        bt_vec = [0,0,1,0]
    else:
        bt_vec = [0,0,0,1]  # aromatic or others
    conj    = [int(bond.GetIsConjugated())]
    in_ring = [int(bond.IsInRing())]
    return bt_vec + conj + in_ring   # dim=6

ATOM_FEAT_DIM = len(atom_to_feature(Chem.MolFromSmiles("CC").GetAtomWithIdx(0)))
EDGE_FEAT_DIM = 6
print("ATOM_FEAT_DIM:", ATOM_FEAT_DIM, "EDGE_FEAT_DIM:", EDGE_FEAT_DIM)

def mol_to_graph(smiles: str, y: float, row_id: int):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    atom_feats = [atom_to_feature(a) for a in mol.GetAtoms()]
    x = torch.tensor(atom_feats, dtype=torch.float32)

    edge_index_list = []
    edge_attr_list  = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bfeat = bond_to_feature(bond)
        edge_index_list.append([i, j]); edge_index_list.append([j, i])
        edge_attr_list.append(bfeat);   edge_attr_list.append(bfeat)

    # 复用你原工程：无键 -> 丢弃
    if len(edge_index_list) == 0:
        return None

    edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
    edge_attr  = torch.tensor(edge_attr_list, dtype=torch.float32)

    y_tensor      = torch.tensor([y], dtype=torch.float32)
    row_id_tensor = torch.tensor([row_id], dtype=torch.long)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y_tensor, row_id=row_id_tensor)


ATOM_FEAT_DIM: 42 EDGE_FEAT_DIM: 6


In [5]:
# =========================
# Cell 5 | 复用你工程的 24 理化特征计算（用于元数据集）
# =========================
def compute_physchem_24(mol: Chem.Mol) -> dict:
    if mol is None:
        return {k: np.nan for k in PHYS_KEYS}

    MolWt_val      = float(Descriptors.MolWt(mol))
    ExactMolWt_val = float(Descriptors.ExactMolWt(mol))
    HeavyAtomCount = int(mol.GetNumHeavyAtoms())
    RingCount      = int(mol.GetRingInfo().NumRings())
    NumAromRings   = int(CalcNumAromaticRings(mol))
    FractionCSP3   = float(CalcFractionCSP3(mol))

    MolLogP_val = float(Crippen.MolLogP(mol))
    TPSA_val    = float(CalcTPSA(mol))
    ASA_Labute  = float(CalcLabuteASA(mol))

    HBA  = int(Lipinski.NumHAcceptors(mol))
    HBD  = int(Lipinski.NumHDonors(mol))
    RotB = int(Lipinski.NumRotatableBonds(mol))
    FormalCharge_val = int(Chem.GetFormalCharge(mol))

    MaxAbsQ, MinAbsQ = np.nan, np.nan
    try:
        mol_h = Chem.AddHs(mol)
        AllChem.ComputeGasteigerCharges(mol_h)
        charges = [float(a.GetProp("_GasteigerCharge")) for a in mol_h.GetAtoms()]
        MaxAbsQ = float(max(abs(c) for c in charges))
        MinAbsQ = float(min(abs(c) for c in charges))
    except Exception:
        MaxAbsQ, MinAbsQ = np.nan, np.nan

    try:
        K1 = float(CalcKappa1(mol))
        K2 = float(CalcKappa2(mol))
        K3 = float(CalcKappa3(mol))
        Chi0v = float(GraphDescriptors.Chi0v(mol))
        Chi1v = float(GraphDescriptors.Chi1v(mol))
        Chi2v = float(GraphDescriptors.Chi2v(mol))
    except Exception:
        K1=K2=K3=np.nan
        Chi0v=Chi1v=Chi2v=np.nan

    try:
        est = np.array(EStateIndices(mol), dtype=np.float32)
        if est.size == 0:
            EST_mean = EST_std = EST_sum = np.nan
        else:
            EST_mean = float(np.mean(est))
            EST_std  = float(np.std(est))
            EST_sum  = float(np.sum(est))
    except Exception:
        EST_mean = EST_std = EST_sum = np.nan

    return {
        "DESC_MolWt": MolWt_val,
        "DESC_ExactMolWt": ExactMolWt_val,
        "DESC_HeavyAtomCount": HeavyAtomCount,
        "DESC_RingCount": RingCount,
        "DESC_NumAromaticRings": NumAromRings,
        "DESC_FractionCSP3": FractionCSP3,
        "DESC_MolLogP": MolLogP_val,
        "DESC_TPSA": TPSA_val,
        "DESC_ASA_Labute": ASA_Labute,
        "DESC_HBA": HBA,
        "DESC_HBD": HBD,
        "DESC_RotatableBonds": RotB,
        "DESC_FormalCharge": FormalCharge_val,
        "DESC_MaxAbsPartialCharge": MaxAbsQ,
        "DESC_MinAbsPartialCharge": MinAbsQ,
        "KIER_Kappa1": K1,
        "KIER_Kappa2": K2,
        "KIER_Kappa3": K3,
        "KIER_Chi0v": Chi0v,
        "KIER_Chi1v": Chi1v,
        "KIER_Chi2v": Chi2v,
        "ESTATE_mean": EST_mean,
        "ESTATE_std": EST_std,
        "ESTATE_sum": EST_sum,
    }

def build_task_dataframe(meta_df: pd.DataFrame, spec: dict) -> pd.DataFrame:
    df = meta_df.copy()
    df["Species Group"] = df["Species Group"].map(normalize_species)
    df["Endpoint"] = df["Endpoint"].astype(str).str.upper().str.strip()
    df["Effect"] = df["Effect"].astype(str).str.upper().str.strip()

    df = df[
        (df["Species Group"] == spec["species"]) &
        (df["Endpoint"] == spec["endpoint"]) &
        (df["Effect"] == spec["effect"]) &
        (df["Duration_Value(hour)"].astype(float) == float(spec["duration"]))
    ].copy()

    # smiles & y
    df["SMILES_Canonical_RDKit"] = df["SMILES_Canonical_RDKit"].astype(str).str.strip()
    df = df[df["SMILES_Canonical_RDKit"].notna() & (df["SMILES_Canonical_RDKit"] != "")]
    df["y_log10"] = df["mgperL"].apply(safe_log10_mgL)
    df = df[df["y_log10"].notna()].copy()
    df["y_log10"] = df["y_log10"].astype(float)

    # FAST_DEV_RUN 抽样
    if FAST_DEV_RUN and len(df) > 800:
        df = df.sample(800, random_state=SEED).copy()

    # 先生成 row_id（连续）
    df = df.reset_index(drop=True)
    df["row_id"] = np.arange(len(df), dtype=np.int64)

    return df

def enrich_physchem_and_graph(df_task: pd.DataFrame):
    """
    严格过滤：必须 physchem(24列全非NaN) 成功 & graph 成功（与你应用一致）
    返回：
      df_keep（含 24理化，行顺序与 graph_list_keep 完全一致）
      graph_list_keep（Data list, 同顺序）
      drop_stats（dict：丢弃原因计数）
    """
    keep_rows = []
    graph_list_keep = []
    drop_reasons = []

    for _, r in tqdm(df_task.iterrows(), total=len(df_task), desc="physchem+graph"):
        smi = str(r["SMILES_Canonical_RDKit"]).strip()
        rid = int(r["row_id"])
        y   = float(r["y_log10"])

        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            drop_reasons.append("mol_fail")
            continue

        phys = compute_physchem_24(mol)

        # 严格：24个必须都不是 NaN（与你 hazard 严格过滤一致）
        bad_phys = False
        for k in PHYS_KEYS:
            v = phys.get(k, np.nan)
            if isinstance(v, float) and math.isnan(v):
                bad_phys = True
                break
        if bad_phys:
            drop_reasons.append("phys_nan")
            continue

        g = mol_to_graph(smi, y=y, row_id=rid)
        if g is None:
            drop_reasons.append("graph_fail_or_no_bond")
            continue

        # KEEP：同时 append（确保顺序一致）
        keep_rows.append({**r.to_dict(), **phys})
        graph_list_keep.append(g)
        drop_reasons.append("KEEP")

    df_keep = pd.DataFrame(keep_rows)
    drop_stats = pd.Series([x for x in drop_reasons if x != "KEEP"]).value_counts().to_dict()

    return df_keep.reset_index(drop=True), graph_list_keep, drop_stats


In [6]:
# =========================
# Cell 6 | 文本模态：ChemBERTaRegressor + Dataset + 训练/嵌入提取
# =========================
class SMILESDatasetAug(torch.utils.data.Dataset):
    def __init__(self, smiles, targets, tokenizer, max_length=256, augment=False):
        self.smiles = list(smiles)
        self.targets = np.array(targets, dtype=np.float32)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augment = augment
        self.mols = []
        for s in self.smiles:
            try:
                mol = Chem.MolFromSmiles(s)
            except Exception:
                mol = None
            self.mols.append(mol)

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        mol = self.mols[idx]
        if self.augment and (mol is not None):
            smiles_str = Chem.MolToSmiles(mol, doRandom=True)
        else:
            smiles_str = self.smiles[idx]

        enc = self.tokenizer(
            smiles_str,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
        )
        input_ids = torch.tensor(enc["input_ids"], dtype=torch.long)
        attention_mask = torch.tensor(enc["attention_mask"], dtype=torch.long)
        y = torch.tensor(self.targets[idx], dtype=torch.float32)
        return input_ids, attention_mask, y

class ChemBERTaRegressor(nn.Module):
    def __init__(self, model_dir: str, dropout=0.3, freeze_embeddings=True, freeze_n_layers=0):
        super().__init__()
        self.backbone = RobertaModel.from_pretrained(str(model_dir), local_files_only=True)
        hidden = self.backbone.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.head = nn.Linear(hidden, 1)

        if freeze_embeddings:
            for p in self.backbone.embeddings.parameters():
                p.requires_grad = False
        if freeze_n_layers > 0:
            # roberta encoder.layer
            for layer in self.backbone.encoder.layer[:freeze_n_layers]:
                for p in layer.parameters():
                    p.requires_grad = False

    def forward(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        x = self.dropout(cls)
        pred = self.head(x).squeeze(-1)
        return pred

    @torch.no_grad()
    def get_cls_embedding(self, input_ids, attention_mask):
        out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        return cls

def train_text_model(smiles, y, groups, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = RobertaTokenizer.from_pretrained(str(CHEMBERTA_DIR), local_files_only=True)

    gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
    tr_idx, va_idx = next(gss.split(smiles, y, groups=groups))

    if FAST_DEV_RUN:
        tr_idx = tr_idx[:400]
        va_idx = va_idx[:200]

    ds_tr = SMILESDatasetAug([smiles[i] for i in tr_idx], y[tr_idx], tokenizer, max_length=256, augment=True)
    ds_va = SMILESDatasetAug([smiles[i] for i in va_idx], y[va_idx], tokenizer, max_length=256, augment=False)

    dl_tr = torch.utils.data.DataLoader(ds_tr, batch_size=32, shuffle=True)
    dl_va = torch.utils.data.DataLoader(ds_va, batch_size=64, shuffle=False)

    model = ChemBERTaRegressor(CHEMBERTA_DIR, dropout=0.3, freeze_embeddings=True, freeze_n_layers=0).to(device)
    opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=2e-4, weight_decay=1e-4)
    sch = OneCycleLR(opt, max_lr=2e-4, total_steps=max(1, len(dl_tr)* (5 if FAST_DEV_RUN else 15)))

    best = 1e9
    best_state = None
    patience = 3 if FAST_DEV_RUN else 8
    bad = 0
    max_epochs = 5 if FAST_DEV_RUN else 15

    for ep in range(1, max_epochs+1):
        model.train()
        tr_losses = []
        for input_ids, attn, yy in dl_tr:
            input_ids = input_ids.to(device)
            attn = attn.to(device)
            yy = yy.to(device)
            opt.zero_grad()
            pred = model(input_ids, attn)
            loss = F.l1_loss(pred, yy)
            loss.backward()
            opt.step()
            sch.step()
            tr_losses.append(loss.item())

        model.eval()
        va_losses = []
        yv_true, yv_pred = [], []
        with torch.no_grad():
            for input_ids, attn, yy in dl_va:
                input_ids = input_ids.to(device)
                attn = attn.to(device)
                yy = yy.to(device)
                pred = model(input_ids, attn)
                loss = F.l1_loss(pred, yy)
                va_losses.append(loss.item())
                yv_true.append(yy.cpu().numpy())
                yv_pred.append(pred.cpu().numpy())
        yv_true = np.concatenate(yv_true)
        yv_pred = np.concatenate(yv_pred)
        met = compute_metrics(yv_true, yv_pred)

        va_loss = float(np.mean(va_losses))
        print(f"[Text][Ep {ep:03d}] train_L1={np.mean(tr_losses):.4f} val_L1={va_loss:.4f} val_r2={met['r2']:.3f}")

        if va_loss < best:
            best = va_loss
            best_state = {k: v.cpu().clone() for k,v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(model.state_dict(), out_dir / "text_chemberta.pt")
    return model, tokenizer

@torch.no_grad()
def embed_text(model, tokenizer, smiles_list, batch_size=128):
    model.eval()
    embs = []
    for i in range(0, len(smiles_list), batch_size):
        batch = smiles_list[i:i+batch_size]
        enc = tokenizer(batch, truncation=True, padding=True, max_length=256, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attn = enc["attention_mask"].to(device)
        cls = model.get_cls_embedding(input_ids, attn)
        embs.append(cls.cpu().numpy())
    return np.concatenate(embs, axis=0).astype(np.float32)


In [7]:
# =========================
# Cell 7 | 图模态：GATv2Regressor + 训练/嵌入提取（复用 GRAPH.ipynb 结构）
# =========================
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, indices):
        self.data_list = data_list
        self.indices = np.array(indices, dtype=np.int64)
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        return self.data_list[self.indices[idx]]

class GATv2Regressor(nn.Module):
    def __init__(self, in_channels, hidden_channels=256, num_layers=2, heads=4, edge_dim=None, dropout=0.05):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.convs = nn.ModuleList()
        assert heads == 1 or hidden_channels % heads == 0
        out_channels = hidden_channels // heads if heads > 1 else hidden_channels

        self.convs.append(GATv2Conv(in_channels, out_channels, heads=heads, edge_dim=edge_dim, dropout=dropout))
        for _ in range(num_layers - 1):
            self.convs.append(GATv2Conv(hidden_channels, out_channels, heads=heads, edge_dim=edge_dim, dropout=dropout))

        self.lin_head = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1),
        )

    def encode_graph(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        for conv in self.convs:
            x = conv(x, edge_index, edge_attr)
            x = F.elu(x)
            x = self.dropout(x)
        g = global_mean_pool(x, batch)
        return g

    def forward(self, data):
        g = self.encode_graph(data)
        out = self.lin_head(g).view(-1)
        return out


In [8]:
# =========================
# Cell 7（续）| 图模态：GATv2 训练 + 嵌入提取
# =========================
def _group_split_indices(groups, test_size=0.2, seed=SEED):
    gss = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
    idx = np.arange(len(groups))
    tr_idx, va_idx = next(gss.split(idx, idx, groups=groups))
    return tr_idx, va_idx

def train_graph_model(graph_list, groups, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    tr_idx, va_idx = _group_split_indices(groups, test_size=0.2, seed=SEED)
    if FAST_DEV_RUN:
        tr_idx = tr_idx[:300]
        va_idx = va_idx[:150]

    ds_tr = GraphDataset(graph_list, tr_idx)
    ds_va = GraphDataset(graph_list, va_idx)

    dl_tr = DataLoader(ds_tr, batch_size=64, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=128, shuffle=False, drop_last=False)

    model = GATv2Regressor(
        in_channels=ATOM_FEAT_DIM,
        hidden_channels=256,
        num_layers=2,
        heads=4,
        edge_dim=EDGE_FEAT_DIM,
        dropout=0.05
    ).to(device)

    opt = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    best = 1e9
    best_state = None
    patience = 3 if FAST_DEV_RUN else 8
    bad = 0
    max_epochs = 6 if FAST_DEV_RUN else 20

    for ep in range(1, max_epochs + 1):
        model.train()
        tr_losses = []
        for batch in dl_tr:
            batch = batch.to(device)
            y = batch.y.view(-1).to(device)
            opt.zero_grad()
            pred = model(batch)
            loss = F.l1_loss(pred, y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            opt.step()
            tr_losses.append(loss.item())

        model.eval()
        va_losses = []
        yv_true, yv_pred = [], []
        with torch.no_grad():
            for batch in dl_va:
                batch = batch.to(device)
                y = batch.y.view(-1).to(device)
                pred = model(batch)
                loss = F.l1_loss(pred, y)
                va_losses.append(loss.item())
                yv_true.append(y.cpu().numpy())
                yv_pred.append(pred.cpu().numpy())

        yv_true = np.concatenate(yv_true) if len(yv_true) else np.array([])
        yv_pred = np.concatenate(yv_pred) if len(yv_pred) else np.array([])
        met = compute_metrics(yv_true, yv_pred) if len(yv_true) else {"r2": np.nan, "mae": np.nan, "rmse": np.nan}

        va_loss = float(np.mean(va_losses)) if len(va_losses) else np.nan
        print(f"[Graph][Ep {ep:03d}] train_L1={np.mean(tr_losses):.4f} val_L1={va_loss:.4f} val_r2={met['r2']:.3f}")

        if va_loss < best:
            best = va_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(model.state_dict(), out_dir / "graph_gatv2.pt")
    return model

@torch.no_grad()
def embed_graph(model: GATv2Regressor, graph_list, batch_size=256):
    model.eval()
    dl = DataLoader(graph_list, batch_size=batch_size, shuffle=False, drop_last=False)
    embs = []
    rowids = []
    for batch in dl:
        batch = batch.to(device)
        g = model.encode_graph(batch)  # (B, hidden)
        embs.append(g.cpu().numpy())
        rowids.append(batch.row_id.view(-1).cpu().numpy())
    emb = np.concatenate(embs, axis=0).astype(np.float32)
    rid = np.concatenate(rowids, axis=0).astype(np.int64)
    return emb, rid


In [16]:
# =========================
# Cell 8 | 理化模态：PhysChem MLP（修复：inf/超大值导致的 sklearn 崩溃）
# =========================
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer

def sanitize_phys_matrix(X, clip_quantile=0.001):
    """
    将 X 中的 inf/-inf -> NaN，并对极端值进行分位数裁剪，避免 float32 溢出/模型不稳定。
    - clip_quantile=0.001 表示按 0.1% 和 99.9% 分位裁剪（每列独立）
    """
    X = np.asarray(X, dtype=np.float64)  # 用 float64 做清洗，避免 float32 溢出
    # 1) 非有限 -> NaN
    bad = ~np.isfinite(X)
    if bad.any():
        X[bad] = np.nan

    # 2) 分位裁剪（对每列独立做，跳过全 NaN 列）
    if clip_quantile is not None and 0.0 < clip_quantile < 0.5:
        lo_q = clip_quantile
        hi_q = 1.0 - clip_quantile
        for j in range(X.shape[1]):
            col = X[:, j]
            finite = np.isfinite(col)
            if finite.sum() < 5:
                continue
            lo = np.nanquantile(col, lo_q)
            hi = np.nanquantile(col, hi_q)
            # 防止 lo/hi 还是 NaN
            if np.isfinite(lo) and np.isfinite(hi) and lo < hi:
                col = np.clip(col, lo, hi)
                X[:, j] = col

    return X

class PhysChemMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, emb_dim=64, dropout=0.10):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc_emb = nn.Linear(hidden_dim, emb_dim)
        self.fc_out = nn.Linear(emb_dim, 1)

    def forward(self, x, return_emb=False):
        h = self.act(self.fc1(x))
        h = self.dropout(h)
        emb = self.act(self.fc_emb(h))
        out = self.fc_out(emb).view(-1)
        if return_emb:
            return out, emb
        return out

def train_phys_model(X_phys, y, groups, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    # ✅ 关键：先 sanitize，再 split（避免 inf 进 sklearn）
    X_phys = sanitize_phys_matrix(X_phys, clip_quantile=0.001)

    tr_idx, va_idx = _group_split_indices(groups, test_size=0.2, seed=SEED)
    if FAST_DEV_RUN:
        tr_idx = tr_idx[:400]
        va_idx = va_idx[:200]

    imputer = SimpleImputer(strategy="median")
    scaler = StandardScaler()

    # 注意：sklearn 的 imputer/scaler 期望 finite 或 NaN（NaN 可以）
    X_tr = imputer.fit_transform(X_phys[tr_idx])
    X_tr = scaler.fit_transform(X_tr)
    X_va = imputer.transform(X_phys[va_idx])
    X_va = scaler.transform(X_va)

    # 转为 float32 给 torch
    X_tr_t = torch.tensor(X_tr.astype(np.float32), dtype=torch.float32)
    y_tr_t = torch.tensor(np.asarray(y[tr_idx], np.float32), dtype=torch.float32)
    X_va_t = torch.tensor(X_va.astype(np.float32), dtype=torch.float32)
    y_va_t = torch.tensor(np.asarray(y[va_idx], np.float32), dtype=torch.float32)

    model = PhysChemMLP(in_dim=X_tr.shape[1], hidden_dim=128, emb_dim=64, dropout=0.10).to(device)
    opt = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    best = 1e9
    best_state = None
    patience = 3 if FAST_DEV_RUN else 10
    bad = 0
    max_epochs = 8 if FAST_DEV_RUN else 40
    batch_size = 256

    for ep in range(1, max_epochs + 1):
        model.train()
        perm = np.random.permutation(len(X_tr))
        losses = []
        for s in range(0, len(X_tr), batch_size):
            idx = perm[s:s+batch_size]
            xb = X_tr_t[idx].to(device)
            yb = y_tr_t[idx].to(device)
            opt.zero_grad()
            pred = model(xb)
            loss = F.l1_loss(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            opt.step()
            losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            pred_va = model(X_va_t.to(device)).cpu().numpy()
        met = compute_metrics(y_va_t.numpy(), pred_va)
        va_l1 = float(np.mean(np.abs(pred_va - y_va_t.numpy())))

        print(f"[Phys][Ep {ep:03d}] train_L1={np.mean(losses):.4f} val_L1={va_l1:.4f} val_r2={met['r2']:.3f}")

        if va_l1 < best:
            best = va_l1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(model.state_dict(), out_dir / "phys_mlp.pt")
    import joblib
    joblib.dump(imputer, out_dir / "phys_imputer.pkl")
    joblib.dump(scaler, out_dir / "phys_scaler.pkl")
    return model, imputer, scaler

@torch.no_grad()
def embed_phys(model: PhysChemMLP, X_phys, imputer, scaler, batch_size=512):
    # ✅ 推理也要同样 sanitize（否则 hazard 里有 inf 一样会炸）
    X_phys = sanitize_phys_matrix(X_phys, clip_quantile=0.001)

    model.eval()
    X = imputer.transform(X_phys)
    X = scaler.transform(X)

    X_t = torch.tensor(X.astype(np.float32), dtype=torch.float32)
    embs = []
    preds = []
    for s in range(0, len(X_t), batch_size):
        xb = X_t[s:s+batch_size].to(device)
        out, emb = model(xb, return_emb=True)
        preds.append(out.cpu().numpy())
        embs.append(emb.cpu().numpy())
    pred = np.concatenate(preds, axis=0).astype(np.float32)
    emb = np.concatenate(embs, axis=0).astype(np.float32)
    return pred, emb


In [10]:
# =========================
# Cell 9 | MID 融合：三模态 token 注意力（TGP-like），输出 fused_emb + y
# =========================
class TriTokenFusion(nn.Module):
    """
    把 text/graph/phys 当作 3 个 token，做多头注意力，pool 得到 fused embedding
    """
    def __init__(self, d_text, d_graph, d_phys, d_model=256, nhead=8, dropout=0.10):
        super().__init__()
        self.p_text = nn.Linear(d_text, d_model)
        self.p_graph = nn.Linear(d_graph, d_model)
        self.p_phys = nn.Linear(d_phys, d_model)

        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model),
        )
        self.ln = nn.LayerNorm(d_model)

        self.head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model//2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model//2, 1),
        )

    def forward(self, z_text, z_graph, z_phys, return_emb=False):
        # B x d -> B x 1 x D
        t = self.p_text(z_text).unsqueeze(1)
        g = self.p_graph(z_graph).unsqueeze(1)
        p = self.p_phys(z_phys).unsqueeze(1)
        x = torch.cat([t, g, p], dim=1)  # (B,3,D)

        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = self.ln(x + attn_out)
        x = x + self.ffn(x)
        x = self.ln(x)

        fused = x.mean(dim=1)  # (B,D)
        yhat = self.head(fused).view(-1)
        if return_emb:
            return yhat, fused
        return yhat

def train_fusion_model(Zt, Zg, Zp, y, groups, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)

    tr_idx, va_idx = _group_split_indices(groups, test_size=0.2, seed=SEED)
    if FAST_DEV_RUN:
        tr_idx = tr_idx[:500]
        va_idx = va_idx[:250]

    Zt_tr = torch.tensor(Zt[tr_idx], dtype=torch.float32)
    Zg_tr = torch.tensor(Zg[tr_idx], dtype=torch.float32)
    Zp_tr = torch.tensor(Zp[tr_idx], dtype=torch.float32)
    y_tr  = torch.tensor(y[tr_idx], dtype=torch.float32)

    Zt_va = torch.tensor(Zt[va_idx], dtype=torch.float32)
    Zg_va = torch.tensor(Zg[va_idx], dtype=torch.float32)
    Zp_va = torch.tensor(Zp[va_idx], dtype=torch.float32)
    y_va  = torch.tensor(y[va_idx], dtype=torch.float32)

    model = TriTokenFusion(d_text=Zt.shape[1], d_graph=Zg.shape[1], d_phys=Zp.shape[1],
                           d_model=256, nhead=8, dropout=0.10).to(device)
    opt = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    best = 1e9
    best_state = None
    patience = 3 if FAST_DEV_RUN else 10
    bad = 0
    max_epochs = 8 if FAST_DEV_RUN else 50
    batch_size = 256

    for ep in range(1, max_epochs + 1):
        model.train()
        perm = np.random.permutation(len(tr_idx))
        losses = []
        for s in range(0, len(perm), batch_size):
            idx = perm[s:s+batch_size]
            xb_t = Zt_tr[idx].to(device)
            xb_g = Zg_tr[idx].to(device)
            xb_p = Zp_tr[idx].to(device)
            yb   = y_tr[idx].to(device)

            opt.zero_grad()
            pred = model(xb_t, xb_g, xb_p)
            loss = F.l1_loss(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            opt.step()
            losses.append(loss.item())

        model.eval()
        with torch.no_grad():
            pred_va = model(Zt_va.to(device), Zg_va.to(device), Zp_va.to(device)).cpu().numpy()
        met = compute_metrics(y_va.numpy(), pred_va)
        va_l1 = float(np.mean(np.abs(pred_va - y_va.numpy())))
        print(f"[Fusion][Ep {ep:03d}] train_L1={np.mean(losses):.4f} val_L1={va_l1:.4f} val_r2={met['r2']:.3f}")

        if va_l1 < best:
            best = va_l1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    torch.save(model.state_dict(), out_dir / "fusion_tgp.pt")
    return model

@torch.no_grad()
def embed_fusion(model: TriTokenFusion, Zt, Zg, Zp, batch_size=512):
    model.eval()
    yhat_all = []
    emb_all = []
    for s in range(0, len(Zt), batch_size):
        xb_t = torch.tensor(Zt[s:s+batch_size], dtype=torch.float32).to(device)
        xb_g = torch.tensor(Zg[s:s+batch_size], dtype=torch.float32).to(device)
        xb_p = torch.tensor(Zp[s:s+batch_size], dtype=torch.float32).to(device)
        yhat, emb = model(xb_t, xb_g, xb_p, return_emb=True)
        yhat_all.append(yhat.cpu().numpy())
        emb_all.append(emb.cpu().numpy())
    return np.concatenate(yhat_all, axis=0).astype(np.float32), np.concatenate(emb_all, axis=0).astype(np.float32)


In [11]:
# =========================
# Cell 10 | Hazard 数据准备（排序 + 对齐）
# =========================
# 按 row_id 排序，后面每个 task 直接复用这个顺序的 hazard 输入
haz_df_sorted = haz_df.sort_values("row_id").reset_index(drop=True)
haz_row_ids = haz_df_sorted["row_id"].to_numpy(dtype=np.int64)
haz_smiles = haz_df_sorted["SMILES_Canonical_RDKit"].astype(str).tolist()

# hazard phys matrix
haz_X_phys = haz_df_sorted[PHYS_KEYS].to_numpy(dtype=np.float32)

# hazard graphs：需要按 row_id 对齐
haz_graph_list = torch.load(HAZARD_GRAPH_PT, map_location="cpu")
haz_rowid_graph = np.load(HAZARD_GRAPH_ROWID).astype(np.int64)

# 建 row_id -> graph 的映射并按 haz_df_sorted 的 row_id 重排
rid2graph = {int(g.row_id.item()): g for g in haz_graph_list}
missing = [rid for rid in haz_row_ids if int(rid) not in rid2graph]
assert len(missing) == 0, f"haz_df 有 row_id 在 graph_list 缺失: {missing[:10]} ..."

haz_graph_aligned = [rid2graph[int(rid)] for rid in haz_row_ids]

print("hazard aligned:", len(haz_smiles), len(haz_graph_aligned), haz_X_phys.shape)


  haz_graph_list = torch.load(HAZARD_GRAPH_PT, map_location="cpu")


hazard aligned: 2106 2106 (2106, 24)


In [17]:
# =========================
# Cell 11 | 单个 Task：训练三模态 + 融合 + 输出名录预测
# =========================
def run_one_task(task_id: str, spec: dict):
    task_dir = OUT_BASE / task_id
    task_dir.mkdir(parents=True, exist_ok=True)
    print("\n" + "="*80)
    print("TASK:", task_id, spec)
    print("="*80)

    # 1) 从总表切任务子集
    df_task = build_task_dataframe(meta_df, spec)
    print("raw task df:", df_task.shape)

    # 2) 物化+构图过滤（严格对齐应用逻辑）
    df_keep, graph_list_keep, stats_drop = enrich_physchem_and_graph(df_task)
    print("keep df:", df_keep.shape, "graph_keep:", len(graph_list_keep), "drop_stats:", stats_drop)

    # 重要：df_keep 的顺序 = graph_list_keep 的顺序（我们在 enrich 里按遍历顺序 append）
    df_keep = df_keep.reset_index(drop=True)
    assert len(df_keep) == len(graph_list_keep), "df_keep 与 graph_list_keep 长度不一致"

    # y / groups
    y = df_keep["y_log10"].to_numpy(dtype=np.float32)
    groups = df_keep["SMILES_Canonical_RDKit"].astype(str).to_numpy()

    # 保存 task 训练清单（只保留必要列，避免 parquet 类型炸裂）
    keep_cols = ["row_id", "SMILES_Canonical_RDKit", "y_log10"] + PHYS_KEYS
    df_keep_save = df_keep.loc[:, [c for c in keep_cols if c in df_keep.columns]].copy()

    # 强制类型：SMILES 变字符串；其他理化/标签变 float
    df_keep_save["SMILES_Canonical_RDKit"] = df_keep_save["SMILES_Canonical_RDKit"].astype(str)

    if "y_log10" in df_keep_save.columns:
        df_keep_save["y_log10"] = pd.to_numeric(df_keep_save["y_log10"], errors="coerce")

    for c in PHYS_KEYS:
        if c in df_keep_save.columns:
            df_keep_save[c] = pd.to_numeric(df_keep_save[c], errors="coerce")

    # 这里再写 parquet 就不会被 datetime/混合类型搞死
    df_keep_save.to_parquet(task_dir / "train_task_keep.parquet", index=False)


    # 3) Text 训练 + embedding
    text_dir = task_dir / "text"
    text_model, tokenizer = train_text_model(
        smiles=df_keep["SMILES_Canonical_RDKit"].astype(str).tolist(),
        y=y,
        groups=groups,
        out_dir=text_dir
    )
    Zt_train = embed_text(text_model, tokenizer, df_keep["SMILES_Canonical_RDKit"].astype(str).tolist(), batch_size=128)
    Zt_haz   = embed_text(text_model, tokenizer, haz_smiles, batch_size=128)
    np.save(text_dir / "Zt_train.npy", Zt_train)
    np.save(text_dir / "Zt_hazard.npy", Zt_haz)

    # 4) Graph 训练 + embedding
    graph_dir = task_dir / "graph"
    g_model = train_graph_model(graph_list_keep, groups, out_dir=graph_dir)
    Zg_train, rid_g_train = embed_graph(g_model, graph_list_keep, batch_size=256)
    Zg_haz, rid_g_haz = embed_graph(g_model, haz_graph_aligned, batch_size=256)
    np.save(graph_dir / "Zg_train.npy", Zg_train)
    np.save(graph_dir / "rid_train.npy", rid_g_train)
    np.save(graph_dir / "Zg_hazard.npy", Zg_haz)
    np.save(graph_dir / "rid_hazard.npy", rid_g_haz)

    # 5) Phys 训练 + embedding
    phys_dir = task_dir / "phys"
    X_phys_train = df_keep[PHYS_KEYS].to_numpy(dtype=np.float32)
    phys_model, imputer, scaler = train_phys_model(X_phys_train, y, groups, out_dir=phys_dir)
    yhat_phys_train, Zp_train = embed_phys(phys_model, X_phys_train, imputer, scaler)
    yhat_phys_haz, Zp_haz = embed_phys(phys_model, haz_X_phys, imputer, scaler)
    np.save(phys_dir / "Zp_train.npy", Zp_train)
    np.save(phys_dir / "Zp_hazard.npy", Zp_haz)

    # 6) 融合训练 + 预测 + fused embedding
    fusion_dir = task_dir / "fusion"
    fusion_model = train_fusion_model(Zt_train, Zg_train, Zp_train, y, groups, out_dir=fusion_dir)

    yhat_train, Zf_train = embed_fusion(fusion_model, Zt_train, Zg_train, Zp_train)
    yhat_haz,   Zf_haz   = embed_fusion(fusion_model, Zt_haz,   Zg_haz,   Zp_haz)

    np.save(fusion_dir / "Zf_train.npy", Zf_train)
    np.save(fusion_dir / "Zf_hazard.npy", Zf_haz)

    # 训练集评估（注意：这里是“全量拟合输出”，不是严格CV OOF，只用于 sanity check）
    met_train = compute_metrics(y, yhat_train)
    with open(task_dir / "metrics_train_sanity.json", "w", encoding="utf-8") as f:
        json.dump(met_train, f, ensure_ascii=False, indent=2)
    print("Train sanity metrics:", met_train)

    # 7) 输出 hazard 预测结果（log10 mg/L -> mg/L）
    pred_log10 = yhat_haz.astype(float)
    pred_mgL = np.power(10.0, pred_log10)

    out_df = haz_df_sorted.copy()
    out_df["pred_log10_mgL"] = pred_log10
    out_df["pred_mgL"] = pred_mgL
    out_df["task_id"] = task_id

    out_path = task_dir / f"hazard_pred_{task_id}.xlsx"
    out_df.to_excel(out_path, index=False)
    print("Saved:", out_path)

    # 返回用于总汇总
    return out_df[["row_id", "CAS", "SMILES_Canonical_RDKit", "pred_log10_mgL", "pred_mgL", "task_id"]].copy()


In [18]:
# =========================
# Cell 12 | 跑完 6 个任务 + 汇总成一个总表
# =========================
all_preds = []
for task_id, spec in TASK_SPECS.items():
    pred_df = run_one_task(task_id, spec)
    all_preds.append(pred_df)

all_long = pd.concat(all_preds, axis=0, ignore_index=True)

# 透视成宽表：每个化合物一行，6列预测（mg/L & log10）
wide_log10 = all_long.pivot_table(index=["row_id","CAS","SMILES_Canonical_RDKit"],
                                 columns="task_id", values="pred_log10_mgL", aggfunc="first").reset_index()
wide_mgL = all_long.pivot_table(index=["row_id","CAS","SMILES_Canonical_RDKit"],
                               columns="task_id", values="pred_mgL", aggfunc="first").reset_index()

# 写到一个 Excel：两个 sheet
out_xlsx = OUT_BASE / "hazard_predictions_all_tasks.xlsx"
with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
    all_long.to_excel(w, sheet_name="long", index=False)
    wide_log10.to_excel(w, sheet_name="wide_log10", index=False)
    wide_mgL.to_excel(w, sheet_name="wide_mgL", index=False)

print("\n✅ ALL DONE:", out_xlsx)



TASK: ALGAE_EC50_72_POP {'species': 'algae', 'endpoint': 'EC50', 'duration': 72, 'effect': 'POP'}
raw task df: (2046, 12)


physchem+graph:  14%|█▍        | 283/2046 [00:00<00:03, 510.04it/s][15:06:12] SMILES Parse Error: syntax error while parsing: CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC
[15:06:12] SMILES Parse Error: check for mistakes around position 10:
[15:06:12] CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O
[15:06:12] ~~~~~~~~~^
[15:06:12] SMILES Parse Error: Failed parsing SMILES 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC' for input: 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC'
[15:06:12] SMILES Parse Error: syntax error while parsing: O|[Co](|O)|O
[15:06:12] SMILES Parse Error: check for mistakes around position 2:
[15:06:12] O|[Co](|O)|O
[15:06:12] ~^
[15:06:12] SMILES Parse Error: Failed parsing SMILES 'O|[Co](|O)|O' for input: 'O|[Co](|O)|O'
[15:06:12] SMILES Parse Error: syntax error while parsing: [Cl]|[W](|[Cl])(|[Cl])(|[Cl])(|[Cl])|[Cl]
[15:06:12] SMILES Parse Error: check for mistakes around position 5:
[15:06:12] [Cl]|[W](|[Cl])(|[Cl])(|[Cl])(

keep df: (1907, 36) graph_keep: 1907 drop_stats: {'graph_fail_or_no_bond': 75, 'phys_nan': 36, 'mol_fail': 28}


Some weights of RobertaModel were not initialized from the model checkpoint at /root/多模态/model and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Text][Ep 001] train_L1=1.0048 val_L1=0.8516 val_r2=-0.016
[Text][Ep 002] train_L1=0.9192 val_L1=0.8472 val_r2=0.010
[Text][Ep 003] train_L1=0.9011 val_L1=0.9113 val_r2=-0.185
[Text][Ep 004] train_L1=0.8975 val_L1=0.8346 val_r2=-0.032
[Text][Ep 005] train_L1=0.9168 val_L1=0.9019 val_r2=-0.167
[Text][Ep 006] train_L1=0.8916 val_L1=0.7972 val_r2=0.086
[Text][Ep 007] train_L1=0.9144 val_L1=0.8805 val_r2=0.000
[Text][Ep 008] train_L1=0.9474 val_L1=0.8792 val_r2=-0.000
[Text][Ep 009] train_L1=0.9333 val_L1=0.8946 val_r2=-0.126
[Text][Ep 010] train_L1=0.9383 val_L1=0.8731 val_r2=-0.026
[Text][Ep 011] train_L1=0.9396 val_L1=0.9021 val_r2=-0.151
[Text][Ep 012] train_L1=0.9342 val_L1=0.8729 val_r2=-0.015
[Text][Ep 013] train_L1=0.9284 val_L1=0.8778 val_r2=-0.056
[Text][Ep 014] train_L1=0.9271 val_L1=0.8751 val_r2=-0.042
[Graph][Ep 001] train_L1=0.9703 val_L1=0.8810 val_r2=-0.042
[Graph][Ep 002] train_L1=0.9179 val_L1=0.8717 val_r2=0.001
[Graph][Ep 003] train_L1=0.9143 val_L1=0.8686 val_r2=0.010

physchem+graph:   0%|          | 0/1649 [00:00<?, ?it/s][15:08:17] SMILES Parse Error: syntax error while parsing: CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC
[15:08:17] SMILES Parse Error: check for mistakes around position 10:
[15:08:17] CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O
[15:08:17] ~~~~~~~~~^
[15:08:17] SMILES Parse Error: Failed parsing SMILES 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC' for input: 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC'
[15:08:17] SMILES Parse Error: syntax error while parsing: CCCCCCCCCCCC[S]|[Sn](|[S]CCCCCCCCCCCC)(CCCC)CCCC
[15:08:17] SMILES Parse Error: check for mistakes around position 16:
[15:08:17] CCCCCCCCCCCC[S]|[Sn](|[S]CCCCCCCCCCCC)(CC
[15:08:17] ~~~~~~~~~~~~~~~^
[15:08:17] SMILES Parse Error: Failed parsing SMILES 'CCCCCCCCCCCC[S]|[Sn](|[S]CCCCCCCCCCCC)(CCCC)CCCC' for input: 'CCCCCCCCCCCC[S]|[Sn](|[S]CCCCCCCCCCCC)(CCCC)CCCC'
[15:08:17] Explicit valence for atom # 0 O, 2, is greater than permitt

keep df: (1534, 36) graph_keep: 1534 drop_stats: {'graph_fail_or_no_bond': 63, 'mol_fail': 29, 'phys_nan': 23}
[Text][Ep 001] train_L1=0.9569 val_L1=0.6729 val_r2=0.171
[Text][Ep 002] train_L1=0.8228 val_L1=0.7000 val_r2=0.133
[Text][Ep 003] train_L1=0.8139 val_L1=0.7446 val_r2=-0.042
[Text][Ep 004] train_L1=0.7852 val_L1=0.7197 val_r2=0.044
[Text][Ep 005] train_L1=0.8126 val_L1=0.7647 val_r2=0.010
[Text][Ep 006] train_L1=0.8344 val_L1=0.8274 val_r2=-0.175
[Text][Ep 007] train_L1=0.8792 val_L1=0.8199 val_r2=-0.148
[Text][Ep 008] train_L1=0.8691 val_L1=0.7788 val_r2=-0.008
[Text][Ep 009] train_L1=0.8641 val_L1=0.8148 val_r2=-0.131
[Graph][Ep 001] train_L1=0.9327 val_L1=0.8204 val_r2=-0.169
[Graph][Ep 002] train_L1=0.8438 val_L1=0.7493 val_r2=0.020
[Graph][Ep 003] train_L1=0.8186 val_L1=0.7498 val_r2=0.004
[Graph][Ep 004] train_L1=0.8173 val_L1=0.7447 val_r2=0.013
[Graph][Ep 005] train_L1=0.8172 val_L1=0.7457 val_r2=0.002
[Graph][Ep 006] train_L1=0.8278 val_L1=0.7229 val_r2=0.064
[Graph]

physchem+graph:   0%|          | 0/1570 [00:00<?, ?it/s][15:09:26] Explicit valence for atom # 0 O, 2, is greater than permitted
physchem+graph:   3%|▎         | 52/1570 [00:00<00:02, 513.72it/s][15:09:26] SMILES Parse Error: syntax error while parsing: [Na+].[Na+].[Na+].[Na+].[Fe-4](|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)|[C]#N
[15:09:26] SMILES Parse Error: check for mistakes around position 32:
[15:09:26] .[Na+].[Na+].[Fe-4](|[C]#N)(|[C]#N)(|[C]#
[15:09:26] ~~~~~~~~~~~~~~~~~~~~^
[15:09:26] SMILES Parse Error: Failed parsing SMILES '[Na+].[Na+].[Na+].[Na+].[Fe-4](|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)|[C]#N' for input: '[Na+].[Na+].[Na+].[Na+].[Fe-4](|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)|[C]#N'
[15:09:26] SMILES Parse Error: syntax error while parsing: [K+].[K+].[K+].[K+].[Fe](|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)(|[C]#N)|[C]#N
[15:09:26] SMILES Parse Error: check for mistakes around position 26:
[15:09:26] [K+].[K+].[K+].[Fe](|[C]#N)(|[C]#N)(|[C]#
[15:09:26] ~~~~~~~~~~~~~~~~~~~~^
[15:0

keep df: (1462, 36) graph_keep: 1462 drop_stats: {'graph_fail_or_no_bond': 58, 'phys_nan': 26, 'mol_fail': 24}
[Text][Ep 001] train_L1=0.9698 val_L1=0.8338 val_r2=-0.015
[Text][Ep 002] train_L1=0.8497 val_L1=0.7273 val_r2=0.113
[Text][Ep 003] train_L1=0.8003 val_L1=0.7404 val_r2=0.120
[Text][Ep 004] train_L1=0.8031 val_L1=0.8618 val_r2=-0.298
[Text][Ep 005] train_L1=0.8012 val_L1=0.7111 val_r2=0.002
[Text][Ep 006] train_L1=0.8139 val_L1=0.8350 val_r2=-0.199
[Text][Ep 007] train_L1=0.8867 val_L1=0.8092 val_r2=-0.086
[Text][Ep 008] train_L1=0.8587 val_L1=0.7501 val_r2=0.090
[Text][Ep 009] train_L1=0.7733 val_L1=0.7787 val_r2=0.021
[Text][Ep 010] train_L1=0.8619 val_L1=0.7825 val_r2=0.009
[Text][Ep 011] train_L1=0.8230 val_L1=0.7560 val_r2=0.058
[Text][Ep 012] train_L1=0.7792 val_L1=0.8095 val_r2=-0.209
[Text][Ep 013] train_L1=0.7406 val_L1=0.7692 val_r2=-0.107
[Graph][Ep 001] train_L1=1.0084 val_L1=0.7991 val_r2=-0.174
[Graph][Ep 002] train_L1=0.8433 val_L1=0.7919 val_r2=-0.008
[Graph][E

physchem+graph:   8%|▊         | 172/2025 [00:00<00:03, 472.60it/s][15:10:58] SMILES Parse Error: syntax error while parsing: [NH2-]|[Pd++](|[NH2-])(|[NH2-])|[NH2-].[Cl].[Cl]
[15:10:58] SMILES Parse Error: check for mistakes around position 7:
[15:10:58] [NH2-]|[Pd++](|[NH2-])(|[NH2-])|[NH2-].[C
[15:10:58] ~~~~~~^
[15:10:58] SMILES Parse Error: Failed parsing SMILES '[NH2-]|[Pd++](|[NH2-])(|[NH2-])|[NH2-].[Cl].[Cl]' for input: '[NH2-]|[Pd++](|[NH2-])(|[NH2-])|[NH2-].[Cl].[Cl]'
[15:10:58] SMILES Parse Error: syntax error while parsing: CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC
[15:10:58] SMILES Parse Error: check for mistakes around position 13:
[15:10:58] CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC
[15:10:58] ~~~~~~~~~~~~^
[15:10:58] SMILES Parse Error: Failed parsing SMILES 'CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC' for input: 'CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC'
[15:10:58] SMILES Parse Error: syntax error while parsing: CCCCCCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)|[S]CC(=

keep df: (1876, 36) graph_keep: 1876 drop_stats: {'graph_fail_or_no_bond': 81, 'phys_nan': 42, 'mol_fail': 26}


Some weights of RobertaModel were not initialized from the model checkpoint at /root/多模态/model and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Text][Ep 001] train_L1=1.0291 val_L1=0.9202 val_r2=0.098
[Text][Ep 002] train_L1=0.9767 val_L1=1.0400 val_r2=0.003
[Text][Ep 003] train_L1=1.0250 val_L1=0.9299 val_r2=0.162
[Text][Ep 004] train_L1=0.9750 val_L1=0.9152 val_r2=0.180
[Text][Ep 005] train_L1=1.0090 val_L1=1.0405 val_r2=-0.027
[Text][Ep 006] train_L1=1.0318 val_L1=1.0605 val_r2=-0.006
[Text][Ep 007] train_L1=1.0385 val_L1=1.0472 val_r2=-0.000
[Text][Ep 008] train_L1=1.0418 val_L1=1.0406 val_r2=-0.003
[Text][Ep 009] train_L1=1.0296 val_L1=1.0452 val_r2=-0.057
[Text][Ep 010] train_L1=1.0304 val_L1=1.0420 val_r2=-0.002
[Text][Ep 011] train_L1=1.0153 val_L1=1.0537 val_r2=-0.002
[Text][Ep 012] train_L1=1.0203 val_L1=1.0406 val_r2=-0.003
[Graph][Ep 001] train_L1=1.0294 val_L1=1.0404 val_r2=-0.007
[Graph][Ep 002] train_L1=1.0046 val_L1=1.0369 val_r2=-0.010
[Graph][Ep 003] train_L1=0.9978 val_L1=1.0340 val_r2=-0.006
[Graph][Ep 004] train_L1=0.9829 val_L1=1.0324 val_r2=0.007
[Graph][Ep 005] train_L1=0.9729 val_L1=1.0300 val_r2=-0.0

physchem+graph:   0%|          | 0/1076 [00:00<?, ?it/s][15:12:45] SMILES Parse Error: syntax error while parsing: [Cl]|[W](|[Cl])(|[Cl])(|[Cl])(|[Cl])|[Cl]
[15:12:45] SMILES Parse Error: check for mistakes around position 5:
[15:12:45] [Cl]|[W](|[Cl])(|[Cl])(|[Cl])(|[Cl])|[Cl]
[15:12:45] ~~~~^
[15:12:45] SMILES Parse Error: Failed parsing SMILES '[Cl]|[W](|[Cl])(|[Cl])(|[Cl])(|[Cl])|[Cl]' for input: '[Cl]|[W](|[Cl])(|[Cl])(|[Cl])(|[Cl])|[Cl]'
[15:12:45] SMILES Parse Error: syntax error while parsing: [Cl]|[Sn](CCCC)(CCCC)CCCC
[15:12:45] SMILES Parse Error: check for mistakes around position 5:
[15:12:45] [Cl]|[Sn](CCCC)(CCCC)CCCC
[15:12:45] ~~~~^
[15:12:45] SMILES Parse Error: Failed parsing SMILES '[Cl]|[Sn](CCCC)(CCCC)CCCC' for input: '[Cl]|[Sn](CCCC)(CCCC)CCCC'
[15:12:45] SMILES Parse Error: syntax error while parsing: CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC
[15:12:45] SMILES Parse Error: check for mistakes around position 13:
[15:12:45] CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC
[15:

keep df: (1033, 36) graph_keep: 1033 drop_stats: {'graph_fail_or_no_bond': 18, 'mol_fail': 16, 'phys_nan': 9}
[Text][Ep 001] train_L1=0.9955 val_L1=0.8743 val_r2=0.021
[Text][Ep 002] train_L1=0.9648 val_L1=0.7768 val_r2=0.041
[Text][Ep 003] train_L1=0.9023 val_L1=0.8514 val_r2=-0.156
[Text][Ep 004] train_L1=0.9450 val_L1=0.8861 val_r2=0.114
[Text][Ep 005] train_L1=0.8668 val_L1=0.8556 val_r2=0.054
[Text][Ep 006] train_L1=0.9271 val_L1=0.7806 val_r2=0.100
[Text][Ep 007] train_L1=0.8868 val_L1=0.7716 val_r2=0.159
[Text][Ep 008] train_L1=0.8643 val_L1=0.7107 val_r2=0.181
[Text][Ep 009] train_L1=0.7881 val_L1=0.7534 val_r2=0.099
[Text][Ep 010] train_L1=0.8178 val_L1=0.7993 val_r2=0.221
[Text][Ep 011] train_L1=0.7821 val_L1=0.7832 val_r2=0.143
[Text][Ep 012] train_L1=0.8014 val_L1=0.7423 val_r2=0.242
[Text][Ep 013] train_L1=0.7651 val_L1=0.7222 val_r2=0.268
[Text][Ep 014] train_L1=0.7127 val_L1=0.7380 val_r2=0.248
[Text][Ep 015] train_L1=0.6995 val_L1=0.7401 val_r2=0.238
[Graph][Ep 001] tra

physchem+graph:  12%|█▏        | 136/1143 [00:00<00:01, 536.52it/s][15:14:00] SMILES Parse Error: syntax error while parsing: CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC
[15:14:00] SMILES Parse Error: check for mistakes around position 10:
[15:14:00] CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O
[15:14:00] ~~~~~~~~~^
[15:14:00] SMILES Parse Error: Failed parsing SMILES 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC' for input: 'CCCC[Sn](|[S]CC(=O)OCC(CC)CCCC)(|[S]CC(=O)OCC(CC)CCCC)CCCC'
[15:14:00] SMILES Parse Error: syntax error while parsing: O|[Co](|O)|O
[15:14:00] SMILES Parse Error: check for mistakes around position 2:
[15:14:00] O|[Co](|O)|O
[15:14:00] ~^
[15:14:00] SMILES Parse Error: Failed parsing SMILES 'O|[Co](|O)|O' for input: 'O|[Co](|O)|O'
[15:14:00] SMILES Parse Error: syntax error while parsing: CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)CCCCCCCC
[15:14:00] SMILES Parse Error: check for mistakes around position 13:
[15:14:00] CCCCCCCC[Sn]|1(|[O]C(=O)C[S]|1)

keep df: (1086, 36) graph_keep: 1086 drop_stats: {'graph_fail_or_no_bond': 34, 'mol_fail': 15, 'phys_nan': 8}
[Text][Ep 001] train_L1=1.0631 val_L1=0.8618 val_r2=-0.049
[Text][Ep 002] train_L1=0.9093 val_L1=0.9477 val_r2=0.073
[Text][Ep 003] train_L1=0.9239 val_L1=1.2115 val_r2=-0.766
[Text][Ep 004] train_L1=0.9417 val_L1=1.0663 val_r2=-0.548
[Text][Ep 005] train_L1=0.8878 val_L1=0.9276 val_r2=0.056
[Text][Ep 006] train_L1=0.9275 val_L1=1.1163 val_r2=-0.158
[Text][Ep 007] train_L1=1.0139 val_L1=0.9405 val_r2=-0.010
[Text][Ep 008] train_L1=0.9975 val_L1=0.9422 val_r2=-0.155
[Text][Ep 009] train_L1=0.9540 val_L1=0.9469 val_r2=-0.193
[Graph][Ep 001] train_L1=1.0380 val_L1=1.0150 val_r2=-0.168
[Graph][Ep 002] train_L1=0.9448 val_L1=0.9792 val_r2=-0.071
[Graph][Ep 003] train_L1=0.9223 val_L1=0.9585 val_r2=-0.052
[Graph][Ep 004] train_L1=0.9018 val_L1=0.9419 val_r2=-0.013
[Graph][Ep 005] train_L1=0.9005 val_L1=0.9408 val_r2=-0.059
[Graph][Ep 006] train_L1=0.9021 val_L1=0.9230 val_r2=-0.005
[