In [1]:
import pandas as pd
import numpy as np
import torch
import random
import gc
import json
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
import logging

In [2]:
class Config:
    DATA_DIR = Path("d:/DS/Datasets/TWOSIDES")
    PROC = DATA_DIR / "processed"
    
    MIN_CASE_COUNT = 3        # Minimum adverse event reports (A ≥ 3)
    MIN_PRR = 2.0             # Minimum proportional reporting ratio (PRR ≥ 2)
    # Rationale: PRR≥2 indicates 2x higher risk than background rate
    # A≥3 ensures statistical reliability (standard in pharmacovigilance)
    
    # Train/val/test split ratios
    TRAIN_RATIO = 0.70
    VAL_RATIO = 0.15
    TEST_RATIO = 0.15
    
    # ChemBERTa settings
    USE_CHEMBERTA = True
    CHEMBERTA_MODEL = "DeepChem/ChemBERTa-77M-MTR"
    CHEMBERTA_BATCH_SIZE = 32 
    
    # Molecular fingerprint settings
    ECFP_RADIUS = 2
    ECFP_BITS = 1024
    
    SEED = 42

cfg = Config()
cfg.PROC.mkdir(exist_ok=True)

In [3]:
# Setup logging
log_file = cfg.PROC / f"preprocess_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file, encoding='utf-8'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

2025-11-22 13:25:55,502 - INFO - Using device: cuda


In [5]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.SEED)

In [6]:
class DataStats:
    def __init__(self):
        self.stats = defaultdict(int)
        self.details = {}
    
    def log(self, key, value):
        # Convert numpy types to Python native types for JSON serialization
        if hasattr(value, 'item'):
            value = value.item()
        self.stats[key] = value
        logger.info(f"STAT: {key} = {value:,}" if isinstance(value, int) else f"STAT: {key} = {value}")
    
    def save(self, path):
        output = {
            'statistics': dict(self.stats),
            'details': self.details,
            'timestamp': datetime.now().isoformat(),
            'config': {
                'min_case_count': cfg.MIN_CASE_COUNT,
                'min_prr': cfg.MIN_PRR,
                'train_ratio': cfg.TRAIN_RATIO,
                'val_ratio': cfg.VAL_RATIO,
                'test_ratio': cfg.TEST_RATIO,
                'seed': cfg.SEED
            }
        }
        with open(path, 'w') as f:
            json.dump(output, f, indent=2)
        logger.info(f"Statistics saved to {path}")

stats = DataStats()

In [7]:
logger.info("Loading TWOSIDES dataset")
df = pd.read_csv(cfg.DATA_DIR / "TWOSIDES.csv", dtype=str)
stats.log("raw_twosides_rows", len(df))
logger.info(f"Loaded {len(df):,} raw DDI records")

2025-11-22 13:25:55,640 - INFO - Loading TWOSIDES dataset
2025-11-22 13:27:01,442 - INFO - STAT: raw_twosides_rows = 42,920,391
2025-11-22 13:27:01,443 - INFO - Loaded 42,920,391 raw DDI records


In [8]:
# Check for column name typos
expected_cols = ['drug_1_rxnorm_id', 'drug_2_rxnorm_id', 'A', 'PRR', 'PRR_error']
actual_cols = df.columns.tolist()

if 'drug_1_rxnorn_id' in actual_cols:
    logger.warning("Found typo 'rxnorn_id' → fixing to 'rxnorm_id'")
    df.rename(columns={
        'drug_1_rxnorn_id': 'drug_1_rxnorm_id',
        'drug_2_rxnorn_id': 'drug_2_rxnorm_id'
    }, inplace=True)

# Validate required columns
missing_cols = [col for col in expected_cols if col not in df.columns]
if missing_cols:
    logger.error(f"Missing required columns: {missing_cols}")
    raise ValueError(f"Dataset missing columns: {missing_cols}")

logger.info(f"All required columns present: {expected_cols}")

2025-11-22 13:27:01,472 - INFO - All required columns present: ['drug_1_rxnorm_id', 'drug_2_rxnorm_id', 'A', 'PRR', 'PRR_error']


In [9]:
# Convert numeric columns
for col in ['A', 'PRR', 'PRR_error']:
    df[col] = pd.to_numeric(df[col], errors='coerce')
    n_invalid = df[col].isna().sum()
    if n_invalid > 0:
        logger.warning(f"{n_invalid:,} invalid values in column '{col}' (set to NaN)")

