In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from pathlib import Path
from rdkit import Chem
from rdkit import RDLogger
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, TensorDataset

# Disable RDLogger warnings
RDLogger.DisableLog('rdApp.*')
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
functional_groups = {
    'Acid anhydride': Chem.MolFromSmarts('[CX3](=[OX1])[OX2][CX3](=[OX1])'),
    'Acyl halide': Chem.MolFromSmarts('[CX3](=[OX1])[F,Cl,Br,I]'),
    'Alcohol': Chem.MolFromSmarts('[#6][OX2H]'),
    'Aldehyde': Chem.MolFromSmarts('[CX3H1](=O)[#6,H]'),
    'Alkane': Chem.MolFromSmarts('[CX4;H3,H2]'),
    'Alkene': Chem.MolFromSmarts('[CX3]=[CX3]'),
    'Alkyne': Chem.MolFromSmarts('[CX2]#[CX2]'),
    'Amide': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[#6]'),
    'Amine': Chem.MolFromSmarts('[NX3;H2,H1,H0;!$(NC=O)]'),
    'Arene': Chem.MolFromSmarts('[cX3]1[cX3][cX3][cX3][cX3][cX3]1'),
    'Azo compound': Chem.MolFromSmarts('[#6][NX2]=[NX2][#6]'),
    'Carbamate': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[OX2H0]'),
    'Carboxylic acid': Chem.MolFromSmarts('[CX3](=O)[OX2H]'),
    'Enamine': Chem.MolFromSmarts('[NX3][CX3]=[CX3]'),
    'Enol': Chem.MolFromSmarts('[OX2H][#6X3]=[#6]'),
    'Ester': Chem.MolFromSmarts('[#6][CX3](=O)[OX2H0][#6]'),
    'Ether': Chem.MolFromSmarts('[OD2]([#6])[#6]'),
    'Haloalkane': Chem.MolFromSmarts('[#6][F,Cl,Br,I]'),
    'Hydrazine': Chem.MolFromSmarts('[NX3][NX3]'),
    'Hydrazone': Chem.MolFromSmarts('[NX3][NX2]=[#6]'),
    'Imide': Chem.MolFromSmarts('[CX3](=[OX1])[NX3][CX3](=[OX1])'),
    'Imine': Chem.MolFromSmarts('[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]'),
    'Isocyanate': Chem.MolFromSmarts('[NX2]=[C]=[O]'),
    'Isothiocyanate': Chem.MolFromSmarts('[NX2]=[C]=[S]'),
    'Ketone': Chem.MolFromSmarts('[#6][CX3](=O)[#6]'),
    'Nitrile': Chem.MolFromSmarts('[NX1]#[CX2]'),
    'Phenol': Chem.MolFromSmarts('[OX2H][cX3]:[c]'),
    'Phosphine': Chem.MolFromSmarts('[PX3]'),
    'Sulfide': Chem.MolFromSmarts('[#16X2H0]'),
    'Sulfonamide': Chem.MolFromSmarts('[#16X4]([NX3])(=[OX1])(=[OX1])[#6]'),
    'Sulfonate': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H0]'),
    'Sulfone': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[#6]'),
    'Sulfonic acid': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H]'),
    'Sulfoxide': Chem.MolFromSmarts('[#16X3]=[OX1]'),
    'Thial': Chem.MolFromSmarts('[CX3H1](=S)[#6,H]'),
    'Thioamide': Chem.MolFromSmarts('[NX3][CX3]=[SX1]'),
    'Thiol': Chem.MolFromSmarts('[#16X2H]')
}
def match_group(mol: Chem.Mol, func_group) -> int:
    if type(func_group) == Chem.Mol:
        n = len(mol.GetSubstructMatches(func_group))
    else:
        n = func_group(mol)
    return 0 if n == 0 else 1
# Function to map SMILES to functional groups (no change)
def get_functional_groups(smiles: str) -> dict:
    smiles = smiles.strip().replace(' ', '')
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: 
        return None
    func_groups = [match_group(mol, smarts) for smarts in functional_groups.values()]
    return func_groups

def interpolate_to_600(spec):
    old_x = np.arange(len(spec))
    new_x = np.linspace(min(old_x), max(old_x), 600)
    interp = interp1d(old_x, spec)
    return interp(new_x)

def make_msms_spectrum(spectrum):
    msms_spectrum = np.zeros(10000)
    for peak in spectrum:
        peak_pos = int(peak[0]*10)
        peak_pos = min(peak_pos, 9999)
        msms_spectrum[peak_pos] = peak[1]
    return msms_spectrum

# Define CNN Model in PyTo




import torch
import torch.nn as nn
import torch.nn.functional as F

