In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from pathlib import Path
from rdkit import Chem
from rdkit import RDLogger
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, TensorDataset

# Disable RDLogger warnings
RDLogger.DisableLog('rdApp.*')
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
functional_groups = {
    'Acid anhydride': Chem.MolFromSmarts('[CX3](=[OX1])[OX2][CX3](=[OX1])'),
    'Acyl halide': Chem.MolFromSmarts('[CX3](=[OX1])[F,Cl,Br,I]'),
    'Alcohol': Chem.MolFromSmarts('[#6][OX2H]'),
    'Aldehyde': Chem.MolFromSmarts('[CX3H1](=O)[#6,H]'),
    'Alkane': Chem.MolFromSmarts('[CX4;H3,H2]'),
    'Alkene': Chem.MolFromSmarts('[CX3]=[CX3]'),
    'Alkyne': Chem.MolFromSmarts('[CX2]#[CX2]'),
    'Amide': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[#6]'),
    'Amine': Chem.MolFromSmarts('[NX3;H2,H1,H0;!$(NC=O)]'),
    'Arene': Chem.MolFromSmarts('[cX3]1[cX3][cX3][cX3][cX3][cX3]1'),
    'Azo compound': Chem.MolFromSmarts('[#6][NX2]=[NX2][#6]'),
    'Carbamate': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[OX2H0]'),
    'Carboxylic acid': Chem.MolFromSmarts('[CX3](=O)[OX2H]'),
    'Enamine': Chem.MolFromSmarts('[NX3][CX3]=[CX3]'),
    'Enol': Chem.MolFromSmarts('[OX2H][#6X3]=[#6]'),
    'Ester': Chem.MolFromSmarts('[#6][CX3](=O)[OX2H0][#6]'),
    'Ether': Chem.MolFromSmarts('[OD2]([#6])[#6]'),
    'Haloalkane': Chem.MolFromSmarts('[#6][F,Cl,Br,I]'),
    'Hydrazine': Chem.MolFromSmarts('[NX3][NX3]'),
    'Hydrazone': Chem.MolFromSmarts('[NX3][NX2]=[#6]'),
    'Imide': Chem.MolFromSmarts('[CX3](=[OX1])[NX3][CX3](=[OX1])'),
    'Imine': Chem.MolFromSmarts('[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]'),
    'Isocyanate': Chem.MolFromSmarts('[NX2]=[C]=[O]'),
    'Isothiocyanate': Chem.MolFromSmarts('[NX2]=[C]=[S]'),
    'Ketone': Chem.MolFromSmarts('[#6][CX3](=O)[#6]'),
    'Nitrile': Chem.MolFromSmarts('[NX1]#[CX2]'),
    'Phenol': Chem.MolFromSmarts('[OX2H][cX3]:[c]'),
    'Phosphine': Chem.MolFromSmarts('[PX3]'),
    'Sulfide': Chem.MolFromSmarts('[#16X2H0]'),
    'Sulfonamide': Chem.MolFromSmarts('[#16X4]([NX3])(=[OX1])(=[OX1])[#6]'),
    'Sulfonate': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H0]'),
    'Sulfone': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[#6]'),
    'Sulfonic acid': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H]'),
    'Sulfoxide': Chem.MolFromSmarts('[#16X3]=[OX1]'),
    'Thial': Chem.MolFromSmarts('[CX3H1](=S)[#6,H]'),
    'Thioamide': Chem.MolFromSmarts('[NX3][CX3]=[SX1]'),
    'Thiol': Chem.MolFromSmarts('[#16X2H]')
}
def match_group(mol: Chem.Mol, func_group) -> int:
    if type(func_group) == Chem.Mol:
        n = len(mol.GetSubstructMatches(func_group))
    else:
        n = func_group(mol)
    return 0 if n == 0 else 1
# Function to map SMILES to functional groups (no change)
def get_functional_groups(smiles: str) -> dict:
    smiles = smiles.strip().replace(' ', '')
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: 
        return None
    func_groups = [match_group(mol, smarts) for smarts in functional_groups.values()]
    return func_groups

def interpolate_to_600(spec):
    old_x = np.arange(len(spec))
    new_x = np.linspace(min(old_x), max(old_x), 600)
    interp = interp1d(old_x, spec)
    return interp(new_x)

def make_msms_spectrum(spectrum):
    msms_spectrum = np.zeros(10000)
    for peak in spectrum:
        peak_pos = int(peak[0]*10)
        peak_pos = min(peak_pos, 9999)
        msms_spectrum[peak_pos] = peak[1]
    return msms_spectrum

# Define CNN Model in PyTo




import torch
import torch.nn as nn
import torch.nn.functional as F


'''
class IndependentCNN(nn.Module):
    def __init__(self, num_fgs):
        super(IndependentCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=31, kernel_size=11, padding='same')
        self.conv2 = nn.Conv1d(in_channels=31, out_channels=62, kernel_size=11, padding='same')

        self.batch_norm1 = nn.BatchNorm1d(31)
        self.batch_norm2 = nn.BatchNorm1d(62)

        # MLP for selecting important channels (62 channels)
        self.mlp = nn.Sequential(
            nn.Linear(62, 128),  # Input 150 features per channel
            nn.ReLU(),
            nn.Linear(128, 1)     # Output importance score for each channel
        )
    def compress(self, solute_features):
    
        p = self.mlp(solute_features)
        device = solute_features.device
        temperature = 1.0
        bias = 0.0 + 0.0001  # If bias is 0, we run into problems
        eps = (bias - (1 - bias)) * torch.rand(p.size()) + (1 - bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(device)
        gate_inputs = (gate_inputs + p) / temperature
        gate_inputs = torch.sigmoid(gate_inputs).squeeze()
        p =torch.sigmoid(p)
        return gate_inputs,p
    def forward(self, x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = F.max_pool1d(x, 1)
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.max_pool1d(x, 4)
        x = x.permute(0, 2, 1)
        # 通道重要性计算(这里对通道的重要性计算应该改成对频率的重要性计算，这样才能算采样。)（那就是先）
        #其实，光谱应该剩下的补充0，而不是补充均值补充均值是没有道理的。
        static_feature_map = x.clone().detach()
        channel_means = x.mean(dim=1)
        channel_std = x.std(dim=1)

        channel_importance,p = self.compress(x)
        #print(channel_importance.size())#41,150
        channel_importance=channel_importance.unsqueeze(-1)
        ib_x_mean = x * channel_importance + (1 - channel_importance) * channel_means.unsqueeze(1)
        ib_x_std = (1 - channel_importance) * channel_std.unsqueeze(1)
        ib_x = ib_x_mean + torch.rand_like(ib_x_mean) * ib_x_std

        # KL Divergence loss
        epsilon = 1e-8
        KL_tensor = 0.5 * (
            (ib_x_std**2) / (channel_std.unsqueeze(1) + epsilon)**2 +
            (channel_std.unsqueeze(1)**2) / (ib_x_std + epsilon)**2 - 1
        ) + ((ib_x_mean - channel_means.unsqueeze(1))**2) / (channel_std.unsqueeze(1) + epsilon)**2

        KL_Loss = torch.mean(KL_tensor)

        # Flatten and pass through fully connected layers
        ib_x = ib_x.permute(0, 2, 1)
        return ib_x, KL_Loss,p

'''