stats.log("invalid_numeric_values", df[['A', 'PRR', 'PRR_error']].isna().sum().sum())

2025-11-22 13:27:48,242 - INFO - STAT: invalid_numeric_values = 3


In [10]:
logger.info("Filtering DDI records")
initial_count = len(df)

# Filter by case count (A ≥ 3)
df = df[df['A'] >= cfg.MIN_CASE_COUNT].copy()
logger.info(f"After A >= {cfg.MIN_CASE_COUNT}: {len(df):,} records ({len(df)/initial_count*100:.1f}% retained)")

# Filter by PRR (≥ 2.0)
df = df[df['PRR'] >= cfg.MIN_PRR].copy()
logger.info(f"After PRR >= {cfg.MIN_PRR}: {len(df):,} records ({len(df)/initial_count*100:.1f}% retained)")

# Remove records with missing PRR_error
df = df[df['PRR_error'].notna()].copy()
logger.info(f"After removing missing PRR_error: {len(df):,} records")

stats.log("filtered_twosides_rows", len(df))
stats.log("filter_retention_rate", len(df) / initial_count)


2025-11-22 13:27:48,263 - INFO - Filtering DDI records
2025-11-22 13:27:52,316 - INFO - After A >= 3: 13,668,432 records (31.8% retained)
2025-11-22 13:27:54,067 - INFO - After PRR >= 2.0: 10,800,847 records (25.2% retained)
2025-11-22 13:27:56,044 - INFO - After removing missing PRR_error: 10,800,847 records
2025-11-22 13:27:56,045 - INFO - STAT: filtered_twosides_rows = 10,800,847
2025-11-22 13:27:56,046 - INFO - STAT: filter_retention_rate = 0.25164838316594085


In [11]:
# Standardize RxNorm IDs 
df['drug_1_rxnorm_id'] = df['drug_1_rxnorm_id'].astype(str).str.zfill(6)
df['drug_2_rxnorm_id'] = df['drug_2_rxnorm_id'].astype(str).str.zfill(6)

In [12]:
# Extract unique drug pairs
edges = df[['drug_1_rxnorm_id', 'drug_2_rxnorm_id', 'PRR']].copy()
edges = edges.drop_duplicates()
logger.info(f"Unique drug pairs: {len(edges):,}")
stats.log("unique_drug_pairs_before_smiles", len(edges))

2025-11-22 13:28:09,693 - INFO - Unique drug pairs: 5,197,357
2025-11-22 13:28:09,695 - INFO - STAT: unique_drug_pairs_before_smiles = 5,197,357


In [13]:
logger.info("Loading molecular structures (SMILES)")
smiles_file = cfg.PROC / "rxnorm_to_smiles_1902.csv"
smiles_df = pd.read_csv(smiles_file)
smiles_df['rxnorm_id'] = smiles_df['rxnorm_id'].astype(str).str.zfill(6)

2025-11-22 13:28:09,724 - INFO - Loading molecular structures (SMILES)


In [14]:
# Validate SMILES
logger.info("Validating SMILES strings...")
valid_smiles = []
invalid_count = 0

for idx, row in tqdm(smiles_df.iterrows(), total=len(smiles_df), desc="Validating SMILES"):
    mol = Chem.MolFromSmiles(row['smiles'])
    if mol is not None and mol.GetNumAtoms() > 0:
        valid_smiles.append(row)
    else:
        invalid_count += 1

smiles_df = pd.DataFrame(valid_smiles)

2025-11-22 13:28:09,830 - INFO - Validating SMILES strings...
Validating SMILES: 100%|██████████| 1663/1663 [00:00<00:00, 2692.25it/s]


In [15]:
logger.info(f"Valid SMILES: {len(smiles_df):,} | Invalid: {invalid_count:,}")
stats.log("total_smiles_available", len(smiles_df))
stats.log("invalid_smiles_count", invalid_count)
rx_to_smiles = dict(zip(smiles_df.rxnorm_id, smiles_df.smiles))

2025-11-22 13:28:10,502 - INFO - Valid SMILES: 1,663 | Invalid: 0
2025-11-22 13:28:10,504 - INFO - STAT: total_smiles_available = 1,663
2025-11-22 13:28:10,506 - INFO - STAT: invalid_smiles_count = 0


In [16]:
# Filter edges to only those with valid SMILES
edges_before = len(edges)
edges = edges[
    edges['drug_1_rxnorm_id'].isin(rx_to_smiles) &
    edges['drug_2_rxnorm_id'].isin(rx_to_smiles)
].copy()

