In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
import networkx as nx
from torch_geometric.loader import DataLoader
from torch_geometric import data as DATA

from mmfdl.util.data_gen_modify import make_variable_one
from mmfdl.util.utils_smiecfp import getInput_mask
from mmfdl.model.model_combination import comModel

import warnings
warnings.filterwarnings('ignore')

# Helper functions (MMFDL_external_test.ipynb에서 가져옴)
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception(f"input {x} not in allowable set{allowable_set}:")
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def atom_features(atom):
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
        ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
         'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
         'Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) +
        one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
        [atom.GetIsAromatic()])

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    if mol is None:
        return None, None, []
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))
    
    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])
        
    return c_size, features, edge_index

In [None]:
dataset_name = 'selectivity'
task_name = 'Kd'
START_FOLD = 1
END_FOLD = 5
ecfp_bits = 2048
max_smiles_len = 44

work_dir = '/home/rlawlsgurjh/hdd/work/MMFDL'
vocab_path = os.path.join(work_dir, 'data', dataset_name, task_name, 'smiles_char_dict.pkl')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')
print(f'Task: {task_name}, Folds: {START_FOLD} to {END_FOLD}')

Using device: cuda:0
Task: Kd, Folds: 1 to 5


In [None]:
argsCom = {
    'num_features_smi': 44,
    'num_features_ecfp': 2048,
    'num_features_x': 78,
    'dropout': 0.1, 
    'num_layer': 2,
    'num_heads': 2,
    'hidden_dim': 256,
    'output_dim': 128,
    'n_output': 1
}

In [None]:
with open(vocab_path, 'rb') as f:
    smilesVoc = pickle.load(f)
print(f'Vocabulary loaded: {len(smilesVoc)} characters')

Vocabulary loaded: 42 characters


In [None]:
def process_smiles_to_tensor(smiles_list, smilesVoc, max_smiles_len, ecfp_bits):
    encoded_smi_list = []
    ecfp_list = []
    graph_data_list = []
    valid_smiles = []
    
    for smi in smiles_list:
        try:
            # SMILES encoding
            encoded_smi = make_variable_one(smi, smilesVoc, max_smiles_len)
            
            # Generate ECFP
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            if mol.HasSubstructMatch(Chem.MolFromSmarts("[H]")):
                mol = Chem.RemoveHs(mol)
            ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=ecfp_bits)
            ecfp_array = np.array(ecfp, dtype=np.float32)
            
            # Generate Graph
            c_size, features, edge_index = smile_to_graph(smi)
            if edge_index == [] or features is None:
                continue
            
            encoded_smi_list.append(encoded_smi)
            ecfp_list.append(ecfp_array)
            graph_data_list.append((c_size, features, edge_index))
            valid_smiles.append(smi)
            
        except Exception as e:
            print(f'Error processing SMILES {smi}: {e}')
            continue
    
    return encoded_smi_list, ecfp_list, graph_data_list, valid_smiles

In [None]:
def extract_embeddings(model, encoded_smi_list, ecfp_list, graph_data_list, weights, device):
    all_unified_embeddings = []
    
    model.eval()
    with torch.no_grad():
        for encoded_smi, ecfp, (c_size, features, edge_index) in zip(
            encoded_smi_list, ecfp_list, graph_data_list
        ):
            encodedSmi = torch.LongTensor([encoded_smi]).to(device)
            encoded_smi_array = np.array([encoded_smi])
            encodedSmi_mask = torch.LongTensor(getInput_mask(encoded_smi_array)).to(device)
            ecfp_tensor = torch.FloatTensor([ecfp]).to(device)
            
            x = torch.Tensor(np.array(features)).to(device)
            edge_index_array = np.array(edge_index)
            if edge_index_array.shape[0] != 2:
                edge_index_array = edge_index_array.transpose(1, 0)
            edge_index_tensor = torch.LongTensor(edge_index_array).to(device)
            batch = torch.zeros(x.shape[0], dtype=torch.long).to(device)
            
            smi_emb, ep_emb, gc_emb = model.get_embeddings(
                encodedSmi, encodedSmi_mask, ecfp_tensor, x, edge_index_tensor, batch
            )
            
            smi_emb_np = smi_emb.cpu().numpy().flatten() # (256,)
            ep_emb_np = ep_emb.cpu().numpy().flatten()   # (128,)
            gc_emb_np = gc_emb.cpu().numpy().flatten()   # (?)

            fused_vector = np.concatenate([
                weights[0] * smi_emb_np, 
                weights[1] * ep_emb_np, 
                weights[2] * gc_emb_np
            ])
            
            all_unified_embeddings.append(fused_vector)
    
    return np.array(all_unified_embeddings)

