In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import pickle
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from pathlib import Path
from rdkit import Chem
from rdkit import RDLogger
from scipy.interpolate import interp1d
from torch.utils.data import DataLoader, TensorDataset

# Disable RDLogger warnings
RDLogger.DisableLog('rdApp.*')
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
functional_groups = {
    'Acid anhydride': Chem.MolFromSmarts('[CX3](=[OX1])[OX2][CX3](=[OX1])'),
    'Acyl halide': Chem.MolFromSmarts('[CX3](=[OX1])[F,Cl,Br,I]'),
    'Alcohol': Chem.MolFromSmarts('[#6][OX2H]'),
    'Aldehyde': Chem.MolFromSmarts('[CX3H1](=O)[#6,H]'),
    'Alkane': Chem.MolFromSmarts('[CX4;H3,H2]'),
    'Alkene': Chem.MolFromSmarts('[CX3]=[CX3]'),
    'Alkyne': Chem.MolFromSmarts('[CX2]#[CX2]'),
    'Amide': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[#6]'),
    'Amine': Chem.MolFromSmarts('[NX3;H2,H1,H0;!$(NC=O)]'),
    'Arene': Chem.MolFromSmarts('[cX3]1[cX3][cX3][cX3][cX3][cX3]1'),
    'Azo compound': Chem.MolFromSmarts('[#6][NX2]=[NX2][#6]'),
    'Carbamate': Chem.MolFromSmarts('[NX3][CX3](=[OX1])[OX2H0]'),
    'Carboxylic acid': Chem.MolFromSmarts('[CX3](=O)[OX2H]'),
    'Enamine': Chem.MolFromSmarts('[NX3][CX3]=[CX3]'),
    'Enol': Chem.MolFromSmarts('[OX2H][#6X3]=[#6]'),
    'Ester': Chem.MolFromSmarts('[#6][CX3](=O)[OX2H0][#6]'),
    'Ether': Chem.MolFromSmarts('[OD2]([#6])[#6]'),
    'Haloalkane': Chem.MolFromSmarts('[#6][F,Cl,Br,I]'),
    'Hydrazine': Chem.MolFromSmarts('[NX3][NX3]'),
    'Hydrazone': Chem.MolFromSmarts('[NX3][NX2]=[#6]'),
    'Imide': Chem.MolFromSmarts('[CX3](=[OX1])[NX3][CX3](=[OX1])'),
    'Imine': Chem.MolFromSmarts('[$([CX3]([#6])[#6]),$([CX3H][#6])]=[$([NX2][#6]),$([NX2H])]'),
    'Isocyanate': Chem.MolFromSmarts('[NX2]=[C]=[O]'),
    'Isothiocyanate': Chem.MolFromSmarts('[NX2]=[C]=[S]'),
    'Ketone': Chem.MolFromSmarts('[#6][CX3](=O)[#6]'),
    'Nitrile': Chem.MolFromSmarts('[NX1]#[CX2]'),
    'Phenol': Chem.MolFromSmarts('[OX2H][cX3]:[c]'),
    'Phosphine': Chem.MolFromSmarts('[PX3]'),
    'Sulfide': Chem.MolFromSmarts('[#16X2H0]'),
    'Sulfonamide': Chem.MolFromSmarts('[#16X4]([NX3])(=[OX1])(=[OX1])[#6]'),
    'Sulfonate': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H0]'),
    'Sulfone': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[#6]'),
    'Sulfonic acid': Chem.MolFromSmarts('[#16X4](=[OX1])(=[OX1])([#6])[OX2H]'),
    'Sulfoxide': Chem.MolFromSmarts('[#16X3]=[OX1]'),
    'Thial': Chem.MolFromSmarts('[CX3H1](=S)[#6,H]'),
    'Thioamide': Chem.MolFromSmarts('[NX3][CX3]=[SX1]'),
    'Thiol': Chem.MolFromSmarts('[#16X2H]')
}
def match_group(mol: Chem.Mol, func_group) -> int:
    if type(func_group) == Chem.Mol:
        n = len(mol.GetSubstructMatches(func_group))
    else:
        n = func_group(mol)
    return 0 if n == 0 else 1
# Function to map SMILES to functional groups (no change)
def get_functional_groups(smiles: str) -> dict:
    smiles = smiles.strip().replace(' ', '')
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: 
        return None
    func_groups = [match_group(mol, smarts) for smarts in functional_groups.values()]
    return func_groups

def interpolate_to_600(spec):
    old_x = np.arange(len(spec))
    new_x = np.linspace(min(old_x), max(old_x), 600)
    interp = interp1d(old_x, spec)
    return interp(new_x)

def make_msms_spectrum(spectrum):
    msms_spectrum = np.zeros(10000)
    for peak in spectrum:
        peak_pos = int(peak[0]*10)
        peak_pos = min(peak_pos, 9999)
        msms_spectrum[peak_pos] = peak[1]
    return msms_spectrum

# Define CNN Model in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_std

class CNNModel(nn.Module):
    def __init__(self, num_fgs):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=31, kernel_size=11, padding='same')
        self.conv2 = nn.Conv1d(in_channels=31, out_channels=62, kernel_size=11, padding='same')
        self.fc1 = nn.Linear(62 * 150, 4927)
        self.fc2 = nn.Linear(4927, 2785)
        self.fc3 = nn.Linear(2785, 1574)
        self.fc4 = nn.Linear(1574, num_fgs)
        self.dropout = nn.Dropout(0.48599073736368)
        self.batch_norm1 = nn.BatchNorm1d(31)
        self.batch_norm2 = nn.BatchNorm1d(62)

        # MLP for selecting important channels (62 channels)
        self.mlp = nn.Sequential(
            nn.Linear(150, 128),  # Input 150 features per channel
            nn.ReLU(),
            nn.Linear(128, 1)     # Output importance score for each channel
        )

    def forward(self, x):
        #torch.Size([41, 1, 600])
        x = F.relu(self.batch_norm1(self.conv1(x)))
        #torch.Size([41, 31, 600])
        x = F.max_pool1d(x, 2)  #torch.Size([41, 31, 300])

        x = F.relu(self.batch_norm2(self.conv2(x)))  #torch.Size([41, 62, 300])
        x = F.max_pool1d(x, 2)  #torch.Size([41, 62, 150])

        # 150维特征的通道重要性计算
        # 对每个通道的150维特征进行平均
        static_feature_map = x.clone().detach()
        channel_means = x.mean(dim=1)  # torch.Size([41, 150])，每个通道的平均值
        channel_std = x.std(dim=1)


        # 使用MLP来预测每个通道的权重
        channel_importance = torch.sigmoid(self.mlp(x))  # torch.Size([41, 62,1])
        # 按照计算出的权重调整通道
        ib_x_mean = x * channel_importance+(1-channel_importance)*channel_means.unsqueeze(1) 
        ib_x_std = (1-channel_importance) * channel_std.unsqueeze(1) 
        ib_x = ib_x_mean + torch.rand_like(ib_x_mean) * ib_x_std
  # 通过广播机制，每个通道按权重调整

        # 计算信息瓶颈损失：KL Divergence
        epsilon = 1e-8  # 防止除零错误
        KL_tensor_1 = 0.5 * ((ib_x_std ** 2) / (channel_std.unsqueeze(1)  + epsilon) ** 2 + (channel_std.unsqueeze(1)  ** 2) / (ib_x_std + epsilon) ** 2 - 1) + \
                   ((ib_x_mean - channel_means.unsqueeze(1) ) ** 2) / (channel_std.unsqueeze(1)  + epsilon) ** 2

        KL_Loss_1KL_Loss_1 = torch.mean(KL_tensor_1)

        # Flatten and pass through fully connected layers
        x = x.view(ib_x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc4(x))

        return x, KL_Loss_1KL_Loss_1
        #return x