class IndependentCNN(nn.Module):
    def __init__(self, num_fgs):
        super(IndependentCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=31, kernel_size=11, padding='same')
        self.conv2 = nn.Conv1d(in_channels=31, out_channels=62, kernel_size=11, padding='same')

        self.batch_norm1 = nn.BatchNorm1d(31)
        self.batch_norm2 = nn.BatchNorm1d(62)

        # MLP for selecting important channels (62 channels)
        self.mlp = nn.Sequential(
            nn.Linear(62, 128),  # Input 150 features per channel
            nn.ReLU(),
            nn.Linear(128, 1)     # Output importance score for each channel
        )

    def forward(self, x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = F.max_pool1d(x, 1)
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.max_pool1d(x, 4)
        x = x.permute(0, 2, 1)
        # 通道重要性计算(这里对通道的重要性计算应该改成对频率的重要性计算，这样才能算采样。)（那就是先）
        static_feature_map = x.clone().detach()
        channel_means = x.mean(dim=1)
        channel_std = x.std(dim=1)

        channel_importance = torch.sigmoid(self.mlp(x))
        ib_x_mean = x * channel_importance + (1 - channel_importance) * channel_means.unsqueeze(1)
        ib_x_std = (1 - channel_importance) * channel_std.unsqueeze(1)
        ib_x = ib_x_mean + torch.rand_like(ib_x_mean) * ib_x_std

        # KL Divergence loss
        epsilon = 1e-8
        KL_tensor = 0.5 * (
            (ib_x_std**2) / (channel_std.unsqueeze(1) + epsilon)**2 +
            (channel_std.unsqueeze(1)**2) / (ib_x_std + epsilon)**2 - 1
        ) + ((ib_x_mean - channel_means.unsqueeze(1))**2) / (channel_std.unsqueeze(1) + epsilon)**2

        KL_Loss = torch.mean(KL_tensor)

        # Flatten and pass through fully connected layers
        ib_x = ib_x.permute(0, 2, 1)
        return ib_x, KL_Loss,channel_importance

import torch
import torch.nn as nn
import torch.nn.functional as F

def rbf_kernel(x, y, sigma=1.0):
    """
    x: [B, D]
    y: [B, D]
    sigma: RBF 核的带宽 (可调)
    返回: [B, B] 的核矩阵
    """
    # x.unsqueeze(1): [B, 1, D]
    # y.unsqueeze(0): [1, B, D]
    # dist: [B, B]，表示 x_i 与 y_j 的欧式距离平方
    x = x.unsqueeze(1)
    y = y.unsqueeze(0)
    dist = (x - y).pow(2).sum(dim=2)
    kxy = torch.exp(-dist / (2 * sigma**2))
    return kxy

# =============== 2. 定义 HSIC 计算函数 ===============
def hsic(x, y, sigma=1.0):
    """
    计算 x, y 的 HSIC 值: HSIC(X, Y) = 1/(n-1)^2 * Tr(H Kx H Ky)
    x: [B, D]
    y: [B, D]
    sigma: RBF 核带宽
    返回标量 HSIC 值
    """
    assert x.size(0) == y.size(0), "x,y 的 batch size 不一致"
    n = x.size(0)

    # 计算核矩阵
    Kx = rbf_kernel(x, x, sigma=sigma)
    Ky = rbf_kernel(y, y, sigma=sigma)

    # 居中矩阵 H = I - 1/n
    H = torch.eye(n, device=x.device) - (1./n) * torch.ones((n, n), device=x.device)

    # H Kx H
    HKxH = H.mm(Kx).mm(H)

    # HSIC = Tr( (H Kx H) * Ky ) / (n-1)^2
    # 注: 这里相乘可以写为 trace(HKxH @ Ky)，矩阵乘法后再 trace
    hsic_val = torch.trace(HKxH.mm(Ky)) / (float(n - 1) ** 2)

    return hsic_val









import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNModelWithVAE(nn.Module): 
    def __init__(self, num_fgs, channel=62, feature_dim=150, hidden_dim=256, latent_dim=64, m_dim=10):
        """
        参数：
        - num_fgs: 预测目标的维度
        - channel: 每个光谱的通道数（不同频率段）
        - feature_dim: 每个光谱的特征维度
        - hidden_dim: 隐藏层维度
        - latent_dim: 潜在变量 z 的维度
        - m_dim: 预测目标的维度（如有需要）
        """
        super(CNNModelWithVAE, self).__init__()
        self.channel = channel
        self.feature_dim = feature_dim

        # 创建三个独立的CNN模块
        self.cnn1 = IndependentCNN(num_fgs)
        self.cnn2 = IndependentCNN(num_fgs)
        self.cnn3 = IndependentCNN(num_fgs)


        # VAE Encoder: 将三个光谱特征融合成潜在表示 z
        # 将 [B, 3*channel, feature_dim] 展平为 [B, 3*channel*feature_dim]
        self.fc_fusion = nn.Sequential(
            nn.Linear(3 * channel * feature_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # VAE Decoder: 从潜在表示 z 重建三个光谱
        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, channel * feature_dim),
                nn.ReLU()
            ) for _ in range(3)
        ])


        # 增加一个线性层处理 x3 的特征
        self.fc_x3 = nn.Linear(channel * feature_dim, latent_dim)

        # 全连接层用于最终预测，使用 z 和 x3 作为输入
        self.fc1 = nn.Linear(latent_dim *2, 4927)  # z 和 x3
        self.fc2 = nn.Linear(4927, 2785)
        self.fc3 = nn.Linear(2785, 1574)
        self.fc4 = nn.Linear(1574,1)
        self.dropout = nn.Dropout(0.48599073736368)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)   # ~ N(0, I)
        return mu + std * eps
    
    def forward(self, x):
        """
        前向传播函数。
        
        参数：
        - x: 输入张量，形状为 [batch_size, 3, feature_dim]
        
        返回：
        - 一个包含预测结果和各类损失组件的字典
        """
        # 拆分输入为三个光谱通道
        x1, x2, x3 = x[:, 0:1, :], x[:, 1:2, :], x[:, 2:3, :]  # 每个 [B, 1, feature_dim]

        # 分别通过三个独立的CNN
        ib_x_1, kl_loss1,channal_importance_1= self.cnn1(x1)  # [B, channel, feature_dim]
        ib_x_2, kl_loss2 ,channal_importance_2= self.cnn2(x2)
        ib_x_3, kl_loss3 ,channal_importance_3= self.cnn3(x3)

        # 将三个通道的输出堆叠
        ib_x_stacked = torch.cat([ib_x_1, ib_x_2, ib_x_3], dim=1)  # [B, 3*channel, feature_dim]
        # 展平为 [B, 3*channel*feature_dim]
        ib_x_flat = ib_x_stacked.view(ib_x_stacked.size(0), -1)  # [B, 3*channel*feature_dim]
        # VAE Encoder
        h = self.fc_fusion(ib_x_flat)  # [B, hidden_dim]
        mu = self.fc_mu(h)             # [B, latent_dim]
        logvar = self.fc_logvar(h)     # [B, latent_dim]
        z = self.reparameterize(mu, logvar)  # [B, latent_dim]

        # VAE Decoder: 重建三个光谱
        recon_x = []
        for decoder in self.decoder:
            recon = decoder(z)  # [B, channel * feature_dim]
            recon = recon.view(z.size(0), self.channel, self.feature_dim)  # [B, channel, feature_dim]
            recon_x.append(recon)
        recon_x1, recon_x2, recon_x3 = recon_x  # 各自的重构光谱

        # 条件互信息估计器
        # 将 ib_x_* 展平
        ib_x1_flat = ib_x_1.view(z.size(0), -1)  # [B, channel * feature_dim]
        ib_x2_flat = ib_x_2.view(z.size(0), -1)
        ib_x3_flat = ib_x_3.view(z.size(0), -1)
        sigma=1.0
        # ====== 3.3 计算 HSIC 并做加和 ======
        hsic_x3_x1 = hsic(ib_x3_flat, ib_x1_flat, sigma=sigma)
        hsic_x3_x2 = hsic(ib_x3_flat, ib_x2_flat, sigma=sigma)
        hsic_x3_z  = hsic(ib_x3_flat, z,     sigma=sigma)

        # 将这几个 HSIC 值相加
        hsic_loss = hsic_x3_x1 + hsic_x3_x2 + hsic_x3_z

        # ====== 3.4 总损失: 主损失 + alpha * HSIC(总和) ======
        cmi_loss = hsic_loss

        # 增加 x3 的处理
        x3_processed = self.fc_x3(ib_x3_flat)  # [B, latent_dim]
        z_x3 = torch.cat([z, x3_processed], dim=1)  # [B, 2 * latent_dim]
        x_pred = F.relu(self.fc1(z_x3 ))  # [B, 4927]
        x_pred = self.dropout(x_pred)
        x_pred = F.relu(self.fc2(x_pred))  # [B, 2785]
        x_pred = self.dropout(x_pred)
        x_pred = F.relu(self.fc3(x_pred))  # [B, 1574]
        x_pred = self.dropout(x_pred)
        x_pred = torch.sigmoid(self.fc4(x_pred))  # [B, num_fgs]

        # KL散度损失取平均值（来自 VAE）
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
        kl=( kl_loss1+ kl_loss2+ kl_loss3)/3
        return {
            'x': x_pred,
            'vae_mu': mu,
            'vae_logvar': logvar,
            'recon_x1': recon_x1,
            'recon_x2': recon_x2,
            'recon_x3': recon_x3,
            'cmi_loss': cmi_loss,  # InfoNCE 损失
            'ib_x_1': ib_x_1,
            'ib_x_2': ib_x_2,
            'ib_x_3': ib_x_3,
            'kl':kl,
            'channal_importance_1':channal_importance_1,
            'channal_importance_2':channal_importance_2,
            'channal_importance_3':channal_importance_3
        }