class IndependentCNN(nn.Module):
    def __init__(self, num_fgs):
        super(IndependentCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=11, padding='same')
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=11, padding='same')

        self.batch_norm1 = nn.BatchNorm1d(64)
        self.batch_norm2 = nn.BatchNorm1d(128)

        # MLP for selecting important channels (62 channels)
        self.mlp = nn.Sequential(
            nn.Linear(128, 256),  # 输入每个通道150个特征
            nn.ReLU(),
            nn.Linear(256, 1)     # 输出每个通道的重要性评分
        )

    def compress(self, solute_features):
        p = self.mlp(solute_features)
        device = solute_features.device
        temperature = 1.0
        bias = 0.0001  # 避免 bias 为 0 导致的问题
        eps = (bias - (1 - bias)) * torch.rand(p.size()) + (1 - bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(device)
        gate_inputs = (gate_inputs + p) / temperature
        #gate_inputs = (p) / temperature
        gate_inputs = torch.sigmoid(gate_inputs).squeeze()
        p = torch.sigmoid(p)
        return gate_inputs, p

    def forward(self, x):
        # 卷积与批归一化
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = F.max_pool1d(x, 2)  # 池化大小为1，不改变尺寸
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.max_pool1d(x, 2)  # 池化大小为4，减少特征维度
        x = x.permute(0, 2, 1)  # 调整维度顺序

        # 复制特征图
        static_feature_map = x.clone().detach()
        channel_means = x.mean(dim=1)
        channel_std = x.std(dim=1)

        # 压缩与门控
        channel_importance, p = self.compress(x)
        channel_importance = channel_importance.unsqueeze(-1)

        # **修改部分开始**
        # 将不重要的部分置为0，而不是均值
        ib_x_mean = x * channel_importance  # 去除 (1 - channel_importance) * channel_means.unsqueeze(1)
        ib_x_std = (1 - channel_importance) * channel_std.unsqueeze(1)
        ib_x = ib_x_mean + torch.rand_like(ib_x_mean) * ib_x_std
        # **修改部分结束**

        # KL 散度损失计算
        epsilon = 1e-8
        KL_tensor = 0.5 * (
            (ib_x_std**2) / (channel_std.unsqueeze(1) + epsilon)**2 +
            (channel_std.unsqueeze(1)**2) / (ib_x_std + epsilon)**2 - 1
        ) + (ib_x_mean**2) / (channel_std.unsqueeze(1) + epsilon)**2  # 修改了这里，将 (ib_x_mean - 0)**2 替换为 ib_x_mean**2

        KL_Loss = torch.mean(KL_tensor)

        # 调整维度顺序并返回
        ib_x = ib_x.permute(0, 2, 1)
        return ib_x, KL_Loss, p



def rbf_kernel(x, y, sigma=1.0):
    """
    x: [B, D]
    y: [B, D]
    sigma: RBF 核的带宽 (可调)
    返回: [B, B] 的核矩阵
    """
    # x.unsqueeze(1): [B, 1, D]
    # y.unsqueeze(0): [1, B, D]
    # dist: [B, B]，表示 x_i 与 y_j 的欧式距离平方
    x = x.unsqueeze(1)
    y = y.unsqueeze(0)
    dist = (x - y).pow(2).sum(dim=2)
    kxy = torch.exp(-dist / (2 * sigma**2))
    return kxy

# =============== 2. 定义 HSIC 计算函数 ===============
def hsic(x, y, sigma=1.0):
    """
    计算 x, y 的 HSIC 值: HSIC(X, Y) = 1/(n-1)^2 * Tr(H Kx H Ky)
    x: [B, D]
    y: [B, D]
    sigma: RBF 核带宽
    返回标量 HSIC 值
    """
    assert x.size(0) == y.size(0), "x,y 的 batch size 不一致"
    n = x.size(0)

    # 计算核矩阵
    Kx = rbf_kernel(x, x, sigma=sigma)
    Ky = rbf_kernel(y, y, sigma=sigma)

    # 居中矩阵 H = I - 1/n
    H = torch.eye(n, device=x.device) - (1./n) * torch.ones((n, n), device=x.device)

    # H Kx H
    HKxH = H.mm(Kx).mm(H)

    # HSIC = Tr( (H Kx H) * Ky ) / (n-1)^2
    # 注: 这里相乘可以写为 trace(HKxH @ Ky)，矩阵乘法后再 trace
    hsic_val = torch.trace(HKxH.mm(Ky)) / (float(n - 1) ** 2)

    return hsic_val









import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNModelWithVAE(nn.Module): 
    def __init__(self, num_fgs, channel=128, feature_dim=150, hidden_dim=256, latent_dim=64, m_dim=10):
        """
        参数：
        - num_fgs: 预测目标的维度
        - channel: 每个光谱的通道数（不同频率段）
        - feature_dim: 每个光谱的特征维度
        - hidden_dim: 隐藏层维度
        - latent_dim: 潜在变量 z 的维度
        - m_dim: 预测目标的维度（如有需要）
        """
        super(CNNModelWithVAE, self).__init__()
        self.channel = channel
        self.feature_dim = feature_dim

        # 创建三个独立的CNN模块
        self.cnn1 = IndependentCNN(num_fgs)
        self.cnn2 = IndependentCNN(num_fgs)
        self.cnn3 = IndependentCNN(num_fgs)


        # VAE Encoder: 将三个光谱特征融合成潜在表示 z
        # 将 [B, 3*channel, feature_dim] 展平为 [B, 3*channel*feature_dim]
        self.fc_fusion = nn.Sequential(
            nn.Linear(3 * channel * feature_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # VAE Decoder: 从潜在表示 z 重建三个光谱
        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, channel * feature_dim),
                nn.ReLU()
            ) for _ in range(3)
        ])


        # 增加一个线性层处理 x3 的特征
        self.fc_x3 = nn.Linear(channel * feature_dim, latent_dim)

        # 全连接层用于最终预测，使用 z 和 x3 作为输入
        self.fc1 = nn.Linear(latent_dim *2, 4927)  # z 和 x3
        self.fc2 = nn.Linear(4927, 2785)
        self.fc3 = nn.Linear(2785, 1574)
        self.fc4 = nn.Linear(1574, num_fgs)
        self.dropout = nn.Dropout(0.48599073736368)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)   # ~ N(0, I)
        return mu + std * eps
    
    def forward(self, x):
        """
        前向传播函数。
        
        参数：
        - x: 输入张量，形状为 [batch_size, 3, feature_dim]
        
        返回：
        - 一个包含预测结果和各类损失组件的字典
        """
        # 拆分输入为三个光谱通道
        x1, x2, x3 = x[:, 0:1, :], x[:, 1:2, :], x[:, 2:3, :]  # 每个 [B, 1, feature_dim]

        # 分别通过三个独立的CNN
        ib_x_1, kl_loss1,channal_importance_1= self.cnn1(x1)  # [B, channel, feature_dim]
        ib_x_2, kl_loss2 ,channal_importance_2= self.cnn2(x2)
        ib_x_3, kl_loss3 ,channal_importance_3= self.cnn3(x3)

        # 将三个通道的输出堆叠
        ib_x_stacked = torch.cat([ib_x_1, ib_x_2, ib_x_3], dim=1)  # [B, 3*channel, feature_dim]
        # 展平为 [B, 3*channel*feature_dim]
        ib_x_flat = ib_x_stacked.view(ib_x_stacked.size(0), -1)  # [B, 3*channel*feature_dim]
        # VAE Encoder
        h = self.fc_fusion(ib_x_flat)  # [B, hidden_dim]
        mu = self.fc_mu(h)             # [B, latent_dim]
        logvar = self.fc_logvar(h)     # [B, latent_dim]
        z = self.reparameterize(mu, logvar)  # [B, latent_dim]

        # VAE Decoder: 重建三个光谱
        recon_x = []
        for decoder in self.decoder:
            recon = decoder(z)  # [B, channel * feature_dim]
            recon = recon.view(z.size(0), self.channel, self.feature_dim)  # [B, channel, feature_dim]
            recon_x.append(recon)
        recon_x1, recon_x2, recon_x3 = recon_x  # 各自的重构光谱

        # 条件互信息估计器
        # 将 ib_x_* 展平
        ib_x1_flat = ib_x_1.view(z.size(0), -1)  # [B, channel * feature_dim]
        ib_x2_flat = ib_x_2.view(z.size(0), -1)
        ib_x3_flat = ib_x_3.view(z.size(0), -1)
        sigma=1.0
        # ====== 3.3 计算 HSIC 并做加和 ======
        hsic_x3_x1 = hsic(ib_x3_flat, ib_x1_flat, sigma=sigma)
        hsic_x3_x2 = hsic(ib_x3_flat, ib_x2_flat, sigma=sigma)
        hsic_x3_z  = hsic(ib_x3_flat, z,     sigma=sigma)

        # 将这几个 HSIC 值相加
        hsic_loss = hsic_x3_x1 + hsic_x3_x2 + hsic_x3_z

        # ====== 3.4 总损失: 主损失 + alpha * HSIC(总和) ======
        cmi_loss = hsic_loss

        # 增加 x3 的处理
        x3_processed = self.fc_x3(ib_x3_flat)  # [B, latent_dim]
        z_x3 = torch.cat([z, x3_processed], dim=1)  # [B, 2 * latent_dim]
        x_pred = F.relu(self.fc1(z_x3 ))  # [B, 4927]
        x_pred = self.dropout(x_pred)
        x_pred = F.relu(self.fc2(x_pred))  # [B, 2785]
        x_pred = self.dropout(x_pred)
        x_pred = F.relu(self.fc3(x_pred))  # [B, 1574]
        x_pred = self.dropout(x_pred)
        x_pred = torch.sigmoid(self.fc4(x_pred))  # [B, num_fgs]

        # KL散度损失取平均值（来自 VAE）
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
        kl=( kl_loss1+ kl_loss2+ kl_loss3)/3
        return {
            'x': x_pred,
            'vae_mu': mu,
            'vae_logvar': logvar,
            'recon_x1': recon_x1,
            'recon_x2': recon_x2,
            'recon_x3': recon_x3,
            'cmi_loss': cmi_loss,  # InfoNCE 损失
            'ib_x_1': ib_x_1,
            'ib_x_2': ib_x_2,
            'ib_x_3': ib_x_3,
            'kl':kl,
            'channal_importance_1':channal_importance_1,
            'channal_importance_2':channal_importance_2,
            'channal_importance_3':channal_importance_3
        }



