In [None]:
# -*- coding: utf-8 -*-
"""
BBBC036: 单文件三段式结构
Part 1: 数据读取与准备
Part 2: 模型
Part 3: 训练与评估（含 main）
"""
# =========================
# 通用导入
# =========================
import os
import time
import pickle
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from scipy.stats import pearsonr

# 设置随机种子（保持原逻辑）
torch.manual_seed(42)
np.random.seed(42)


# ============================================================
# Part 1) 数据读取与准备
# ============================================================
def load_from_HDF(fname):
    """Load data from a HDF5 file to a dictionary."""
    data = dict()
    with h5py.File(fname, 'r') as f:
        for key in f:
            data[key] = np.asarray(f[key])
            if isinstance(data[key][0], np.bytes_):
                data[key] = data[key].astype(str)
    return data


class DrugResponseDataset(Dataset):
    """自定义数据集：返回 SMILES 向量、处理前特征、处理后特征"""
    def __init__(self, smiles_embeddings, features_before, features_after):
        self.smiles_embeddings = torch.tensor(smiles_embeddings, dtype=torch.float32)
        self.features_before = torch.tensor(features_before, dtype=torch.float32)
        self.features_after = torch.tensor(features_after, dtype=torch.float32)

    def __len__(self):
        return len(self.smiles_embeddings)

    def __getitem__(self, idx):
        return {
            'smiles': self.smiles_embeddings[idx],
            'features_before': self.features_before[idx],
            'features_after': self.features_after[idx]
        }


def prepare_dataloaders(
    smiles_pickle_path: str,
    h5_path: str,
    test_size: float = 0.1,
    random_state: int = 42,
    batch_size: int = 1024
):
    """
    加载HDF5与SMILES嵌入，过滤无嵌入样本，划分数据集并返回DataLoader及维度信息。
    完全保持原脚本的数据处理与打印逻辑。
    """
    print("Loading data...")
    with open(smiles_pickle_path, 'rb') as f:
        smi2emb = pickle.load(f)

    cpg_data = load_from_HDF(h5_path)
    print(f"Data shapes: SMILES {cpg_data['canonical_smiles'].shape}, Target {cpg_data['target'].shape}, Control {cpg_data['control'].shape}")

    # 提取数据
    smiles_list = cpg_data['canonical_smiles']
    features_before = cpg_data['control']  # 处理前特征
    features_after = cpg_data['target']    # 处理后特征

    # 分布信息（保持原逻辑）
    print(f"Features before mean: {np.mean(features_before):.6f}, std: {np.std(features_before):.6f}")
    print(f"Features after  mean: {np.mean(features_after):.6f}, std: {np.std(features_after):.6f}")

    # 将SMILES转为嵌入，过滤缺失
    smiles_embeddings = []
    invalid_indices = []
    for i, smiles in enumerate(smiles_list):
        if smiles in smi2emb:
            smiles_embeddings.append(smi2emb[smiles].astype(np.float32))
        else:
            print(f"Warning: SMILES '{smiles}' not found in embeddings, skipping index {i}.")
            invalid_indices.append(i)

    smiles_embeddings = np.array(smiles_embeddings)
    features_before = np.delete(features_before, invalid_indices, axis=0)
    features_after  = np.delete(features_after,  invalid_indices, axis=0)

    print(f"Processed data shapes: SMILES Embeddings {smiles_embeddings.shape}, Features Before {features_before.shape}, Features After {features_after.shape}")

    # 划分训练/测试 (9:1)
    X_smiles_train, X_smiles_test, X_before_train, X_before_test, y_train, y_test = train_test_split(
        smiles_embeddings, features_before, features_after, test_size=test_size, random_state=random_state
    )
    print(f"Train size: {len(X_smiles_train)}, Test size: {len(X_smiles_test)}")

    # DataLoader
    train_dataset = DrugResponseDataset(X_smiles_train, X_before_train, y_train)
    test_dataset  = DrugResponseDataset(X_smiles_test,  X_before_test,  y_test)
    train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader   = DataLoader(test_dataset,  batch_size=batch_size)

    # 维度信息
    smiles_dim  = smiles_embeddings.shape[1]
    feature_dim = features_before.shape[1]
    return train_loader, test_loader, smiles_dim, feature_dim