# Training function in PyTorch
from tqdm import tqdm  # 引入 tqdm

b=0.0001
# 定义训练函数
# 定义训练函数
from tqdm import tqdm  # 引入 tqdm

# 定义训练函数
def train_model(X_train, y_train, X_test, y_test, num_fgs, weighted=False, batch_size=41, epochs=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.5, lambda_recon=0.1):
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model = CNNModelWithVAE(num_fgs).to(device)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    if weighted:
        class_weights = calculate_class_weights(y_train)
        criterion = WeightedBinaryCrossEntropyLoss(class_weights).to(device)
    else:
        criterion = nn.BCELoss().to(device)

    # 创建 DataLoader
    y_train = np.array([np.array(item, dtype=np.float32) for item in y_train], dtype=np.float32)
    y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
    train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # 确保保存路径存在
    out_path.mkdir(parents=True, exist_ok=True)
    
    best_f1 = 0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        recon_loss_avg = 0.0
        kl_weight = min(max_lambda_kl, (epoch + 1) / annealing_epochs)
        with tqdm(train_loader, unit='batch', desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
            for batch in tepoch:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                x_pred = outputs['x']
                mu = outputs['vae_mu']
                logvar = outputs['vae_logvar']
                recon_x1 = outputs['recon_x1']
                recon_x2 = outputs['recon_x2']
                recon_x3 = outputs['recon_x3']
                kl = outputs['kl']
                cmi_loss = outputs['cmi_loss']

                # 预测损失
                targets = targets[:, 0].unsqueeze(1)
                pred_loss = criterion(x_pred, targets)
                
                # 重建损失
                recon_loss = F.mse_loss(recon_x1, outputs['ib_x_1']) + \
                             F.mse_loss(recon_x2, outputs['ib_x_2']) + \
                             F.mse_loss(recon_x3, outputs['ib_x_3'])
        
                # KL散度损失
                kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
                # 总损失：预测损失 + KL散度 + 互信息损失 + 重建损失
                total_loss = pred_loss + kl_weight * kl_div + \
                             lambda_cmi * cmi_loss + lambda_recon * recon_loss + 0.0001 * kl
                total_loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                running_loss += total_loss.item()
                recon_loss_avg += recon_loss.item()
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1),
                                  kl_weight=kl_weight)
        
        avg_loss = running_loss / len(train_loader)
        recon_loss_a = recon_loss_avg / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}, KL Weight: {kl_weight}, Recon Loss: {recon_loss_a}')
        
        # 评估F1分数
        model.eval()
        predictions = []
        with torch.no_grad():
            for batch in test_loader:
                inputs, targets = batch
                targets = targets[:, 0].unsqueeze(1)
                targets=targets[:][0]
                inputs = inputs.to(device)
                outputs = model(inputs)
                x_pred = outputs['x']
                predictions.append(x_pred.cpu().numpy())
        predictions = np.concatenate(predictions)
        binary_predictions = (predictions > 0.5).astype(int)
        y_test = y_test[:, 0].reshape(-1, 1)
        f1 = f1_score(y_test, binary_predictions, average='micro')
        print(f'F1 Score: {f1}')
        
        # 保存最佳模型
        if f1 > best_f1:
            best_f1 = f1
            model_save_path = out_path / "best_model.pth"
            torch.save(model.state_dict(), model_save_path)
            print(f'Best model saved with F1 Score: {best_f1} at {model_save_path}')

    return binary_predictions