logger.info(f"Drug pairs with valid SMILES: {len(edges):,} (removed {edges_before - len(edges):,})")
stats.log("drug_pairs_with_smiles", len(edges))
stats.log("pairs_removed_no_smiles", edges_before - len(edges))

2025-11-22 13:28:12,409 - INFO - Drug pairs with valid SMILES: 4,735,600 (removed 461,757)
2025-11-22 13:28:12,411 - INFO - STAT: drug_pairs_with_smiles = 4,735,600
2025-11-22 13:28:12,413 - INFO - STAT: pairs_removed_no_smiles = 461,757


In [17]:
logger.info("Building node list and creating drug-cold splits")
all_drugs = sorted(set(edges['drug_1_rxnorm_id']) | set(edges['drug_2_rxnorm_id']))
node_to_idx = {rx: i for i, rx in enumerate(all_drugs)}
idx_to_rx = all_drugs
N_nodes = len(all_drugs)

logger.info(f"Total unique drugs: {N_nodes:,}")
stats.log("total_unique_drugs", N_nodes)

2025-11-22 13:28:12,442 - INFO - Building node list and creating drug-cold splits
2025-11-22 13:28:13,371 - INFO - Total unique drugs: 1,602
2025-11-22 13:28:13,372 - INFO - STAT: total_unique_drugs = 1,602


In [18]:
# Drug-cold split (inductive evaluation)
indices = np.arange(N_nodes)
rng = np.random.default_rng(cfg.SEED)
rng.shuffle(indices)

n_train = int(cfg.TRAIN_RATIO * N_nodes)
n_val = int(cfg.VAL_RATIO * N_nodes)

train_nodes = set(indices[:n_train].tolist())
val_nodes = set(indices[n_train:n_train+n_val].tolist())
test_nodes = set(indices[n_train+n_val:].tolist())

logger.info(f"Train drugs:{len(train_nodes):,} ({len(train_nodes)/N_nodes*100:.1f}%)")
logger.info(f"Val drugs:{len(val_nodes):,} ({len(val_nodes)/N_nodes*100:.1f}%)")
logger.info(f"Test drugs:{len(test_nodes):,} ({len(test_nodes)/N_nodes*100:.1f}%)")

stats.log("train_drugs", len(train_nodes))
stats.log("val_drugs", len(val_nodes))
stats.log("test_drugs", len(test_nodes))

2025-11-22 13:28:13,384 - INFO - Train drugs:1,121 (70.0%)
2025-11-22 13:28:13,386 - INFO - Val drugs:240 (15.0%)
2025-11-22 13:28:13,386 - INFO - Test drugs:241 (15.0%)
2025-11-22 13:28:13,387 - INFO - STAT: train_drugs = 1,121
2025-11-22 13:28:13,388 - INFO - STAT: val_drugs = 240
2025-11-22 13:28:13,388 - INFO - STAT: test_drugs = 241


In [19]:
def rx_pair_to_idx_pair(r1, r2):
    return node_to_idx[r1], node_to_idx[r2]

raw_pairs = []
for _, row in edges.iterrows():
    u, v = rx_pair_to_idx_pair(row['drug_1_rxnorm_id'], row['drug_2_rxnorm_id'])
    if u == v:  # Remove self-loops
        continue
    # Store as undirected (sorted)
    if u > v:
        u, v = v, u
    raw_pairs.append((u, v))

raw_pairs = list(set(raw_pairs))
logger.info(f"Unique undirected edges: {len(raw_pairs):,}")
stats.log("unique_undirected_edges", len(raw_pairs))

2025-11-22 13:30:38,130 - INFO - Unique undirected edges: 162,914
2025-11-22 13:30:38,131 - INFO - STAT: unique_undirected_edges = 162,914


In [None]:
train_pairs, val_pairs, test_pairs = [], [], []
cross_split_pairs = 0

for (u, v) in raw_pairs:
    if u in train_nodes and v in train_nodes:
        train_pairs.append((u, v))
    elif u in val_nodes and v in val_nodes:
        val_pairs.append((u, v))
    elif u in test_nodes and v in test_nodes:
        test_pairs.append((u, v))
    else:
        cross_split_pairs += 1  # Edge crosses splits

logger.info(f"Train edges (both drugs in train set): {len(train_pairs):,}")
logger.info(f"Val edges (both drugs in val set):     {len(val_pairs):,}")
logger.info(f"Test edges (both drugs in test set):   {len(test_pairs):,}")
logger.info(f"Cross-split edges (excluded):          {cross_split_pairs:,}")