# Training function in PyTorch
from tqdm import tqdm  # 引入 tqdm

b=0.0001
# 定义训练函数
# 定义训练函数
from tqdm import tqdm  # 引入 tqdm

# 定义训练函数
def train_model(X_train, y_train, X_test, y_test, num_fgs, weighted=False, batch_size=41, epochs=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.5, lambda_recon=0.0001):
    device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
    model = CNNModelWithVAE(num_fgs).to(device)
    
    # 定义优化器和损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    if weighted:
        class_weights = calculate_class_weights(y_train)
        criterion = WeightedBinaryCrossEntropyLoss(class_weights).to(device)
    else:
        criterion = nn.BCELoss().to(device)

    # 创建 DataLoader
    y_train = np.array([np.array(item, dtype=np.float32) for item in y_train], dtype=np.float32)
    y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
    train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # 确保保存路径存在
    out_path.mkdir(parents=True, exist_ok=True)
    
    best_f1 = 0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        recon_loss_avg = 0.0
        kl_weight = min(max_lambda_kl, (epoch + 1) / annealing_epochs)
        with tqdm(train_loader, unit='batch', desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
            for batch in tepoch:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                x_pred = outputs['x']
                mu = outputs['vae_mu']
                logvar = outputs['vae_logvar']
                recon_x1 = outputs['recon_x1']
                recon_x2 = outputs['recon_x2']
                recon_x3 = outputs['recon_x3']
                kl = outputs['kl']
                cmi_loss = outputs['cmi_loss']

                # 预测损失
                pred_loss = criterion(x_pred, targets)
                
                # 重建损失
                recon_loss = F.mse_loss(recon_x1, outputs['ib_x_1']) + \
                             F.mse_loss(recon_x2, outputs['ib_x_2']) + \
                             F.mse_loss(recon_x3, outputs['ib_x_3'])
        
                # KL散度损失
                kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
                # 总损失：预测损失 + KL散度 + 互信息损失 + 重建损失
                total_loss = pred_loss + kl_weight * kl_div + \
                             lambda_cmi * cmi_loss + lambda_recon * recon_loss + 0.0000001 * kl
                total_loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                running_loss += total_loss.item()
                recon_loss_avg += recon_loss.item()
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1),
                                  kl_weight=kl_weight)
        
        avg_loss = running_loss / len(train_loader)
        recon_loss_a = recon_loss_avg / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}, KL Weight: {kl_weight}, Recon Loss: {recon_loss_a}')
        
        # 评估F1分数
        model.eval()
        predictions = []
        with torch.no_grad():
            for batch in test_loader:
                inputs, targets = batch
                inputs = inputs.to(device)
                outputs = model(inputs)
                x_pred = outputs['x']
                predictions.append(x_pred.cpu().numpy())
        predictions = np.concatenate(predictions)
        binary_predictions = (predictions > 0.5).astype(int)
        f1 = f1_score(y_test, binary_predictions, average='micro')
        print(f'F1 Score: {f1}')
        
        # 保存最佳模型
        if f1 > best_f1:
            best_f1 = f1
            model_save_path = out_path / "best_model.pth"
            torch.save(model.state_dict(), model_save_path)
            print(f'Best model saved with F1 Score: {best_f1} at {model_save_path}')

    return binary_predictions




# Custom loss function with class weights
class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        loss = self.class_weights[0] * (1 - y_true) * torch.log(1 - y_pred + 1e-15) + \
               self.class_weights[1] * y_true * torch.log(y_pred + 1e-15)
        return -loss.mean()

# Calculate class weights
def calculate_class_weights(y_true):
    num_samples = y_true.shape[0]
    class_weights = np.zeros((2, y_true.shape[1]))
    for i in range(y_true.shape[1]):
        weights_n = num_samples / (2 * (y_true[:, i] == 0).sum())
        weights_p = num_samples / (2 * (y_true[:, i] == 1).sum())
        class_weights[0, i] = weights_n
        class_weights[1, i] = weights_p
    return torch.tensor(class_weights.T, dtype=torch.float32)

# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all")
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]
seed = 3245

# 准备存储合并后的数据
all_data = []
i=0
# 一次性读取文件并处理所有列
for parquet_file in analytical_data.glob("*.parquet"):
    i+=1
    # 读取所有需要的列
    data = pd.read_parquet(parquet_file, columns=columns + ['smiles'])
    
    # 对每个列进行插值
    for column in columns:
        data[column] = data[column].map(interpolate_to_600)
    
    # 添加功能团信息
    data['func_group'] = data.smiles.map(get_functional_groups)
    #在这里就是0/1矩阵了
    all_data.append(data)
    print(f"Loaded Data from: ", i)
    if i==3:
        break
# 合并所有数据
training_data = pd.concat(all_data, ignore_index=True)


# 将数据划分为训练集和测试集
train, test = train_test_split(training_data, test_size=0.1, random_state=seed)

# 定义特征列
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]

# 提取训练集特征和标签
X_train = np.array(train[columns].values.tolist())  # 确保特征值是一个二维数组
y_train = np.array(train['func_group'].values)      # 标签转换为一维数组

# 提取测试集特征和标签
X_test = np.array(test[columns].values.tolist())    # 同样确保二维数组
y_test = np.array(test['func_group'].values)        # 标签一维数组

# 检查数组形状以验证正确性
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_test shape:", X_test.shape)
print("y_test shape:", y_test.shape)
# Train extended model
predictions = train_model(X_train, y_train, X_test, y_test,num_fgs=37, weighted=False, batch_size=41, epochs=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.1, lambda_recon=0.1)

# Evaluate the model
y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)

Loaded Data from:  1
Loaded Data from:  2
Loaded Data from:  3
X_train shape: (8740, 3, 600)
y_train shape: (8740,)
X_test shape: (972, 3, 600)
y_test shape: (972,)


Epoch 1/41: 100%|██████████| 214/214 [00:41<00:00,  5.13batch/s, kl_weight=0.1, loss=1.84]


Epoch 1/41, Loss: 1.8427282397713616, KL Weight: 0.1, Recon Loss: 0.08978283238188128
F1 Score: 0.6037341689515603
Best model saved with F1 Score: 0.6037341689515603 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 2/41: 100%|██████████| 214/214 [00:19<00:00, 11.11batch/s, kl_weight=0.2, loss=0.223]


Epoch 2/41, Loss: 0.22245526425192289, KL Weight: 0.2, Recon Loss: 0.020214818774887892
F1 Score: 0.6052873865579377
Best model saved with F1 Score: 0.6052873865579377 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 3/41: 100%|██████████| 214/214 [00:19<00:00, 10.76batch/s, kl_weight=0.3, loss=0.217]


Epoch 3/41, Loss: 0.21678034529507717, KL Weight: 0.3, Recon Loss: 0.010436423582016168
F1 Score: 0.6732482542740188
Best model saved with F1 Score: 0.6732482542740188 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 4/41: 100%|██████████| 214/214 [00:19<00:00, 10.91batch/s, kl_weight=0.4, loss=0.195]


Epoch 4/41, Loss: 0.19504428000372148, KL Weight: 0.4, Recon Loss: 0.007323026718032137
F1 Score: 0.6698880976602238


Epoch 5/41: 100%|██████████| 214/214 [00:23<00:00,  9.15batch/s, kl_weight=0.5, loss=0.181]


Epoch 5/41, Loss: 0.18086476763275183, KL Weight: 0.5, Recon Loss: 0.005140207574058707
F1 Score: 0.7140663753395535
Best model saved with F1 Score: 0.7140663753395535 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 6/41: 100%|██████████| 214/214 [00:42<00:00,  5.09batch/s, kl_weight=0.6, loss=0.172]