# Custom loss function with class weights
class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        loss = self.class_weights[0] * (1 - y_true) * torch.log(1 - y_pred + 1e-15) + \
               self.class_weights[1] * y_true * torch.log(y_pred + 1e-15)
        return -loss.mean()

# Calculate class weights
def calculate_class_weights(y_true):
    num_samples = y_true.shape[0]
    class_weights = np.zeros((2, y_true.shape[1]))
    for i in range(y_true.shape[1]):
        weights_n = num_samples / (2 * (y_true[:, i] == 0).sum())
        weights_p = num_samples / (2 * (y_true[:, i] == 1).sum())
        class_weights[0, i] = weights_n
        class_weights[1, i] = weights_p
    return torch.tensor(class_weights.T, dtype=torch.float32)

# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all")
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]
seed = 3245

# 准备存储合并后的数据
all_data = []
i=0
# 一次性读取文件并处理所有列
for parquet_file in analytical_data.glob("*.parquet"):
    i+=1
    # 读取所有需要的列
    data = pd.read_parquet(parquet_file, columns=columns + ['smiles'])
    
    # 对每个列进行插值
    for column in columns:
        data[column] = data[column].map(interpolate_to_600)
    
    # 添加功能团信息
    data['func_group'] = data.smiles.map(get_functional_groups)
    #在这里就是0/1矩阵了
    all_data.append(data)
    print(f"Loaded Data from: ", i)
    if i==3:
        break
