In [1]:
import os
import pickle
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from mmfdl.util.utils_smiecfp import getInput_mask
from mmfdl.util.utils import formDataset_Single
from mmfdl.model.model_combination import comModel

In [2]:
def load_weights_from_csv(weight_path: str) -> np.ndarray:
    """
    Load fusion weights from csv file.

    Args:
        weight_path (str): Path to weight csv with columns [Key, Value].

    Returns:
        np.ndarray: (3,) weights in order (1,2,3).
    """
    df = pd.read_csv(weight_path)
    weight_dict = dict(zip(df["Key"], df["Value"]))
    w = np.array([weight_dict[1], weight_dict[2], weight_dict[3]], dtype=np.float32)
    return w

In [3]:
def extract_embeddings_from_pt(
    model: torch.nn.Module,
    loader: DataLoader,
    weights: np.ndarray,
    device: torch.device,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Extract unified embeddings from a DataLoader created from formDataset_Single (.pt).

    Args:
        model (torch.nn.Module): Trained comModel with get_embeddings().
        loader (DataLoader): DataLoader over .pt dataset.
        weights (np.ndarray): (3,) fusion weights.
        device (torch.device): cuda/cpu.

    Returns:
        Tuple[np.ndarray, np.ndarray, List[str]]:
            embeddings: (N, D)
            labels: (N,)
            smiles: list of SMILES strings (if present in batch, else empty strings)
    """
    w = np.asarray(weights, dtype=np.float32)
    if w.shape != (3,):
        raise ValueError(f"weights must be (3,), got {w.shape}")

    model.eval()

    all_emb: List[np.ndarray] = []
    all_y: List[np.ndarray] = []
    all_smiles: List[str] = []

    with torch.no_grad():
        for data in loader:
            # ------------------------------------------------------
            # NOTE: This assumes your .pt data object has:
            #   data.smi, data.ep, data.x, data.edge_index, data.batch, data.y
            # and optionally data.SMILES or data.smiles (string list)
            # ------------------------------------------------------
            encodedSmi = torch.as_tensor(data.smi, dtype=torch.long, device=device)  # (B, L)
            smi_mask_np = getInput_mask(encodedSmi.cpu().numpy())                    # (B, L)
            encodedSmi_mask = torch.as_tensor(smi_mask_np, dtype=torch.long, device=device)

            ecfp = torch.as_tensor(data.ep, dtype=torch.float32, device=device)
            # Ensure GRU input shape is (B, T, F)
            if ecfp.dim() == 2:
                ecfp = ecfp.unsqueeze(1)  # (B, 1, 2048)

            x = data.x.to(device)
            edge_index = data.edge_index.to(device)
            batch = data.batch.to(device)

            y = data.y.view(-1).detach().cpu().numpy().astype(np.float32)

            smi_emb, ep_emb, gc_emb = model.get_embeddings(
                encodedSmi, encodedSmi_mask, ecfp, x, edge_index, batch
            )

            smi_vec = smi_emb.detach().cpu().numpy().astype(np.float32)  # (B, hidden_dim)
            ep_vec = ep_emb.detach().cpu().numpy().astype(np.float32)     # (B, output_dim)
            gc_vec = gc_emb.detach().cpu().numpy().astype(np.float32)     # (B, 2*num_features_x)

            fused = np.concatenate([w[0] * smi_vec, w[1] * ep_vec, w[2] * gc_vec], axis=1)  # (B, D)
            all_emb.append(fused)
            all_y.append(y)

            # SMILES string (optional)
            if hasattr(data, "SMILES"):
                # PyG Batch may store as list
                all_smiles.extend(list(data.SMILES))
            elif hasattr(data, "smiles"):
                all_smiles.extend(list(data.smiles))
            else:
                all_smiles.extend([""] * fused.shape[0])

    embeddings = np.vstack(all_emb) if len(all_emb) > 0 else np.zeros((0, 0), dtype=np.float32)
    labels = np.concatenate(all_y) if len(all_y) > 0 else np.zeros((0,), dtype=np.float32)
    return embeddings, labels, all_smiles

In [4]:
# =============================
# Main
# =============================
dataset_name = "selectivity"
task_name = "Ki"
START_FOLD = 1
END_FOLD = 5

work_dir = "/home/rlawlsgurjh/hdd/work/MMFDL"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"[INFO] device={device}")

vocab_path = os.path.join(work_dir, "data", dataset_name, task_name, "smiles_char_dict.pkl")
with open(vocab_path, "rb") as f:
    smilesVoc = pickle.load(f)

[INFO] device=cuda:0


In [5]:
argsCom = {
    "num_features_smi": len(smilesVoc),
    "num_features_ecfp": 2048,
    "num_features_x": 78,
    "dropout": 0.1,
    "num_layer": 2,
    "num_heads": 2,
    "hidden_dim": 256,
    "output_dim": 128,
    "n_output": 1,
}

In [6]:
for fold_num in range(START_FOLD, END_FOLD + 1):
    print("\n" + "=" * 80)
    print(f"[Fold {fold_num}] PT-based embedding extraction")
    print("=" * 80)

    # -----------------------------------------
    # 1) Load trained model + weights
    # -----------------------------------------
    checkpoint_dir = os.path.join(work_dir, "results", "SGD", dataset_name, task_name, f"fold{fold_num}")
    checkpoint_path = os.path.join(checkpoint_dir, "best_model.pt")
    if not os.path.exists(checkpoint_path):
        print(f"[WARN] checkpoint not found: {checkpoint_path}")
        continue

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    best_epoch = int(checkpoint["epoch"])

    weight_path = os.path.join(
        checkpoint_dir,
        f"{dataset_name}_{task_name}_fold{fold_num}_weight_epoch_{best_epoch}.csv",
    )
    if not os.path.exists(weight_path):
        print(f"[WARN] weight file not found: {weight_path}")
        continue

    weights = load_weights_from_csv(weight_path)
    print(f"[INFO] weights={weights}")

    model = comModel(argsCom).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # -----------------------------------------
    # 2) Load .pt datasets (same as training)
    # -----------------------------------------
    pt_dir = os.path.join(work_dir, "data", dataset_name, task_name, f"fold{fold_num}")

    train_pt = os.path.join(pt_dir, f"{dataset_name}_train.pt")
    val_pt = os.path.join(pt_dir, f"{dataset_name}_val.pt")
    test_pt = os.path.join(pt_dir, f"{dataset_name}_test.pt")

    if not (os.path.exists(train_pt) and os.path.exists(val_pt) and os.path.exists(test_pt)):
        print(f"[WARN] missing pt files in: {pt_dir}")
        print(f"  train_pt exists={os.path.exists(train_pt)}")
        print(f"  val_pt   exists={os.path.exists(val_pt)}")
        print(f"  test_pt  exists={os.path.exists(test_pt)}")
        continue

    train_data = formDataset_Single(root=pt_dir, dataset=f"{dataset_name}_train")
    val_data = formDataset_Single(root=pt_dir, dataset=f"{dataset_name}_val")
    test_data = formDataset_Single(root=pt_dir, dataset=f"{dataset_name}_test")

    # combine train+val (like your earlier script)
    tr_val_data = list(train_data) + list(val_data)

    # DataLoader (PyG dataset can be loaded with torch_geometric.loader.DataLoader ideally)
    # If you already used torch.utils.data.DataLoader during training, keep it consistent.
    tr_val_loader = DataLoader(tr_val_data, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

    # -----------------------------------------
    # 3) Extract embeddings
    # -----------------------------------------
    print("[INFO] Extracting Train+Val embeddings from PT...")
    emb_tr_val, y_tr_val, smiles_tr_val = extract_embeddings_from_pt(model, tr_val_loader, weights, device)
    print(f"[INFO] Train+Val embeddings: {emb_tr_val.shape}, labels: {y_tr_val.shape}")

    print("[INFO] Extracting Test embeddings from PT...")
    emb_te, y_te, smiles_te = extract_embeddings_from_pt(model, test_loader, weights, device)
    print(f"[INFO] Test embeddings: {emb_te.shape}, labels: {y_te.shape}")

    # -----------------------------------------
    # 4) Save (npz)
    # -----------------------------------------
    out_dir = os.path.join(checkpoint_dir, "embeddings")
    os.makedirs(out_dir, exist_ok=True)

    tr_val_path = os.path.join(out_dir, "tr_val_embeddings.npy")
    te_path = os.path.join(out_dir, "te_embeddings.npy")

    np.save(
        tr_val_path,
        {
            "embeddings": emb_tr_val,
            "Ssel": y_tr_val,
            "SMILES": np.array(smiles_tr_val, dtype=object),
        }
    )
    np.save(
        te_path,
        {
            "embeddings": emb_te,
            "Ssel": y_te,
            "SMILES": np.array(smiles_te, dtype=object),
        }
    )

    print(f"[INFO] Saved: {tr_val_path}")
    print(f"[INFO] Saved: {te_path}")
    print(f"[Fold {fold_num}] Done.")


[Fold 1] PT-based embedding extraction
[INFO] weights=[0.56694883 0.35034388 0.08510989]
[INFO] Extracting Train+Val embeddings from PT...
[INFO] Train+Val embeddings: (1459, 540), labels: (1459,)
[INFO] Extracting Test embeddings from PT...
[INFO] Test embeddings: (365, 540), labels: (365,)
[INFO] Saved: /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Ki/fold1/embeddings/tr_val_embeddings.npy
[INFO] Saved: /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Ki/fold1/embeddings/te_embeddings.npy
[Fold 1] Done.

[Fold 2] PT-based embedding extraction
[INFO] weights=[0.6294782  0.26747948 0.10409598]
[INFO] Extracting Train+Val embeddings from PT...
[INFO] Train+Val embeddings: (1459, 540), labels: (1459,)
[INFO] Extracting Test embeddings from PT...
[INFO] Test embeddings: (365, 540), labels: (365,)
[INFO] Saved: /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Ki/fold2/embeddings/tr_val_embeddings.npy
[INFO] Saved: /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/sel