# ============================================================
# Part 2) 模型
# ============================================================
class ImprovedMLP(nn.Module):
    """
    改进的MLP模型定义 - 使用适合正负值数据的激活函数。
    逻辑、结构、超参与原代码完全一致。
    """
    def __init__(self, smiles_dim=384, feature_dim=591, hidden_dim=1024, dropout=0.2):
        super(ImprovedMLP, self).__init__()

        # SMILES特征处理分支 - 使用LeakyReLU激活函数
        self.smiles_branch = nn.Sequential(
            nn.Linear(smiles_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout)
        )

        # 细胞形态特征处理分支 - 使用LeakyReLU激活函数
        self.feature_branch = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout)
        )

        # 合并分支 - 最终Tanh帮助输出位于[-1, 1]
        self.combined = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, feature_dim),
            nn.Tanh()
        )

        # 缩放与平移（可学习）
        self.scale_factor = nn.Parameter(torch.ones(1))
        self.shift_factor = nn.Parameter(torch.zeros(1))

    def forward(self, smiles, features_before):
        smiles_out   = self.smiles_branch(smiles)
        features_out = self.feature_branch(features_before)
        combined     = torch.cat([smiles_out, features_out], dim=1)
        return self.scale_factor * self.combined(combined) + self.shift_factor


# ============================================================
# Part 3) 训练与评估（含 main）
# ============================================================
def calculate_per_sample_metrics(predictions, targets):
    """
    计算每个样本的PCC和RMSE（保持原始实现与打印逻辑一致）
    """
    n_samples = predictions.shape[0]
    pcc_values = np.zeros(n_samples)
    rmse_values = np.zeros(n_samples)

    for i in range(n_samples):
        corr, _ = pearsonr(predictions[i], targets[i])
        pcc_values[i] = corr
        mse = np.mean((predictions[i] - targets[i]) ** 2)
        rmse_values[i] = np.sqrt(mse)

    avg_pcc = np.mean(pcc_values)
    avg_rmse = np.mean(rmse_values)
    return avg_pcc, avg_rmse, pcc_values, rmse_values


def train_model(model, train_loader, criterion, optimizer, device, epochs=200, print_interval=20):
    """
    训练函数 - 移除验证集相关代码，打印与节奏保持原脚本一致
    """
    interval_train_loss = 0.0
    interval_samples = 0

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        epoch_train_loss = 0.0

        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]', disable=(epoch % print_interval != 0)):
            smiles = batch['smiles'].to(device)
            features_before = batch['features_before'].to(device)
            features_after = batch['features_after'].to(device)

            optimizer.zero_grad()
            outputs = model(smiles, features_before)
            loss = criterion(outputs, features_after)
            loss.backward()
            optimizer.step()

            batch_size = smiles.size(0)
            epoch_train_loss += loss.item() * batch_size
            interval_train_loss += loss.item() * batch_size
            interval_samples += batch_size

        epoch_train_loss /= len(train_loader.dataset)

        if (epoch + 1) % print_interval == 0:
            avg_interval_loss = interval_train_loss / interval_samples
            end_time = time.time()
            print(f'\nEpoch {epoch+1}/{epochs} | Time: {end_time-start_time:.2f}s')
            print(f'Train Loss (interval avg): {avg_interval_loss:.6f} | Train Loss (epoch): {epoch_train_loss:.6f}')
            interval_train_loss = 0.0
            interval_samples = 0

    return model