# 合并所有数据
training_data = pd.concat(all_data, ignore_index=True)


# 将数据划分为训练集和测试集
train, test = train_test_split(training_data, test_size=0.1, random_state=seed)

# 定义特征列
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]

# 提取训练集特征和标签
X_train = np.array(train[columns].values.tolist())  # 确保特征值是一个二维数组
y_train = np.array(train['func_group'].values)      # 标签转换为一维数组

# 提取测试集特征和标签
X_test = np.array(test[columns].values.tolist())    # 同样确保二维数组
y_test = np.array(test['func_group'].values)        # 标签一维数组

# 检查数组形状以验证正确性
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
# Train extended model
predictions = train_model(X_train, y_train, X_test, y_test,num_fgs=37, weighted=False, batch_size=41, epochs=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.1, lambda_recon=0.1)

# Evaluate the model
y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)

Loaded Data from:  1
Loaded Data from:  2
Loaded Data from:  3
X_train shape: (8740, 3, 600)
y_train shape: (8740,)
X_test shape: (972, 3, 600)
y_test shape: (972,)


Epoch 1/41: 100%|██████████| 214/214 [00:07<00:00, 26.92batch/s, kl_weight=0.1, loss=121]   


Epoch 1/41, Loss: 120.14816641819307, KL Weight: 0.1, Recon Loss: 4.062616068060362
F1 Score: 0.9989711934156379
Best model saved with F1 Score: 0.9989711934156379 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 2/41: 100%|██████████| 214/214 [00:07<00:00, 28.29batch/s, kl_weight=0.2, loss=0.111] 


Epoch 2/41, Loss: 0.10971375003404846, KL Weight: 0.2, Recon Loss: 0.008368838194400837
F1 Score: 0.9989711934156379


Epoch 3/41: 100%|██████████| 214/214 [00:07<00:00, 28.12batch/s, kl_weight=0.3, loss=0.106] 


Epoch 3/41, Loss: 0.10621489464113855, KL Weight: 0.3, Recon Loss: 0.002522643497779478
F1 Score: 0.9989711934156379


Epoch 4/41: 100%|██████████| 214/214 [00:07<00:00, 28.62batch/s, kl_weight=0.4, loss=0.105] 


Epoch 4/41, Loss: 0.10494953173508641, KL Weight: 0.4, Recon Loss: 0.001062062885160732
F1 Score: 0.9989711934156379


Epoch 5/41: 100%|██████████| 214/214 [00:07<00:00, 28.59batch/s, kl_weight=0.5, loss=0.159] 


Epoch 5/41, Loss: 0.15854250617740376, KL Weight: 0.5, Recon Loss: 0.00043860562698546773
F1 Score: 0.9989711934156379


Epoch 6/41: 100%|██████████| 214/214 [00:07<00:00, 28.11batch/s, kl_weight=0.6, loss=0.104]  


Epoch 6/41, Loss: 0.1026440447934466, KL Weight: 0.6, Recon Loss: 0.0001452774318648017
F1 Score: 0.9989711934156379


Epoch 7/41: 100%|██████████| 214/214 [00:07<00:00, 27.99batch/s, kl_weight=0.7, loss=0.103]  


Epoch 7/41, Loss: 0.10262793474772207, KL Weight: 0.7, Recon Loss: 0.00045397491364476813
F1 Score: 0.9989711934156379


Epoch 8/41: 100%|██████████| 214/214 [00:07<00:00, 27.98batch/s, kl_weight=0.8, loss=0.104] 


Epoch 8/41, Loss: 0.10258266620458314, KL Weight: 0.8, Recon Loss: 0.00024009740789160872
F1 Score: 0.9989711934156379


Epoch 9/41: 100%|██████████| 214/214 [00:07<00:00, 27.82batch/s, kl_weight=0.9, loss=0.104]  