Epoch 6/41, Loss: 0.17239067132506414, KL Weight: 0.6, Recon Loss: 0.002600063535571969
F1 Score: 0.7254811036401576
Best model saved with F1 Score: 0.7254811036401576 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 7/41: 100%|██████████| 214/214 [00:19<00:00, 10.81batch/s, kl_weight=0.7, loss=0.166]


Epoch 7/41, Loss: 0.1651265188076786, KL Weight: 0.7, Recon Loss: 0.0016729431292129176
F1 Score: 0.7220326128175958


Epoch 8/41: 100%|██████████| 214/214 [00:20<00:00, 10.44batch/s, kl_weight=0.8, loss=0.16] 


Epoch 8/41, Loss: 0.15957957290321867, KL Weight: 0.8, Recon Loss: 0.0015022244400162007
F1 Score: 0.7503750721292556
Best model saved with F1 Score: 0.7503750721292556 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 9/41: 100%|██████████| 214/214 [00:21<00:00, 10.07batch/s, kl_weight=0.9, loss=0.155]


Epoch 9/41, Loss: 0.15517087628908247, KL Weight: 0.9, Recon Loss: 0.001485395881186335
F1 Score: 0.7544222375373306
Best model saved with F1 Score: 0.7544222375373306 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 10/41: 100%|██████████| 214/214 [00:19<00:00, 10.72batch/s, kl_weight=1, loss=0.154]


Epoch 10/41, Loss: 0.15233464413714187, KL Weight: 1.0, Recon Loss: 0.001313211606863329
F1 Score: 0.7555816686251469
Best model saved with F1 Score: 0.7555816686251469 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 11/41: 100%|██████████| 214/214 [00:20<00:00, 10.68batch/s, kl_weight=1, loss=0.149]


Epoch 11/41, Loss: 0.14885449217997979, KL Weight: 1.0, Recon Loss: 0.0012395972297331056
F1 Score: 0.7651098901098901
Best model saved with F1 Score: 0.7651098901098901 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 12/41: 100%|██████████| 214/214 [00:19<00:00, 10.94batch/s, kl_weight=1, loss=0.147]


Epoch 12/41, Loss: 0.1467670657994034, KL Weight: 1.0, Recon Loss: 0.0011348838386970146
F1 Score: 0.7547533092659446


Epoch 13/41: 100%|██████████| 214/214 [00:19<00:00, 10.90batch/s, kl_weight=1, loss=0.144]


Epoch 13/41, Loss: 0.1442870909922591, KL Weight: 1.0, Recon Loss: 0.001082408928723623
F1 Score: 0.7566718995290423


Epoch 14/41: 100%|██████████| 214/214 [00:19<00:00, 10.81batch/s, kl_weight=1, loss=0.142]


Epoch 14/41, Loss: 0.14179466602122673, KL Weight: 1.0, Recon Loss: 0.0010440721533970612
F1 Score: 0.7616237904177484


Epoch 15/41: 100%|██████████| 214/214 [00:22<00:00,  9.73batch/s, kl_weight=1, loss=0.139]


Epoch 15/41, Loss: 0.13864554423038092, KL Weight: 1.0, Recon Loss: 0.0010282586271021619
F1 Score: 0.7639325189063408


Epoch 16/41: 100%|██████████| 214/214 [00:19<00:00, 10.71batch/s, kl_weight=1, loss=0.138]


Epoch 16/41, Loss: 0.13824635778910646, KL Weight: 1.0, Recon Loss: 0.0010008637224341455
F1 Score: 0.764946764946765


Epoch 17/41: 100%|██████████| 214/214 [00:19<00:00, 10.72batch/s, kl_weight=1, loss=0.136]


Epoch 17/41, Loss: 0.13534692277975172, KL Weight: 1.0, Recon Loss: 0.0009454859556112369
F1 Score: 0.7536373639809267


Epoch 18/41: 100%|██████████| 214/214 [00:19<00:00, 11.23batch/s, kl_weight=1, loss=0.134]


Epoch 18/41, Loss: 0.13315668987615087, KL Weight: 1.0, Recon Loss: 0.0009211634166612698
F1 Score: 0.7625933388645253


Epoch 19/41: 100%|██████████| 214/214 [00:19<00:00, 10.77batch/s, kl_weight=1, loss=0.131]


Epoch 19/41, Loss: 0.13121189647049547, KL Weight: 1.0, Recon Loss: 0.0009132333262497124
F1 Score: 0.767052767052767
Best model saved with F1 Score: 0.767052767052767 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 20/41: 100%|██████████| 214/214 [00:19<00:00, 10.93batch/s, kl_weight=1, loss=0.13] 


Epoch 20/41, Loss: 0.12907484723864315, KL Weight: 1.0, Recon Loss: 0.00090190511665777
F1 Score: 0.7626818020574672


Epoch 21/41: 100%|██████████| 214/214 [00:20<00:00, 10.56batch/s, kl_weight=1, loss=0.127]


Epoch 21/41, Loss: 0.1271569448584151, KL Weight: 1.0, Recon Loss: 0.0008974798215731595
F1 Score: 0.7581434196396611


Epoch 22/41: 100%|██████████| 214/214 [00:24<00:00,  8.90batch/s, kl_weight=1, loss=0.126]


Epoch 22/41, Loss: 0.12581503607123812, KL Weight: 1.0, Recon Loss: 0.0008859405879945294
F1 Score: 0.7671676300578034
Best model saved with F1 Score: 0.7671676300578034 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 23/41: 100%|██████████| 214/214 [00:37<00:00,  5.76batch/s, kl_weight=1, loss=0.122]


Epoch 23/41, Loss: 0.12237850343373334, KL Weight: 1.0, Recon Loss: 0.0008462339426960936
F1 Score: 0.769865597736383
Best model saved with F1 Score: 0.769865597736383 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 24/41: 100%|██████████| 214/214 [00:19<00:00, 10.95batch/s, kl_weight=1, loss=0.121]


Epoch 24/41, Loss: 0.12091985142955156, KL Weight: 1.0, Recon Loss: 0.0008541256440672849
F1 Score: 0.7709923664122137
Best model saved with F1 Score: 0.7709923664122137 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 25/41: 100%|██████████| 214/214 [00:18<00:00, 11.29batch/s, kl_weight=1, loss=0.12] 


Epoch 25/41, Loss: 0.11951195455600168, KL Weight: 1.0, Recon Loss: 0.0008770960348301401
F1 Score: 0.755967224795155


Epoch 26/41: 100%|██████████| 214/214 [00:19<00:00, 10.73batch/s, kl_weight=1, loss=0.119]


Epoch 26/41, Loss: 0.11854949398575543, KL Weight: 1.0, Recon Loss: 0.0008651116595246663
F1 Score: 0.7758462946020128
Best model saved with F1 Score: 0.7758462946020128 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 27/41: 100%|██████████| 214/214 [00:20<00:00, 10.49batch/s, kl_weight=1, loss=0.118]


Epoch 27/41, Loss: 0.11842268169204766, KL Weight: 1.0, Recon Loss: 0.000850560431364297
F1 Score: 0.7684100661022846


Epoch 28/41: 100%|██████████| 214/214 [00:18<00:00, 11.66batch/s, kl_weight=1, loss=0.115]


Epoch 28/41, Loss: 0.11472476378222492, KL Weight: 1.0, Recon Loss: 0.0008691812040736882
F1 Score: 0.7729232203192357


Epoch 29/41: 100%|██████████| 214/214 [00:20<00:00, 10.49batch/s, kl_weight=1, loss=0.113]


Epoch 29/41, Loss: 0.11295928377832208, KL Weight: 1.0, Recon Loss: 0.0008120936451471089
F1 Score: 0.7764705882352941
Best model saved with F1 Score: 0.7764705882352941 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 30/41: 100%|██████████| 214/214 [00:21<00:00, 10.15batch/s, kl_weight=1, loss=0.113]


Epoch 30/41, Loss: 0.11230519351279625, KL Weight: 1.0, Recon Loss: 0.0008189267120933351
F1 Score: 0.7756432646408197


Epoch 31/41: 100%|██████████| 214/214 [00:20<00:00, 10.20batch/s, kl_weight=1, loss=0.111]


Epoch 31/41, Loss: 0.11095400635048608, KL Weight: 1.0, Recon Loss: 0.0008318326966380091
F1 Score: 0.7787610619469026
Best model saved with F1 Score: 0.7787610619469026 at /home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth


Epoch 32/41: 100%|██████████| 214/214 [00:33<00:00,  6.42batch/s, kl_weight=1, loss=0.11] 


Epoch 32/41, Loss: 0.10965102810029671, KL Weight: 1.0, Recon Loss: 0.0008155837419516816
F1 Score: 0.7723029658513529


Epoch 33/41: 100%|██████████| 214/214 [00:35<00:00,  5.99batch/s, kl_weight=1, loss=0.11] 


Epoch 33/41, Loss: 0.10952699111304551, KL Weight: 1.0, Recon Loss: 0.0008277536917120984
F1 Score: 0.7723369691343699


Epoch 34/41: 100%|██████████| 214/214 [00:35<00:00,  6.07batch/s, kl_weight=1, loss=0.108]


Epoch 34/41, Loss: 0.10822168678463062, KL Weight: 1.0, Recon Loss: 0.0007802124984375262
F1 Score: 0.7693389592123769


Epoch 35/41: 100%|██████████| 214/214 [00:34<00:00,  6.25batch/s, kl_weight=1, loss=0.105]


Epoch 35/41, Loss: 0.10519002827110692, KL Weight: 1.0, Recon Loss: 0.0007920977281445689
F1 Score: 0.7777006708304418


Epoch 36/41: 100%|██████████| 214/214 [00:22<00:00,  9.64batch/s, kl_weight=1, loss=0.105]


Epoch 36/41, Loss: 0.10402593176777118, KL Weight: 1.0, Recon Loss: 0.0007994765011598051
F1 Score: 0.766903073286052


Epoch 37/41: 100%|██████████| 214/214 [00:19<00:00, 10.97batch/s, kl_weight=1, loss=0.104]


Epoch 37/41, Loss: 0.10391852506828085, KL Weight: 1.0, Recon Loss: 0.0008202914546230791
F1 Score: 0.7774071091682964


Epoch 38/41: 100%|██████████| 214/214 [00:21<00:00, 10.15batch/s, kl_weight=1, loss=0.102]


Epoch 38/41, Loss: 0.10151361520880851, KL Weight: 1.0, Recon Loss: 0.0007852204167120872
F1 Score: 0.7787528868360277


Epoch 39/41: 100%|██████████| 214/214 [00:19<00:00, 10.89batch/s, kl_weight=1, loss=0.1]   


Epoch 39/41, Loss: 0.09985604419190193, KL Weight: 1.0, Recon Loss: 0.0007848216685309797
F1 Score: 0.7705184840734712


Epoch 40/41: 100%|██████████| 214/214 [00:19<00:00, 10.89batch/s, kl_weight=1, loss=0.1]   


Epoch 40/41, Loss: 0.1003656875446578, KL Weight: 1.0, Recon Loss: 0.0007883615206457048
F1 Score: 0.7697383926467122


Epoch 41/41: 100%|██████████| 214/214 [00:19<00:00, 10.92batch/s, kl_weight=1, loss=0.0988]


Epoch 41/41, Loss: 0.09880477861962586, KL Weight: 1.0, Recon Loss: 0.0007757751350552192
F1 Score: 0.7682613768961494
F1 Score: 0.7682613768961494


In [2]:
######初始的
import matplotlib.pyplot as plt
# 定义训练函数
def evalute_model(X_test, y_test, model_path,smiles,ir,num_fgs, weighted=False, batch_size=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.5, lambda_recon=0.1):
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model = CNNModelWithVAE(num_fgs).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    

    # 创建 DataLoader
    y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    
    # 评估F1分数
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            x_pred = outputs['x']
            channal_importance_1 = outputs['channal_importance_1']
            channal_importance_2 = outputs['channal_importance_2']
            channal_importance_3 = outputs['channal_importance_3']
            channal_importance_3_cpu = channal_importance_3.squeeze().cpu().numpy()  # squeeze 去除维度 [1, 150, 1] 转为 [150, 1]



            # 步骤4: 可视化
            plt.plot(np.arange(75), channal_importance_3_cpu)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('Importance')
            plt.show()
            
            
            plt.plot(np.arange(1800), ir)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('ir')
            plt.show()
            
            predictions.append(x_pred.cpu().numpy())
    predictions = np.concatenate(predictions)
    binary_predictions = (predictions > 0.5).astype(int)
    f1 = f1_score(y_test, binary_predictions, average='micro')
    print(f'F1 Score: {f1}')

    return binary_predictions




# Custom loss function with class weights
class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        loss = self.class_weights[0] * (1 - y_true) * torch.log(1 - y_pred + 1e-15) + \
               self.class_weights[1] * y_true * torch.log(y_pred + 1e-15)
        return -loss.mean()

# Calculate class weights
def calculate_class_weights(y_true):
    num_samples = y_true.shape[0]
    class_weights = np.zeros((2, y_true.shape[1]))
    for i in range(y_true.shape[1]):
        weights_n = num_samples / (2 * (y_true[:, i] == 0).sum())
        weights_p = num_samples / (2 * (y_true[:, i] == 1).sum())
        class_weights[0, i] = weights_n
        class_weights[1, i] = weights_p
    return torch.tensor(class_weights.T, dtype=torch.float32)

# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all")
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]
seed = 3245
model_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth")
# 准备存储合并后的数据
all_data = []
i=0
# 一次性读取文件并处理所有列
for parquet_file in analytical_data.glob("*.parquet"):
    i+=1
    # 读取所有需要的列
    data = pd.read_parquet(parquet_file, columns=columns + ['smiles'])
    # 对每个列进行插值
    for column in columns:
        data[column+"ori"] = data[column]
        data[column] = data[column].map(interpolate_to_600)
    
    # 添加功能团信息
    data['func_group'] = data.smiles.map(get_functional_groups)
    #在这里就是0/1矩阵了
    all_data.append(data)
    print(f"Loaded Data from: ", i)
    if i==3:
        break
# 合并所有数据
training_data = pd.concat(all_data, ignore_index=True)


# 将数据划分为训练集和测试集
train, test = train_test_split(training_data, test_size=1, random_state=seed)
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]


# 提取测试集特征和标签
X_test = np.array(test[columns].values.tolist())    # 同样确保二维数组
y_test = np.array(test['func_group'].values)        # 标签一维数组
print(len(test["ir_spectraori"].values.tolist()[0]))
smiles = test["smiles"]
ir = test["ir_spectraori"].values.tolist()[0]
# Train extended model
print(test['func_group'].values)
predictions = evalute_model( X_test, y_test,model_path,smiles,ir,num_fgs=37, weighted=False, batch_size=1, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.1, lambda_recon=0.1)

# Evaluate the model
y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)