stats.log("train_edges", len(train_pairs))
stats.log("val_edges", len(val_pairs))
stats.log("test_edges", len(test_pairs))
stats.log("cross_split_edges_excluded", cross_split_pairs)

2025-11-22 13:30:38,235 - INFO - Train edges (both drugs in train set): 77,815
2025-11-22 13:30:38,236 - INFO - Val edges (both drugs in val set):     4,198
2025-11-22 13:30:38,237 - INFO - Test edges (both drugs in test set):   3,530
2025-11-22 13:30:38,238 - INFO - Cross-split edges (excluded):          77,371
2025-11-22 13:30:38,240 - INFO - STAT: train_edges = 77,815
2025-11-22 13:30:38,241 - INFO - STAT: val_edges = 4,198
2025-11-22 13:30:38,243 - INFO - STAT: test_edges = 3,530
2025-11-22 13:30:38,244 - INFO - STAT: cross_split_edges_excluded = 77,371


In [21]:
logger.info("Computing molecular features")

def compute_ecfp_physchem(smiles: str) -> np.ndarray:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None or mol.GetNumAtoms() == 0:
        return None
    
    # ECFP fingerprint
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, cfg.ECFP_RADIUS, nBits=cfg.ECFP_BITS)
    arr = np.zeros((cfg.ECFP_BITS,), dtype=np.float32)
    DataStructs.ConvertToNumpyArray(fp, arr)
    
    # Physicochemical properties
    try:
        phys = np.array([
            Descriptors.MolWt(mol),
            Descriptors.MolLogP(mol),
            Descriptors.NumHAcceptors(mol),
            Descriptors.NumHDonors(mol),
            Descriptors.NumRotatableBonds(mol),
            Descriptors.TPSA(mol),
        ], dtype=np.float32)
    except:
        logger.warning(f"Failed to compute descriptors for SMILES: {smiles[:50]}...")
        return None
    
    return np.concatenate([arr, phys])

2025-11-22 13:30:38,259 - INFO - Computing molecular features


In [22]:
logger.info("Computing ECFP fingerprints and physicochemical descriptors")
smiles_list = [rx_to_smiles[rx] for rx in idx_to_rx]

ecfp_features = []
failed_ecfp = 0

for smiles in tqdm(smiles_list, desc="ECFP+Physchem"):
    feat = compute_ecfp_physchem(smiles)
    if feat is not None:
        ecfp_features.append(feat)
    else:
        ecfp_features.append(np.zeros(cfg.ECFP_BITS + 6, dtype=np.float32))
        failed_ecfp += 1

ecfp_arr = np.stack(ecfp_features, axis=0)
logger.info(f"ECFP features shape: {ecfp_arr.shape}")
logger.info(f"Failed to compute features for {failed_ecfp} drugs (zero-padded)")
stats.log("failed_ecfp_computation", failed_ecfp)

2025-11-22 13:30:38,288 - INFO - Computing ECFP fingerprints and physicochemical descriptors
ECFP+Physchem: 100%|██████████| 1602/1602 [00:01<00:00, 885.04it/s]
2025-11-22 13:30:40,106 - INFO - ECFP features shape: (1602, 1030)
2025-11-22 13:30:40,107 - INFO - Failed to compute features for 0 drugs (zero-padded)
2025-11-22 13:30:40,107 - INFO - STAT: failed_ecfp_computation = 0


In [23]:
# ChemBERTa embeddings
if cfg.USE_CHEMBERTA:
    logger.info("="*70)
    logger.info("Computing ChemBERTa embeddings")
    logger.info("="*70)
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(cfg.CHEMBERTA_MODEL)
        chemberta = AutoModel.from_pretrained(cfg.CHEMBERTA_MODEL).to(device)
        chemberta.eval()
        
        @torch.no_grad()
        def chemberta_embeddings(smiles_batch, batch_size=32):
            embeddings = []
            for i in tqdm(range(0, len(smiles_batch), batch_size), desc="ChemBERTa"):
                batch = smiles_batch[i:i+batch_size]
                enc = tokenizer(batch, return_tensors="pt", padding=True, 
                               truncation=True, max_length=256).to(device)
                out = chemberta(**enc).last_hidden_state.mean(dim=1)
                embeddings.append(out.cpu().numpy())
            return np.concatenate(embeddings, axis=0)
        
        chemberta_embs = chemberta_embeddings(smiles_list, batch_size=cfg.CHEMBERTA_BATCH_SIZE)
        logger.info(f"ChemBERTa embeddings shape: {chemberta_embs.shape}")
        
        # Combine ECFP + ChemBERTa
        X_global = np.concatenate([chemberta_embs, ecfp_arr], axis=1)
        logger.info(f"Combined features shape: {X_global.shape}")
        stats.log("chemberta_dim", chemberta_embs.shape[1])
        
        del chemberta
        gc.collect()
        torch.cuda.empty_cache()
        
    except Exception as e:
        logger.error(f"ChemBERTa computation failed: {e}")
        logger.info("Falling back to ECFP-only features")
        X_global = ecfp_arr