Epoch 9/41, Loss: 0.10256880464728786, KL Weight: 0.9, Recon Loss: 9.043352318292257e-05
F1 Score: 0.9989711934156379


Epoch 10/41: 100%|██████████| 214/214 [00:07<00:00, 27.93batch/s, kl_weight=1, loss=0.103]  


Epoch 10/41, Loss: 0.10255412827027655, KL Weight: 1.0, Recon Loss: 8.496944056472972e-05
F1 Score: 0.9989711934156379


Epoch 11/41: 100%|██████████| 214/214 [00:07<00:00, 28.00batch/s, kl_weight=1, loss=0.103]  


Epoch 11/41, Loss: 0.10254498976985041, KL Weight: 1.0, Recon Loss: 4.123550182488985e-05
F1 Score: 0.9989711934156379


Epoch 12/41: 100%|██████████| 214/214 [00:07<00:00, 28.13batch/s, kl_weight=1, loss=0.104] 


Epoch 12/41, Loss: 0.10254058312887189, KL Weight: 1.0, Recon Loss: 3.999176292849387e-05
F1 Score: 0.9989711934156379


Epoch 13/41: 100%|██████████| 214/214 [00:07<00:00, 28.64batch/s, kl_weight=1, loss=0.104]


Epoch 13/41, Loss: 0.10253981652186522, KL Weight: 1.0, Recon Loss: 5.139084316639451e-05
F1 Score: 0.9989711934156379


Epoch 14/41: 100%|██████████| 214/214 [00:07<00:00, 28.24batch/s, kl_weight=1, loss=0.104] 


Epoch 14/41, Loss: 0.10253433020984078, KL Weight: 1.0, Recon Loss: 1.3127151078151655e-05
F1 Score: 0.9989711934156379


Epoch 15/41: 100%|██████████| 214/214 [00:07<00:00, 28.26batch/s, kl_weight=1, loss=0.104]  


Epoch 15/41, Loss: 0.10253384862950621, KL Weight: 1.0, Recon Loss: 1.984819859803665e-05
F1 Score: 0.9989711934156379


Epoch 16/41: 100%|██████████| 214/214 [00:09<00:00, 23.38batch/s, kl_weight=1, loss=0.103] 


Epoch 16/41, Loss: 0.10253222654978132, KL Weight: 1.0, Recon Loss: 9.31751040438217e-06
F1 Score: 0.9989711934156379


Epoch 17/41: 100%|██████████| 214/214 [00:09<00:00, 23.02batch/s, kl_weight=1, loss=0.104]  


Epoch 17/41, Loss: 0.10253686418218422, KL Weight: 1.0, Recon Loss: 6.183564091111931e-05
F1 Score: 0.9989711934156379


Epoch 18/41: 100%|██████████| 214/214 [00:09<00:00, 22.81batch/s, kl_weight=1, loss=0.103] 


Epoch 18/41, Loss: 0.1025312290451345, KL Weight: 1.0, Recon Loss: 1.5583161153523248e-05
F1 Score: 0.9989711934156379


Epoch 19/41: 100%|██████████| 214/214 [00:09<00:00, 22.67batch/s, kl_weight=1, loss=0.103] 


Epoch 19/41, Loss: 0.10253014984142436, KL Weight: 1.0, Recon Loss: 6.768524708902171e-06
F1 Score: 0.9989711934156379


Epoch 20/41: 100%|██████████| 214/214 [00:09<00:00, 22.83batch/s, kl_weight=1, loss=0.103]  


Epoch 20/41, Loss: 0.10253102987692975, KL Weight: 1.0, Recon Loss: 1.8426757770813125e-05
F1 Score: 0.9989711934156379


Epoch 21/41: 100%|██████████| 214/214 [00:09<00:00, 22.72batch/s, kl_weight=1, loss=0.103] 


Epoch 21/41, Loss: 0.10252884818862108, KL Weight: 1.0, Recon Loss: 2.822640338932386e-06
F1 Score: 0.9989711934156379


Epoch 22/41: 100%|██████████| 214/214 [00:09<00:00, 22.27batch/s, kl_weight=1, loss=0.103] 


Epoch 22/41, Loss: 0.10252779229098467, KL Weight: 1.0, Recon Loss: 1.7831481461072806e-06
F1 Score: 0.9989711934156379


Epoch 23/41: 100%|██████████| 214/214 [00:09<00:00, 22.83batch/s, kl_weight=1, loss=0.103]  


Epoch 23/41, Loss: 0.10252756663616922, KL Weight: 1.0, Recon Loss: 2.4530711925231923e-06
F1 Score: 0.9989711934156379