Loaded Data from:  1
Loaded Data from:  2
Loaded Data from:  3
1800
[list([0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]


RuntimeError: Error(s) in loading state_dict for CNNModelWithVAE:
	Missing key(s) in state_dict: "cnn1.mlp.0.weight", "cnn1.mlp.0.bias", "cnn1.mlp.2.weight", "cnn1.mlp.2.bias", "cnn2.mlp.0.weight", "cnn2.mlp.0.bias", "cnn2.mlp.2.weight", "cnn2.mlp.2.bias", "cnn3.mlp.0.weight", "cnn3.mlp.0.bias", "cnn3.mlp.2.weight", "cnn3.mlp.2.bias". 
	Unexpected key(s) in state_dict: "cnn1.mlps.0.0.weight", "cnn1.mlps.0.0.bias", "cnn1.mlps.0.2.weight", "cnn1.mlps.0.2.bias", "cnn1.mlps.1.0.weight", "cnn1.mlps.1.0.bias", "cnn1.mlps.1.2.weight", "cnn1.mlps.1.2.bias", "cnn1.mlps.2.0.weight", "cnn1.mlps.2.0.bias", "cnn1.mlps.2.2.weight", "cnn1.mlps.2.2.bias", "cnn1.mlps.3.0.weight", "cnn1.mlps.3.0.bias", "cnn1.mlps.3.2.weight", "cnn1.mlps.3.2.bias", "cnn1.mlps.4.0.weight", "cnn1.mlps.4.0.bias", "cnn1.mlps.4.2.weight", "cnn1.mlps.4.2.bias", "cnn1.mlps.5.0.weight", "cnn1.mlps.5.0.bias", "cnn1.mlps.5.2.weight", "cnn1.mlps.5.2.bias", "cnn1.mlps.6.0.weight", "cnn1.mlps.6.0.bias", "cnn1.mlps.6.2.weight", "cnn1.mlps.6.2.bias", "cnn1.mlps.7.0.weight", "cnn1.mlps.7.0.bias", "cnn1.mlps.7.2.weight", "cnn1.mlps.7.2.bias", "cnn1.mlps.8.0.weight", "cnn1.mlps.8.0.bias", "cnn1.mlps.8.2.weight", "cnn1.mlps.8.2.bias", "cnn1.mlps.9.0.weight", "cnn1.mlps.9.0.bias", "cnn1.mlps.9.2.weight", "cnn1.mlps.9.2.bias", "cnn1.mlps.10.0.weight", "cnn1.mlps.10.0.bias", "cnn1.mlps.10.2.weight", "cnn1.mlps.10.2.bias", "cnn1.mlps.11.0.weight", "cnn1.mlps.11.0.bias", "cnn1.mlps.11.2.weight", "cnn1.mlps.11.2.bias", "cnn1.mlps.12.0.weight", "cnn1.mlps.12.0.bias", "cnn1.mlps.12.2.weight", "cnn1.mlps.12.2.bias", "cnn1.mlps.13.0.weight", "cnn1.mlps.13.0.bias", "cnn1.mlps.13.2.weight", "cnn1.mlps.13.2.bias", "cnn1.mlps.14.0.weight", "cnn1.mlps.14.0.bias", "cnn1.mlps.14.2.weight", "cnn1.mlps.14.2.bias", "cnn1.mlps.15.0.weight", "cnn1.mlps.15.0.bias", "cnn1.mlps.15.2.weight", "cnn1.mlps.15.2.bias", "cnn1.mlps.16.0.weight", "cnn1.mlps.16.0.bias", "cnn1.mlps.16.2.weight", "cnn1.mlps.16.2.bias", "cnn1.mlps.17.0.weight", "cnn1.mlps.17.0.bias", "cnn1.mlps.17.2.weight", "cnn1.mlps.17.2.bias", "cnn1.mlps.18.0.weight", "cnn1.mlps.18.0.bias", "cnn1.mlps.18.2.weight", "cnn1.mlps.18.2.bias", "cnn1.mlps.19.0.weight", "cnn1.mlps.19.0.bias", "cnn1.mlps.19.2.weight", "cnn1.mlps.19.2.bias", "cnn1.mlps.20.0.weight", "cnn1.mlps.20.0.bias", "cnn1.mlps.20.2.weight", "cnn1.mlps.20.2.bias", "cnn1.mlps.21.0.weight", "cnn1.mlps.21.0.bias", "cnn1.mlps.21.2.weight", "cnn1.mlps.21.2.bias", "cnn1.mlps.22.0.weight", "cnn1.mlps.22.0.bias", "cnn1.mlps.22.2.weight", "cnn1.mlps.22.2.bias", "cnn1.mlps.23.0.weight", "cnn1.mlps.23.0.bias", "cnn1.mlps.23.2.weight", "cnn1.mlps.23.2.bias", "cnn1.mlps.24.0.weight", "cnn1.mlps.24.0.bias", "cnn1.mlps.24.2.weight", "cnn1.mlps.24.2.bias", "cnn1.mlps.25.0.weight", "cnn1.mlps.25.0.bias", "cnn1.mlps.25.2.weight", "cnn1.mlps.25.2.bias", "cnn1.mlps.26.0.weight", "cnn1.mlps.26.0.bias", "cnn1.mlps.26.2.weight", "cnn1.mlps.26.2.bias", "cnn1.mlps.27.0.weight", "cnn1.mlps.27.0.bias", "cnn1.mlps.27.2.weight", "cnn1.mlps.27.2.bias", "cnn1.mlps.28.0.weight", "cnn1.mlps.28.0.bias", "cnn1.mlps.28.2.weight", "cnn1.mlps.28.2.bias", "cnn1.mlps.29.0.weight", "cnn1.mlps.29.0.bias", "cnn1.mlps.29.2.weight", "cnn1.mlps.29.2.bias", "cnn1.mlps.30.0.weight", "cnn1.mlps.30.0.bias", "cnn1.mlps.30.2.weight", "cnn1.mlps.30.2.bias", "cnn1.mlps.31.0.weight", "cnn1.mlps.31.0.bias", "cnn1.mlps.31.2.weight", "cnn1.mlps.31.2.bias", "cnn1.mlps.32.0.weight", "cnn1.mlps.32.0.bias", "cnn1.mlps.32.2.weight", "cnn1.mlps.32.2.bias", "cnn1.mlps.33.0.weight", "cnn1.mlps.33.0.bias", "cnn1.mlps.33.2.weight", "cnn1.mlps.33.2.bias", "cnn1.mlps.34.0.weight", "cnn1.mlps.34.0.bias", "cnn1.mlps.34.2.weight", "cnn1.mlps.34.2.bias", "cnn1.mlps.35.0.weight", "cnn1.mlps.35.0.bias", "cnn1.mlps.35.2.weight", "cnn1.mlps.35.2.bias", "cnn1.mlps.36.0.weight", "cnn1.mlps.36.0.bias", "cnn1.mlps.36.2.weight", "cnn1.mlps.36.2.bias", "cnn2.mlps.0.0.weight", "cnn2.mlps.0.0.bias", "cnn2.mlps.0.2.weight", "cnn2.mlps.0.2.bias", "cnn2.mlps.1.0.weight", "cnn2.mlps.1.0.bias", "cnn2.mlps.1.2.weight", "cnn2.mlps.1.2.bias", "cnn2.mlps.2.0.weight", "cnn2.mlps.2.0.bias", "cnn2.mlps.2.2.weight", "cnn2.mlps.2.2.bias", "cnn2.mlps.3.0.weight", "cnn2.mlps.3.0.bias", "cnn2.mlps.3.2.weight", "cnn2.mlps.3.2.bias", "cnn2.mlps.4.0.weight", "cnn2.mlps.4.0.bias", "cnn2.mlps.4.2.weight", "cnn2.mlps.4.2.bias", "cnn2.mlps.5.0.weight", "cnn2.mlps.5.0.bias", "cnn2.mlps.5.2.weight", "cnn2.mlps.5.2.bias", "cnn2.mlps.6.0.weight", "cnn2.mlps.6.0.bias", "cnn2.mlps.6.2.weight", "cnn2.mlps.6.2.bias", "cnn2.mlps.7.0.weight", "cnn2.mlps.7.0.bias", "cnn2.mlps.7.2.weight", "cnn2.mlps.7.2.bias", "cnn2.mlps.8.0.weight", "cnn2.mlps.8.0.bias", "cnn2.mlps.8.2.weight", "cnn2.mlps.8.2.bias", "cnn2.mlps.9.0.weight", "cnn2.mlps.9.0.bias", "cnn2.mlps.9.2.weight", "cnn2.mlps.9.2.bias", "cnn2.mlps.10.0.weight", "cnn2.mlps.10.0.bias", "cnn2.mlps.10.2.weight", "cnn2.mlps.10.2.bias", "cnn2.mlps.11.0.weight", "cnn2.mlps.11.0.bias", "cnn2.mlps.11.2.weight", "cnn2.mlps.11.2.bias", "cnn2.mlps.12.0.weight", "cnn2.mlps.12.0.bias", "cnn2.mlps.12.2.weight", "cnn2.mlps.12.2.bias", "cnn2.mlps.13.0.weight", "cnn2.mlps.13.0.bias", "cnn2.mlps.13.2.weight", "cnn2.mlps.13.2.bias", "cnn2.mlps.14.0.weight", "cnn2.mlps.14.0.bias", "cnn2.mlps.14.2.weight", "cnn2.mlps.14.2.bias", "cnn2.mlps.15.0.weight", "cnn2.mlps.15.0.bias", "cnn2.mlps.15.2.weight", "cnn2.mlps.15.2.bias", "cnn2.mlps.16.0.weight", "cnn2.mlps.16.0.bias", "cnn2.mlps.16.2.weight", "cnn2.mlps.16.2.bias", "cnn2.mlps.17.0.weight", "cnn2.mlps.17.0.bias", "cnn2.mlps.17.2.weight", "cnn2.mlps.17.2.bias", "cnn2.mlps.18.0.weight", "cnn2.mlps.18.0.bias", "cnn2.mlps.18.2.weight", "cnn2.mlps.18.2.bias", "cnn2.mlps.19.0.weight", "cnn2.mlps.19.0.bias", "cnn2.mlps.19.2.weight", "cnn2.mlps.19.2.bias", "cnn2.mlps.20.0.weight", "cnn2.mlps.20.0.bias", "cnn2.mlps.20.2.weight", "cnn2.mlps.20.2.bias", "cnn2.mlps.21.0.weight", "cnn2.mlps.21.0.bias", "cnn2.mlps.21.2.weight", "cnn2.mlps.21.2.bias", "cnn2.mlps.22.0.weight", "cnn2.mlps.22.0.bias", "cnn2.mlps.22.2.weight", "cnn2.mlps.22.2.bias", "cnn2.mlps.23.0.weight", "cnn2.mlps.23.0.bias", "cnn2.mlps.23.2.weight", "cnn2.mlps.23.2.bias", "cnn2.mlps.24.0.weight", "cnn2.mlps.24.0.bias", "cnn2.mlps.24.2.weight", "cnn2.mlps.24.2.bias", "cnn2.mlps.25.0.weight", "cnn2.mlps.25.0.bias", "cnn2.mlps.25.2.weight", "cnn2.mlps.25.2.bias", "cnn2.mlps.26.0.weight", "cnn2.mlps.26.0.bias", "cnn2.mlps.26.2.weight", "cnn2.mlps.26.2.bias", "cnn2.mlps.27.0.weight", "cnn2.mlps.27.0.bias", "cnn2.mlps.27.2.weight", "cnn2.mlps.27.2.bias", "cnn2.mlps.28.0.weight", "cnn2.mlps.28.0.bias", "cnn2.mlps.28.2.weight", "cnn2.mlps.28.2.bias", "cnn2.mlps.29.0.weight", "cnn2.mlps.29.0.bias", "cnn2.mlps.29.2.weight", "cnn2.mlps.29.2.bias", "cnn2.mlps.30.0.weight", "cnn2.mlps.30.0.bias", "cnn2.mlps.30.2.weight", "cnn2.mlps.30.2.bias", "cnn2.mlps.31.0.weight", "cnn2.mlps.31.0.bias", "cnn2.mlps.31.2.weight", "cnn2.mlps.31.2.bias", "cnn2.mlps.32.0.weight", "cnn2.mlps.32.0.bias", "cnn2.mlps.32.2.weight", "cnn2.mlps.32.2.bias", "cnn2.mlps.33.0.weight", "cnn2.mlps.33.0.bias", "cnn2.mlps.33.2.weight", "cnn2.mlps.33.2.bias", "cnn2.mlps.34.0.weight", "cnn2.mlps.34.0.bias", "cnn2.mlps.34.2.weight", "cnn2.mlps.34.2.bias", "cnn2.mlps.35.0.weight", "cnn2.mlps.35.0.bias", "cnn2.mlps.35.2.weight", "cnn2.mlps.35.2.bias", "cnn2.mlps.36.0.weight", "cnn2.mlps.36.0.bias", "cnn2.mlps.36.2.weight", "cnn2.mlps.36.2.bias", "cnn3.mlps.0.0.weight", "cnn3.mlps.0.0.bias", "cnn3.mlps.0.2.weight", "cnn3.mlps.0.2.bias", "cnn3.mlps.1.0.weight", "cnn3.mlps.1.0.bias", "cnn3.mlps.1.2.weight", "cnn3.mlps.1.2.bias", "cnn3.mlps.2.0.weight", "cnn3.mlps.2.0.bias", "cnn3.mlps.2.2.weight", "cnn3.mlps.2.2.bias", "cnn3.mlps.3.0.weight", "cnn3.mlps.3.0.bias", "cnn3.mlps.3.2.weight", "cnn3.mlps.3.2.bias", "cnn3.mlps.4.0.weight", "cnn3.mlps.4.0.bias", "cnn3.mlps.4.2.weight", "cnn3.mlps.4.2.bias", "cnn3.mlps.5.0.weight", "cnn3.mlps.5.0.bias", "cnn3.mlps.5.2.weight", "cnn3.mlps.5.2.bias", "cnn3.mlps.6.0.weight", "cnn3.mlps.6.0.bias", "cnn3.mlps.6.2.weight", "cnn3.mlps.6.2.bias", "cnn3.mlps.7.0.weight", "cnn3.mlps.7.0.bias", "cnn3.mlps.7.2.weight", "cnn3.mlps.7.2.bias", "cnn3.mlps.8.0.weight", "cnn3.mlps.8.0.bias", "cnn3.mlps.8.2.weight", "cnn3.mlps.8.2.bias", "cnn3.mlps.9.0.weight", "cnn3.mlps.9.0.bias", "cnn3.mlps.9.2.weight", "cnn3.mlps.9.2.bias", "cnn3.mlps.10.0.weight", "cnn3.mlps.10.0.bias", "cnn3.mlps.10.2.weight", "cnn3.mlps.10.2.bias", "cnn3.mlps.11.0.weight", "cnn3.mlps.11.0.bias", "cnn3.mlps.11.2.weight", "cnn3.mlps.11.2.bias", "cnn3.mlps.12.0.weight", "cnn3.mlps.12.0.bias", "cnn3.mlps.12.2.weight", "cnn3.mlps.12.2.bias", "cnn3.mlps.13.0.weight", "cnn3.mlps.13.0.bias", "cnn3.mlps.13.2.weight", "cnn3.mlps.13.2.bias", "cnn3.mlps.14.0.weight", "cnn3.mlps.14.0.bias", "cnn3.mlps.14.2.weight", "cnn3.mlps.14.2.bias", "cnn3.mlps.15.0.weight", "cnn3.mlps.15.0.bias", "cnn3.mlps.15.2.weight", "cnn3.mlps.15.2.bias", "cnn3.mlps.16.0.weight", "cnn3.mlps.16.0.bias", "cnn3.mlps.16.2.weight", "cnn3.mlps.16.2.bias", "cnn3.mlps.17.0.weight", "cnn3.mlps.17.0.bias", "cnn3.mlps.17.2.weight", "cnn3.mlps.17.2.bias", "cnn3.mlps.18.0.weight", "cnn3.mlps.18.0.bias", "cnn3.mlps.18.2.weight", "cnn3.mlps.18.2.bias", "cnn3.mlps.19.0.weight", "cnn3.mlps.19.0.bias", "cnn3.mlps.19.2.weight", "cnn3.mlps.19.2.bias", "cnn3.mlps.20.0.weight", "cnn3.mlps.20.0.bias", "cnn3.mlps.20.2.weight", "cnn3.mlps.20.2.bias", "cnn3.mlps.21.0.weight", "cnn3.mlps.21.0.bias", "cnn3.mlps.21.2.weight", "cnn3.mlps.21.2.bias", "cnn3.mlps.22.0.weight", "cnn3.mlps.22.0.bias", "cnn3.mlps.22.2.weight", "cnn3.mlps.22.2.bias", "cnn3.mlps.23.0.weight", "cnn3.mlps.23.0.bias", "cnn3.mlps.23.2.weight", "cnn3.mlps.23.2.bias", "cnn3.mlps.24.0.weight", "cnn3.mlps.24.0.bias", "cnn3.mlps.24.2.weight", "cnn3.mlps.24.2.bias", "cnn3.mlps.25.0.weight", "cnn3.mlps.25.0.bias", "cnn3.mlps.25.2.weight", "cnn3.mlps.25.2.bias", "cnn3.mlps.26.0.weight", "cnn3.mlps.26.0.bias", "cnn3.mlps.26.2.weight", "cnn3.mlps.26.2.bias", "cnn3.mlps.27.0.weight", "cnn3.mlps.27.0.bias", "cnn3.mlps.27.2.weight", "cnn3.mlps.27.2.bias", "cnn3.mlps.28.0.weight", "cnn3.mlps.28.0.bias", "cnn3.mlps.28.2.weight", "cnn3.mlps.28.2.bias", "cnn3.mlps.29.0.weight", "cnn3.mlps.29.0.bias", "cnn3.mlps.29.2.weight", "cnn3.mlps.29.2.bias", "cnn3.mlps.30.0.weight", "cnn3.mlps.30.0.bias", "cnn3.mlps.30.2.weight", "cnn3.mlps.30.2.bias", "cnn3.mlps.31.0.weight", "cnn3.mlps.31.0.bias", "cnn3.mlps.31.2.weight", "cnn3.mlps.31.2.bias", "cnn3.mlps.32.0.weight", "cnn3.mlps.32.0.bias", "cnn3.mlps.32.2.weight", "cnn3.mlps.32.2.bias", "cnn3.mlps.33.0.weight", "cnn3.mlps.33.0.bias", "cnn3.mlps.33.2.weight", "cnn3.mlps.33.2.bias", "cnn3.mlps.34.0.weight", "cnn3.mlps.34.0.bias", "cnn3.mlps.34.2.weight", "cnn3.mlps.34.2.bias", "cnn3.mlps.35.0.weight", "cnn3.mlps.35.0.bias", "cnn3.mlps.35.2.weight", "cnn3.mlps.35.2.bias", "cnn3.mlps.36.0.weight", "cnn3.mlps.36.0.bias", "cnn3.mlps.36.2.weight", "cnn3.mlps.36.2.bias". 
	size mismatch for cnn1.conv1.weight: copying a param with shape torch.Size([31, 1, 11]) from checkpoint, the shape in current model is torch.Size([64, 1, 11]).
	size mismatch for cnn1.conv1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn1.conv2.weight: copying a param with shape torch.Size([62, 31, 11]) from checkpoint, the shape in current model is torch.Size([128, 64, 11]).
	size mismatch for cnn1.conv2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn1.batch_norm1.weight: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn1.batch_norm1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn1.batch_norm1.running_mean: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn1.batch_norm1.running_var: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn1.batch_norm2.weight: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn1.batch_norm2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn1.batch_norm2.running_mean: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn1.batch_norm2.running_var: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn2.conv1.weight: copying a param with shape torch.Size([31, 1, 11]) from checkpoint, the shape in current model is torch.Size([64, 1, 11]).
	size mismatch for cnn2.conv1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn2.conv2.weight: copying a param with shape torch.Size([62, 31, 11]) from checkpoint, the shape in current model is torch.Size([128, 64, 11]).
	size mismatch for cnn2.conv2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn2.batch_norm1.weight: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn2.batch_norm1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn2.batch_norm1.running_mean: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn2.batch_norm1.running_var: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn2.batch_norm2.weight: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn2.batch_norm2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn2.batch_norm2.running_mean: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn2.batch_norm2.running_var: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn3.conv1.weight: copying a param with shape torch.Size([31, 1, 11]) from checkpoint, the shape in current model is torch.Size([64, 1, 11]).
	size mismatch for cnn3.conv1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn3.conv2.weight: copying a param with shape torch.Size([62, 31, 11]) from checkpoint, the shape in current model is torch.Size([128, 64, 11]).
	size mismatch for cnn3.conv2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn3.batch_norm1.weight: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn3.batch_norm1.bias: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn3.batch_norm1.running_mean: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn3.batch_norm1.running_var: copying a param with shape torch.Size([31]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for cnn3.batch_norm2.weight: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn3.batch_norm2.bias: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn3.batch_norm2.running_mean: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for cnn3.batch_norm2.running_var: copying a param with shape torch.Size([62]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for fc_fusion.0.weight: copying a param with shape torch.Size([256, 27900]) from checkpoint, the shape in current model is torch.Size([256, 57600]).
	size mismatch for decoder.0.2.weight: copying a param with shape torch.Size([9300, 256]) from checkpoint, the shape in current model is torch.Size([19200, 256]).
	size mismatch for decoder.0.2.bias: copying a param with shape torch.Size([9300]) from checkpoint, the shape in current model is torch.Size([19200]).
	size mismatch for decoder.1.2.weight: copying a param with shape torch.Size([9300, 256]) from checkpoint, the shape in current model is torch.Size([19200, 256]).
	size mismatch for decoder.1.2.bias: copying a param with shape torch.Size([9300]) from checkpoint, the shape in current model is torch.Size([19200]).
	size mismatch for decoder.2.2.weight: copying a param with shape torch.Size([9300, 256]) from checkpoint, the shape in current model is torch.Size([19200, 256]).
	size mismatch for decoder.2.2.bias: copying a param with shape torch.Size([9300]) from checkpoint, the shape in current model is torch.Size([19200]).
	size mismatch for fc_x3.weight: copying a param with shape torch.Size([64, 9300]) from checkpoint, the shape in current model is torch.Size([64, 19200]).
	size mismatch for fc4.weight: copying a param with shape torch.Size([1, 1574]) from checkpoint, the shape in current model is torch.Size([37, 1574]).
	size mismatch for fc4.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([37]).

In [3]:
######插值600的，插值600和 1800的在图像上是没有区别的。
import matplotlib.pyplot as plt
# 定义训练函数
def evalute_model(X_test, y_test, model_path,smiles,ir,num_fgs, weighted=False, batch_size=41, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.5, lambda_recon=0.1):
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model = CNNModelWithVAE(num_fgs).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    

    # 创建 DataLoader
    y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    
    # 评估F1分数
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            outputs = model(inputs)
            x_pred = outputs['x']
            channal_importance_1 = outputs['channal_importance_1']
            channal_importance_2 = outputs['channal_importance_2']
            channal_importance_3 = outputs['channal_importance_3']
            channal_importance_3_cpu = channal_importance_3.squeeze().cpu().numpy()  # squeeze 去除维度 [1, 150, 1] 转为 [150, 1]



            # 步骤4: 可视化
            plt.plot(np.arange(150), channal_importance_3_cpu)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('Importance')
            plt.show()
            
            
            plt.plot(np.arange(600), ir)
            plt.title(smiles)
            plt.xlabel('Wavelength Index')
            plt.ylabel('ir')
            plt.show()
            
            predictions.append(x_pred.cpu().numpy())
    predictions = np.concatenate(predictions)
    binary_predictions = (predictions > 0.5).astype(int)
    f1 = f1_score(y_test, binary_predictions, average='micro')
    print(f'F1 Score: {f1}')

    return binary_predictions




# Custom loss function with class weights
class WeightedBinaryCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedBinaryCrossEntropyLoss, self).__init__()
        self.class_weights = class_weights

    def forward(self, y_pred, y_true):
        loss = self.class_weights[0] * (1 - y_true) * torch.log(1 - y_pred + 1e-15) + \
               self.class_weights[1] * y_true * torch.log(y_pred + 1e-15)
        return -loss.mean()

# Calculate class weights
def calculate_class_weights(y_true):
    num_samples = y_true.shape[0]
    class_weights = np.zeros((2, y_true.shape[1]))
    for i in range(y_true.shape[1]):
        weights_n = num_samples / (2 * (y_true[:, i] == 0).sum())
        weights_p = num_samples / (2 * (y_true[:, i] == 1).sum())
        class_weights[0, i] = weights_n
        class_weights[1, i] = weights_p
    return torch.tensor(class_weights.T, dtype=torch.float32)

# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all")
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]
seed = 3245
model_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/all/best_model.pth")
# 准备存储合并后的数据
all_data = []
i=0
# 一次性读取文件并处理所有列
for parquet_file in analytical_data.glob("*.parquet"):
    i+=1
    # 读取所有需要的列
    data = pd.read_parquet(parquet_file, columns=columns + ['smiles'])
    # 对每个列进行插值
    for column in columns:
        data[column+"ori"] = data[column]
        data[column] = data[column].map(interpolate_to_600)
    
    # 添加功能团信息
    data['func_group'] = data.smiles.map(get_functional_groups)
    #在这里就是0/1矩阵了
    all_data.append(data)
    print(f"Loaded Data from: ", i)
    if i==3:
        break
# 合并所有数据
training_data = pd.concat(all_data, ignore_index=True)


# 将数据划分为训练集和测试集
train, test = train_test_split(training_data, test_size=1, random_state=seed)
columns = ["h_nmr_spectra", "c_nmr_spectra", "ir_spectra"]


# 提取测试集特征和标签
X_test = np.array(test[columns].values.tolist())    # 同样确保二维数组
y_test = np.array(test['func_group'].values)        # 标签一维数组
print(len(test["ir_spectraori"].values.tolist()[0]))
smiles = test["smiles"]
ir = test["ir_spectra"].values.tolist()[0]
# Train extended model
print(test['func_group'].values)
predictions = evalute_model( X_test, y_test,model_path,smiles,ir,num_fgs=37, weighted=False, batch_size=1, 
                annealing_epochs=10, max_lambda_kl=1.0, lambda_cmi=0.1, lambda_recon=0.1)

# Evaluate the model
y_test = np.array([np.array(item, dtype=np.float32) for item in y_test], dtype=np.float32)
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)

Loaded Data from:  1


KeyboardInterrupt: 