# Example of how to use the model
# model = CNNModel(num_fgs=10)
# output, loss = model(torch.randn(41, 1, 600))





In [43]:
from tqdm import tqdm  # 引入 tqdm

b=0.0001
def train_model(X_train, y_train, X_test, num_fgs, weighted=False, batch_size=41, epochs=41):
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    model = CNNModel(num_fgs).to(device)
    
    # Define optimizer and loss
    optimizer = optim.Adam(model.parameters())
    
    if weighted:
        class_weights = calculate_class_weights(y_train)
        criterion = WeightedBinaryCrossEntropyLoss(class_weights).to(device)
    else:
        criterion = nn.BCELoss().to(device)

    # Create DataLoader
    train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # Train the model
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        # Create tqdm progress bar for each epoch
        with tqdm(train_loader, unit='batch', desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
            for inputs, targets in tepoch:
                inputs, targets = inputs.to(device), targets.to(device)
                
                optimizer.zero_grad()
                outputs,loss1 = model(inputs.unsqueeze(1))  # Add channel dimension
                loss2 = criterion(outputs, targets)
                loss = loss2+loss1*b
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()

                # Update the progress bar with loss information
                tepoch.set_postfix(loss=running_loss / (tepoch.n + 1))
        
        # After every epoch, print the average loss
        print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}')

    # Evaluate the model
    model.eval()
    predictions = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs,loss2 = model(inputs.unsqueeze(1))
            predictions.append(outputs.cpu().numpy())

    predictions = np.concatenate(predictions)
    return (predictions > 0.5).astype(int)