Epoch 24/41: 100%|██████████| 214/214 [00:09<00:00, 22.73batch/s, kl_weight=1, loss=0.103]  


Epoch 24/41, Loss: 0.10252736915802663, KL Weight: 1.0, Recon Loss: 1.5552595059120033e-06
F1 Score: 0.9989711934156379


Epoch 25/41: 100%|██████████| 214/214 [00:09<00:00, 22.46batch/s, kl_weight=1, loss=0.103]  


Epoch 25/41, Loss: 0.10252739784269903, KL Weight: 1.0, Recon Loss: 1.5868528528661525e-06
F1 Score: 0.9989711934156379


Epoch 26/41: 100%|██████████| 214/214 [00:09<00:00, 22.58batch/s, kl_weight=1, loss=0.103]  


Epoch 26/41, Loss: 0.10252769726847423, KL Weight: 1.0, Recon Loss: 2.7174979983413737e-06
F1 Score: 0.9989711934156379


Epoch 27/41: 100%|██████████| 214/214 [00:09<00:00, 22.74batch/s, kl_weight=1, loss=0.103]  


Epoch 27/41, Loss: 0.10252719647656007, KL Weight: 1.0, Recon Loss: 1.366604347940899e-06
F1 Score: 0.9989711934156379


Epoch 28/41: 100%|██████████| 214/214 [00:09<00:00, 22.64batch/s, kl_weight=1, loss=0.103] 


Epoch 28/41, Loss: 0.1025283492907509, KL Weight: 1.0, Recon Loss: 1.1657922179592102e-05
F1 Score: 0.9989711934156379


Epoch 29/41: 100%|██████████| 214/214 [00:09<00:00, 22.73batch/s, kl_weight=1, loss=0.103]  


Epoch 29/41, Loss: 0.1025269769873367, KL Weight: 1.0, Recon Loss: 1.8624491017328876e-07
F1 Score: 0.9989711934156379


Epoch 30/41: 100%|██████████| 214/214 [00:09<00:00, 22.70batch/s, kl_weight=1, loss=0.103] 


Epoch 30/41, Loss: 0.10252656515506281, KL Weight: 1.0, Recon Loss: 4.451868358842588e-07
F1 Score: 0.9989711934156379


Epoch 31/41: 100%|██████████| 214/214 [00:09<00:00, 22.93batch/s, kl_weight=1, loss=0.103] 


Epoch 31/41, Loss: 0.10253078289338403, KL Weight: 1.0, Recon Loss: 4.3301124103177636e-05
F1 Score: 0.9989711934156379


Epoch 32/41: 100%|██████████| 214/214 [00:09<00:00, 22.60batch/s, kl_weight=1, loss=0.103]  


Epoch 32/41, Loss: 0.10252625768035299, KL Weight: 1.0, Recon Loss: 6.630753724906157e-08
F1 Score: 0.9989711934156379


Epoch 33/41: 100%|██████████| 214/214 [00:09<00:00, 22.79batch/s, kl_weight=1, loss=0.103]  


Epoch 33/41, Loss: 0.10252620865214561, KL Weight: 1.0, Recon Loss: 2.1503741317710827e-08
F1 Score: 0.9989711934156379


Epoch 34/41: 100%|██████████| 214/214 [00:09<00:00, 22.81batch/s, kl_weight=1, loss=0.103]  


Epoch 34/41, Loss: 0.10252617882940633, KL Weight: 1.0, Recon Loss: 6.771301962803584e-09
F1 Score: 0.9989711934156379


Epoch 35/41: 100%|██████████| 214/214 [00:09<00:00, 22.64batch/s, kl_weight=1, loss=0.103] 


Epoch 35/41, Loss: 0.10252637383963081, KL Weight: 1.0, Recon Loss: 1.1566686337387991e-07
F1 Score: 0.9989711934156379


Epoch 36/41: 100%|██████████| 214/214 [00:09<00:00, 22.49batch/s, kl_weight=1, loss=0.103]  


Epoch 36/41, Loss: 0.1025268214467428, KL Weight: 1.0, Recon Loss: 3.0866628259393995e-06
F1 Score: 0.9989711934156379


Epoch 37/41: 100%|██████████| 214/214 [00:09<00:00, 22.52batch/s, kl_weight=1, loss=0.103] 


Epoch 37/41, Loss: 0.1025261097769986, KL Weight: 1.0, Recon Loss: 6.840891087427275e-08
F1 Score: 0.9989711934156379


Epoch 38/41: 100%|██████████| 214/214 [00:09<00:00, 22.36batch/s, kl_weight=1, loss=0.103]  