else:
    logger.info("Skipping ChemBERTa (USE_CHEMBERTA=False)")
    X_global = ecfp_arr

stats.log("global_feature_dim", X_global.shape[1])

2025-11-22 13:30:40,119 - INFO - Computing ChemBERTa embeddings
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ChemBERTa: 100%|██████████| 51/51 [00:01<00:00, 45.87it/s]
2025-11-22 13:30:45,490 - INFO - ChemBERTa embeddings shape: (1602, 384)
2025-11-22 13:30:45,493 - INFO - Combined features shape: (1602, 1414)
2025-11-22 13:30:45,493 - INFO - STAT: chemberta_dim = 384
2025-11-22 13:30:45,650 - INFO - STAT: global_feature_dim = 1,414


In [24]:
# Standardize features
logger.info("Standardizing features (fitted on training set only)")
scaler = StandardScaler()
train_indices = list(train_nodes)
scaler.fit(X_global[train_indices])
X_scaled = scaler.transform(X_global).astype(np.float32)

logger.info(f"Feature mean (train): {X_global[train_indices].mean():.4f}")
logger.info(f"Feature std (train):  {X_global[train_indices].std():.4f}")
logger.info(f"Feature mean (scaled): {X_scaled.mean():.4f}")
logger.info(f"Feature std (scaled):  {X_scaled.std():.4f}")

np.save(cfg.PROC / "node_features.npy", X_scaled)
logger.info(f"Saved node_features.npy with shape {X_scaled.shape}")

2025-11-22 13:30:45,673 - INFO - Standardizing features (fitted on training set only)
2025-11-22 13:30:45,709 - INFO - Feature mean (train): 0.4439
2025-11-22 13:30:45,717 - INFO - Feature std (train):  21.8085
2025-11-22 13:30:45,719 - INFO - Feature mean (scaled): 0.0020
2025-11-22 13:30:45,728 - INFO - Feature std (scaled):  1.0116
2025-11-22 13:30:45,734 - INFO - Saved node_features.npy with shape (1602, 1414)


In [None]:
logger.info("Building molecular graphs for GNN")
def atom_features(atom: Chem.rdchem.Atom) -> torch.Tensor:
    """Extract atom-level features"""
    return torch.tensor([
        atom.GetAtomicNum(), # Atomic number
        atom.GetTotalDegree(), # Degree
        int(atom.GetIsAromatic()), # Aromaticity
        atom.GetFormalCharge(), # Formal charge
    ], dtype=torch.float)