In [3]:
# Loading data (no change)
analytical_data = Path("/data/zjh2/multimodal-spectroscopic-dataset-main/data/multimodal_spectroscopic_dataset")
out_path = Path("/home/dwj/icml_guangpu/multimodal-spectroscopic-dataset-main/runs/runs_f_groups/h_nmr")
column = "h_nmr_spectra"
seed = 3245

training_data = None
for i, parquet_file in enumerate(analytical_data.glob("*.parquet")):
    data = pd.read_parquet(parquet_file, columns=[column, 'smiles'])
    data[column] = data[column].map(interpolate_to_600)
    data['func_group'] = data.smiles.map(get_functional_groups)
    print("Loaded Data: ", i)
    if training_data is None:
        training_data = data
    else:
        training_data = pd.concat((training_data, data))



Loaded Data:  0
Loaded Data:  1
Loaded Data:  2
Loaded Data:  3
Loaded Data:  4
Loaded Data:  5
Loaded Data:  6
Loaded Data:  7
Loaded Data:  8
Loaded Data:  9
Loaded Data:  10
Loaded Data:  11
Loaded Data:  12
Loaded Data:  13
Loaded Data:  14
Loaded Data:  15
Loaded Data:  16
Loaded Data:  17
Loaded Data:  18
Loaded Data:  19
Loaded Data:  20
Loaded Data:  21
Loaded Data:  22
Loaded Data:  23
Loaded Data:  24
Loaded Data:  25
Loaded Data:  26
Loaded Data:  27
Loaded Data:  28
Loaded Data:  29
Loaded Data:  30
Loaded Data:  31
Loaded Data:  32
Loaded Data:  33
Loaded Data:  34
Loaded Data:  35
Loaded Data:  36
Loaded Data:  37
Loaded Data:  38
Loaded Data:  39
Loaded Data:  40
Loaded Data:  41
Loaded Data:  42
Loaded Data:  43
Loaded Data:  44
Loaded Data:  45
Loaded Data:  46
Loaded Data:  47
Loaded Data:  48
Loaded Data:  49
Loaded Data:  50
Loaded Data:  51
Loaded Data:  52
Loaded Data:  53
Loaded Data:  54
Loaded Data:  55
Loaded Data:  56
Loaded Data:  57
Loaded Data:  58
Loaded 

In [36]:
train, test = train_test_split(training_data, test_size=0.1, random_state=seed)

X_train = np.stack(train[column].to_list())
y_train = np.stack(train['func_group'].to_list())
X_test = np.stack(test[column].to_list())
y_test = np.stack(test['func_group'].to_list())



In [None]:
# Train extended model
predictions = train_model(X_train, y_train, X_test, num_fgs=37, weighted=False)

# Evaluate the model
f1 = f1_score(y_test, predictions, average='micro')
print(f'F1 Score: {f1}')

# Save results
with open(out_path / "results.pickle", "wb") as file:
    pickle.dump({'pred': predictions, 'tgt': y_test}, file)

Epoch 1/41: 100%|██████████| 17439/17439 [14:47<00:00, 19.65batch/s, loss=0.172]


Epoch 1/41, Loss: 0.17175170208311824


Epoch 2/41: 100%|██████████| 17439/17439 [13:55<00:00, 20.88batch/s, loss=0.159]


Epoch 2/41, Loss: 0.15904590773002442


Epoch 3/41: 100%|██████████| 17439/17439 [09:38<00:00, 30.16batch/s, loss=0.154]


Epoch 3/41, Loss: 0.1536876848821445


Epoch 4/41: 100%|██████████| 17439/17439 [11:08<00:00, 26.07batch/s, loss=0.15]


Epoch 4/41, Loss: 0.1496898248826728


Epoch 5/41: 100%|██████████| 17439/17439 [10:22<00:00, 28.02batch/s, loss=0.147]


Epoch 5/41, Loss: 0.1465038159850732


Epoch 6/41: 100%|██████████| 17439/17439 [10:03<00:00, 28.89batch/s, loss=0.144]


Epoch 6/41, Loss: 0.14386171982759552


Epoch 7/41: 100%|██████████| 17439/17439 [09:33<00:00, 30.40batch/s, loss=0.142]


Epoch 7/41, Loss: 0.14159161970510661


Epoch 8/41: 100%|██████████| 17439/17439 [09:31<00:00, 30.50batch/s, loss=0.14] 


Epoch 8/41, Loss: 0.13976322551445305


Epoch 9/41: 100%|██████████| 17439/17439 [09:34<00:00, 30.36batch/s, loss=0.138]


Epoch 9/41, Loss: 0.13782890476519255


Epoch 10/41: 100%|██████████| 17439/17439 [09:59<00:00, 29.10batch/s, loss=0.136]


Epoch 10/41, Loss: 0.1361409236132104


Epoch 11/41:  72%|███████▏  | 12542/17439 [06:46<02:19, 35.01batch/s, loss=0.134]