Epoch 38/41, Loss: 0.10252597727628547, KL Weight: 1.0, Recon Loss: 1.1225831970382476e-08
F1 Score: 0.9989711934156379


Epoch 39/41: 100%|██████████| 214/214 [00:09<00:00, 22.43batch/s, kl_weight=1, loss=0.103] 


Epoch 39/41, Loss: 0.10252615714182518, KL Weight: 1.0, Recon Loss: 4.538740518162688e-07
F1 Score: 0.9989711934156379


Epoch 40/41: 100%|██████████| 214/214 [00:09<00:00, 22.46batch/s, kl_weight=1, loss=0.103]  


Epoch 40/41, Loss: 0.10252659576486349, KL Weight: 1.0, Recon Loss: 2.886102313745064e-06
F1 Score: 0.9989711934156379


Epoch 41/41: 100%|██████████| 214/214 [00:09<00:00, 22.62batch/s, kl_weight=1, loss=0.103] 


Epoch 41/41, Loss: 0.10252625779292208, KL Weight: 1.0, Recon Loss: 5.212831338368758e-07
F1 Score: 0.9989711934156379


ValueError: Classification metrics can't handle a mix of multilabel-indicator and binary targets

In [None]:
######初始的
import matplotlib.pyplot as plt
# 定义训练函数
def evalute_model(X_test, y_test, model_path,smiles,ir,num_fgs, weighted=False, batch_size=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.5, lambda_recon=0.1):
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model = CNNModelWithVAE(num_fgs).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    

    # 创建 DataLoader
    y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    
    # 评估F1分数
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            x_pred = outputs['x']
            channal_importance_1 = outputs['channal_importance_1']
            channal_importance_2 = outputs['channal_importance_2']
            channal_importance_3 = outputs['channal_importance_3']
            channal_importance_3_cpu = channal_importance_3.squeeze().cpu().numpy()  # squeeze 去除维度 [1, 150, 1] 转为 [150, 1]



            # 步骤4: 可视化
            plt.plot(np.arange(150), channal_importance_3_cpu)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('Importance')
            plt.show()
            
            
            plt.plot(np.arange(1800), ir)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('ir')
            plt.show()
            
            predictions.append(x_pred.cpu().numpy())
    predictions = np.concatenate(predictions)
    binary_predictions = (predictions > 0.5).astype(int)
    f1 = f1_score(y_test, binary_predictions, average='micro')
    print(f'F1 Score: {f1}')

    return binary_predictions




# Custom loss function with class weights
class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        loss = self.class_weights[0] * (1 - y_true) * torch.log(1 - y_pred + 1e-15) + \
               self.class_weights[1] * y_true * torch.log(y_pred + 1e-15)
        return -loss.mean()

# Calculate class weights
def calculate_class_weights(y_true):
    num_samples = y_true.shape[0]
    class_weights = np.zeros((2, y_true.shape[1]))
    for i in range(y_true.shape[1]):
        weights_n = num_samples / (2 * (y_true[:, i] == 0).sum())
        weights_p = num_samples / (2 * (y_true[:, i] == 1).sum())
        class_weights[0, i] = weights_n
        class_weights[1, i] = weights_p
    return torch.tensor(class_weights.T, dtype=torch.float32)

# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all")
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]
seed = 3245
model_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth")
# 准备存储合并后的数据
all_data = []
i=0
# 一次性读取文件并处理所有列
for parquet_file in analytical_data.glob("*.parquet"):
    i+=1
    # 读取所有需要的列
    data = pd.read_parquet(parquet_file, columns=columns + ['smiles'])
    # 对每个列进行插值
    for column in columns:
        data[column+"ori"] = data[column]
        data[column] = data[column].map(interpolate_to_600)
    
    # 添加功能团信息
    data['func_group'] = data.smiles.map(get_functional_groups)
    #在这里就是0/1矩阵了
    all_data.append(data)
    print(f"Loaded Data from: ", i)
    if i==3:
        break
# 合并所有数据
training_data = pd.concat(all_data, ignore_index=True)


# 将数据划分为训练集和测试集
train, test = train_test_split(training_data, test_size=1, random_state=seed)
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]


# 提取测试集特征和标签
X_test = np.array(test[columns].values.tolist())    # 同样确保二维数组
y_test = np.array(test['func_group'].values)        # 标签一维数组
print(len(test["ir_spectraori"].values.tolist()[0]))
smiles = test["smiles"]
ir = test["ir_spectraori"].values.tolist()[0]
# Train extended model
print(test['func_group'].values)
predictions = evalute_model( X_test, y_test,model_path,smiles,ir,num_fgs=37, weighted=False, batch_size=1, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.1, lambda_recon=0.1)

# Evaluate the model
y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)