def test_model(model, test_loader, criterion, device):
    """
    测试函数 - 保持原始逻辑
    """
    model.eval()
    test_loss = 0.0
    all_outputs = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Testing'):
            smiles = batch['smiles'].to(device)
            features_before = batch['features_before'].to(device)
            features_after = batch['features_after'].to(device)

            outputs = model(smiles, features_before)
            loss = criterion(outputs, features_after)

            test_loss += loss.item() * smiles.size(0)
            all_outputs.append(outputs.cpu().numpy())
            all_targets.append(features_after.cpu().numpy())

    test_loss /= len(test_loader.dataset)
    all_outputs = np.vstack(all_outputs)
    all_targets = np.vstack(all_targets)

    # 计算每个样本的评估指标
    avg_pcc, avg_rmse, pcc_values, rmse_values = calculate_per_sample_metrics(all_outputs, all_targets)

    print(f'Test Loss: {test_loss:.6f} | PCC: {avg_pcc:.6f} | RMSE: {avg_rmse:.6f}')
    print(f"Predictions mean: {np.mean(all_outputs):.6f}, std: {np.std(all_outputs):.6f}")
    print(f"Targets mean: {np.mean(all_targets)::.6f}, std: {np.std(all_targets):.6f}")

    # 打印最佳与最差样本
    best_sample_idx = np.argmax(pcc_values)
    worst_sample_idx = np.argmin(pcc_values)
    print(f"\nBest sample (PCC = {pcc_values[best_sample_idx]:.6f}, RMSE = {rmse_values[best_sample_idx]:.6f})")
    print(f" Worst sample (PCC = {pcc_values[worst_sample_idx]:.6f}, RMSE = {rmse_values[worst_sample_idx]:.6f})")

    return test_loss, avg_pcc, avg_rmse, pcc_values, rmse_values, all_outputs, all_targets


def load_trained_model(model_path, smiles_dim, feature_dim, device):
    """
    加载已训练模型（保持原始行为）
    """
    model = ImprovedMLP(smiles_dim=smiles_dim, feature_dim=feature_dim)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    print(f"Loaded trained model from {model_path}")
    return model


def main():
    # ======== 可配置参数（保持与原脚本一致） ========
    print_interval = 50  # 每print_interval个epoch打印一次
    All_epochs = 1000    # 总训练轮数
    model_path = 'BBBC036_improved_mlp.pth'
    load_existing_model = True  # 是否加载已有模型

    # 数据路径（保持原路径）
    smiles_pickle_path = '/home/bob/boom/VCBench/Molecule_encoder/embeddings/Image/ECFP4_emb2048.pickle'
    h5_path = '/home/bob/boom/VCBench/data/Image/BBBC036_data.h5'

    # ======== 数据准备 ========
    train_loader, test_loader, smiles_dim, feature_dim = prepare_dataloaders(
        smiles_pickle_path=smiles_pickle_path,
        h5_path=h5_path,
        test_size=0.1,
        random_state=42,
        batch_size=1024
    )

    # 设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ======== 模型初始化或加载 ========
    if load_existing_model and os.path.exists(model_path):
        model = load_trained_model(model_path, smiles_dim, feature_dim, device)
    else:
        model = ImprovedMLP(smiles_dim=smiles_dim, feature_dim=feature_dim).to(device)

    # ======== 损失与优化器 ========
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # ======== 训练 ========
    print(f"Training model, will print every {print_interval} epochs...")
    model = train_model(
        model=model,
        train_loader=train_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        epochs=All_epochs,
        print_interval=print_interval
    )

    # ======== 保存模型 ========
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # ======== 测试 ========
    print("Testing model...")
    test_loss, avg_pcc, avg_rmse, pcc_values, rmse_values, predictions, targets = test_model(
        model=model,
        test_loader=test_loader,
        criterion=nn.MSELoss(),
        device=device
    )

    print("\nPerformance distribution:")
    print(f"PCC - Min: {np.min(pcc_values):.6f}, Max: {np.max(pcc_values):.6f}, Mean: {np.mean(pcc_values):.6f}, Median: {np.median(pcc_values):.6f}")
    print(f"RMSE - Min: {np.min(rmse_values):.6f}, Max: {np.max(rmse_values):.6f}, Mean: {np.mean(rmse_values):.6f}, Median: {np.median(rmse_values):.6f}")


if __name__ == "__main__":
    main()