def smiles_to_pyg(smiles: str) -> Data:
    """Convert SMILES to PyG graph"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None or mol.GetNumAtoms() == 0:
        return None
    
    # Atom features
    x_list = [atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.stack(x_list, dim=0)
    
    # Bond edges (undirected)
    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.append([i, j])
        edges.append([j, i])
    
    if len(edges) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

2025-11-22 13:30:45,752 - INFO - Building molecular graphs for GNN


In [26]:
drug_graphs = []
failed_graphs = []

for rx in tqdm(idx_to_rx, desc="SMILES → Graphs"):
    smiles = rx_to_smiles[rx]
    graph = smiles_to_pyg(smiles)
    
    if graph is None:
        failed_graphs.append(rx)
        # Create minimal valid graph as fallback
        x = torch.zeros((1, 4), dtype=torch.float)
        edge_index = torch.empty((2, 0), dtype=torch.long)
        graph = Data(x=x, edge_index=edge_index)
    
    drug_graphs.append(graph)

logger.info(f"Successfully built {len(drug_graphs)} molecular graphs")
logger.info(f"Failed to parse {len(failed_graphs)} drugs (using fallback graphs)")
stats.log("total_graphs", len(drug_graphs))
stats.log("failed_graph_parsing", len(failed_graphs))

SMILES → Graphs: 100%|██████████| 1602/1602 [00:01<00:00, 1375.58it/s]
2025-11-22 13:30:46,933 - INFO - Successfully built 1602 molecular graphs
2025-11-22 13:30:46,933 - INFO - Failed to parse 0 drugs (using fallback graphs)
2025-11-22 13:30:46,934 - INFO - STAT: total_graphs = 1,602
2025-11-22 13:30:46,935 - INFO - STAT: failed_graph_parsing = 0


In [27]:
torch.save(drug_graphs, cfg.PROC / "drug_graphs.pt")

In [28]:
pd.DataFrame({"rxnorm_id": idx_to_rx}).to_csv(cfg.PROC / "node_list.csv", index=False)
pd.DataFrame(train_pairs, columns=["u", "v"]).to_csv(cfg.PROC / "train_edges.csv", index=False)
pd.DataFrame(val_pairs, columns=["u", "v"]).to_csv(cfg.PROC / "val_edges.csv", index=False)
pd.DataFrame(test_pairs, columns=["u", "v"]).to_csv(cfg.PROC / "test_edges.csv", index=False)

stats.save(cfg.PROC / "preprocessing_stats.json")

2025-11-22 13:30:47,206 - INFO - Statistics saved to d:\DS\Datasets\TWOSIDES\processed\preprocessing_stats.json


In [29]:
logger.info(f"")
logger.info(f"Dataset Summary:")
logger.info(f"Total drugs: {N_nodes:,}")
logger.info(f"Train drugs: {len(train_nodes):,} ({len(train_nodes)/N_nodes*100:.1f}%)")
logger.info(f"Val drugs: {len(val_nodes):,} ({len(val_nodes)/N_nodes*100:.1f}%)")
logger.info(f"Test drugs: {len(test_nodes):,} ({len(test_nodes)/N_nodes*100:.1f}%)")
logger.info(f"")
logger.info(f"Total interactions: {len(raw_pairs):,}")
logger.info(f"Train interactions: {len(train_pairs):,}")
logger.info(f"Val interactions: {len(val_pairs):,}")
logger.info(f" Test interactions: {len(test_pairs):,}")
logger.info(f"")
logger.info(f"Features:")
logger.info(f"Global feature dim: {X_scaled.shape[1]}")
logger.info(f" - ChemBERTa: {'Yes' if cfg.USE_CHEMBERTA else 'No'}")
logger.info(f" - ECFP-{cfg.ECFP_RADIUS}: {cfg.ECFP_BITS} bits")
logger.info(f" - Physicochemical: 6 descriptors")
logger.info(f"")
logger.info(f"Output files saved to: {cfg.PROC}")
logger.info(f"Log file: {log_file}")

print(f"Check {log_file} for detailed logs")
print(f"Check {cfg.PROC / 'preprocessing_stats.json'} for statistics")

2025-11-22 13:30:47,239 - INFO - 
2025-11-22 13:30:47,242 - INFO - Dataset Summary:
2025-11-22 13:30:47,245 - INFO - Total drugs: 1,602
2025-11-22 13:30:47,250 - INFO - Train drugs: 1,121 (70.0%)
2025-11-22 13:30:47,252 - INFO - Val drugs: 240 (15.0%)
2025-11-22 13:30:47,255 - INFO - Test drugs: 241 (15.0%)
2025-11-22 13:30:47,258 - INFO - 
2025-11-22 13:30:47,261 - INFO - Total interactions: 162,914
2025-11-22 13:30:47,266 - INFO - Train interactions: 77,815
2025-11-22 13:30:47,270 - INFO - Val interactions: 4,198
2025-11-22 13:30:47,274 - INFO -  Test interactions: 3,530
2025-11-22 13:30:47,278 - INFO - 
2025-11-22 13:30:47,281 - INFO - Features:
2025-11-22 13:30:47,285 - INFO - Global feature dim: 1414
2025-11-22 13:30:47,287 - INFO -  - ChemBERTa: Yes
2025-11-22 13:30:47,291 - INFO -  - ECFP-2: 1024 bits
2025-11-22 13:30:47,295 - INFO -  - Physicochemical: 6 descriptors
2025-11-22 13:30:47,296 - INFO - 
2025-11-22 13:30:47,299 - INFO - Output files saved to: d:\DS\Datasets\TWOSIDES

Check d:\DS\Datasets\TWOSIDES\processed\preprocess_20251122_132555.log for detailed logs
Check d:\DS\Datasets\TWOSIDES\processed\preprocessing_stats.json for statistics