In [None]:
for fold_num in range(START_FOLD, END_FOLD + 1):
    print("\n" + "=" * 80)
    print(f"[Fold {fold_num}] Processing...")
    print("=" * 80)

    data_dir = f'/home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/{task_name}/fold{fold_num}'

    print(f"[INFO] Loading data from {data_dir}")
    df_train = pd.read_csv(os.path.join(data_dir, 'selectivity_train.csv'))
    df_val = pd.read_csv(os.path.join(data_dir, 'selectivity_val.csv'))
    df_test = pd.read_csv(os.path.join(data_dir, 'selectivity_test.csv'))

    df_train_val = pd.concat([df_train, df_val], ignore_index=True)
    
    print(f"Train: {len(df_train)}, Val: {len(df_val)}, Train+Val: {len(df_train_val)}, Test: {len(df_test)}")

    X_train_val = df_train_val['SMILES'].values
    y_train_val = df_train_val['Ssel'].values.astype(np.float32)
    
    X_test = df_test['SMILES'].values
    y_test = df_test['Ssel'].values.astype(np.float32)
    
    checkpoint_dir = os.path.join(work_dir, 'results', 'SGD', dataset_name, task_name, f'fold{fold_num}')
    checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pt')
    
    if not os.path.exists(checkpoint_path):
        print(f'Warning: Checkpoint file not found: {checkpoint_path}')
        continue
    
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    best_epoch = checkpoint['epoch']
    
    weight_path = os.path.join(checkpoint_dir, 
                               f'{dataset_name}_{task_name}_fold{fold_num}_weight_epoch_{best_epoch}.csv')
    
    if not os.path.exists(weight_path):
        print(f'Warning: Weight file not found: {weight_path}')
        continue

    print(f"[INFO] Loading model from {checkpoint_path}")
    model = comModel(argsCom).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print(f'Model loaded from epoch {checkpoint["epoch"]}, val_loss: {checkpoint["val_loss"]:.4f}')

    weight_df = pd.read_csv(weight_path)
    weight_dict = dict(zip(weight_df['Key'], weight_df['Value']))
    weights = np.array([weight_dict[1], weight_dict[2], weight_dict[3]])
    print(f'Loaded weights: {weights}')
    

    print("[INFO] Processing Train+Val data...")
    encoded_smi_train_val, ecfp_train_val, graph_train_val, valid_smiles_train_val = process_smiles_to_tensor(
        X_train_val.tolist(), smilesVoc, max_smiles_len, ecfp_bits
    )
    print(f"  Successfully processed {len(encoded_smi_train_val)} SMILES")
    

    print("[INFO] Processing Test data...")
    encoded_smi_test, ecfp_test, graph_test, valid_smiles_test = process_smiles_to_tensor(
        X_test.tolist(), smilesVoc, max_smiles_len, ecfp_bits
    )
    print(f"  Successfully processed {len(encoded_smi_test)} SMILES")
    

    print("[INFO] Extracting Train+Val embeddings...")
    emb_train_val = extract_embeddings(model, encoded_smi_train_val, ecfp_train_val, graph_train_val, weights, device)
    print(f"  Train+Val embeddings shape: {emb_train_val.shape}")
    

    print("[INFO] Extracting Test embeddings...")
    emb_test = extract_embeddings(model, encoded_smi_test, ecfp_test, graph_test, weights, device)
    print(f"  Test embeddings shape: {emb_test.shape}")
    

    output_dir = os.path.join(work_dir, 'results', 'SGD', dataset_name, task_name,
                              f'fold{fold_num}', 'embeddings')
    os.makedirs(output_dir, exist_ok=True)
    

    tr_val_path = os.path.join(output_dir, f'tr_val_embeddings.npy')
    te_path = os.path.join(output_dir, f'te_embeddings.npy')
    
    np.save(tr_val_path, {
        'embeddings': emb_train_val,
        'Ssel': y_train_val[:len(emb_train_val)],
        'SMILES': valid_smiles_train_val,
    })
    print(f"[INFO] Saved Train+Val embeddings to {tr_val_path}")
    
    np.save(te_path, {
        'embeddings': emb_test,
        'Ssel': y_test[:len(emb_test)],
        'SMILES': valid_smiles_test,
    })
    print(f"[INFO] Saved Test embeddings to {te_path}")
    
    print(f"[Fold {fold_num}] Completed!")

print("\n" + "=" * 80)
print("[INFO] All folds processed!")
print("=" * 80)


[Fold 1] Processing...
[INFO] Loading data from /home/rlawlsgurjh/hdd/work/ChEMBLv2/data/selectivity_processed/Kd/fold1
Train: 399, Val: 45, Train+Val: 444, Test: 112
[INFO] Loading model from /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Kd/fold1/best_model.pt
Model loaded from epoch 50, val_loss: 0.2395
Loaded weights: [0.69886494 0.23145324 0.07628601]
[INFO] Processing Train+Val data...
  Successfully processed 444 SMILES
[INFO] Processing Test data...
  Successfully processed 112 SMILES
[INFO] Extracting Train+Val embeddings...
  Train+Val embeddings shape: (444, 540)
[INFO] Extracting Test embeddings...
  Test embeddings shape: (112, 540)
[INFO] Saved Train+Val embeddings to /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Kd/fold1/embeddings/tr_val_embeddings.npy
[INFO] Saved Test embeddings to /home/rlawlsgurjh/hdd/work/MMFDL/results/SGD/selectivity/Kd/fold1/embeddings/te_embeddings.npy
[Fold 1] Completed!

[Fold 2] Processing...
[INFO] Loading data from /ho