In [1]:
import pandas as pd
from transformers import BertModel, AutoTokenizer, RobertaModel, RobertaTokenizerFast
from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import MinMaxScaler, RobustScaler, normalize
import torch, os, random
from Sophia.sophia import SophiaG 
import numpy as np
from torch.nn.utils import clip_grad_norm_
from graph_aug import mask_nodes, mask_edges, permute_edges, drop_nodes, subgraph
from torch import nn
from torch.nn import functional as F
# from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from torch_ema import ExponentialMovingAverage
from matplotlib import pyplot as plt

In [2]:
from ogb.utils import smiles2graph
from dualgraph.mol import smiles2graphwithface, simles2graphwithface_with_mask
from dualgraph.gnn import GNN, GNN2, GIN_node_Virtual, GNNwithvn

from torch_geometric.data import Dataset, InMemoryDataset
from dualgraph.dataset import DGData
from torch_geometric.loader import DataLoader

Using backend: pytorch


In [3]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)

    torch.cuda.manual_seed(seed)  # type: ignore
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore
seed_everything(2023)


In [4]:
device = 'cuda'

In [5]:
data = pd.read_csv('data/train.csv', index_col=None)
test = pd.read_csv('data/test.csv', index_col=None)

duplicate_smile = data['SMILES'].value_counts().reset_index().query('count>1')['SMILES'].values
no_duplicate = data.drop_duplicates('SMILES').reset_index(drop=True)

duplicate = data[data['SMILES'].isin(duplicate_smile)].reset_index(drop=True)
duplicate = duplicate.groupby('SMILES')[['MLM', 'HLM']].max().reset_index()

no_duplicate.loc[ no_duplicate['SMILES'].isin(duplicate['SMILES']), 'MLM'] = duplicate['MLM'].values
no_duplicate.loc[ no_duplicate['SMILES'].isin(duplicate['SMILES']), 'HLM'] = duplicate['HLM'].values
data = no_duplicate

In [6]:
data.loc[data['AlogP'].isna(), 'AlogP'] = data['LogD']
test.loc[test['AlogP'].isna(), 'AlogP'] = test['LogD']

# data.loc[data['MLM'] > 100, 'MLM'] = float(100)
# data.loc[data['HLM'] > 100, 'HLM'] = float(100)

In [7]:
pp_target = np.log1p(data['MLM'].values)
norm_pp_target = (pp_target - pp_target.mean()) / (pp_target.std())

In [8]:
test['MLM'] = 0
test['HLM'] = 0
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [9]:
class CustomDataset(InMemoryDataset):
    def __init__(self, root='dataset_path', transform=None, pre_transform=None, df=None, target_type='MLM', mode='train'):
        self.df = df
        self.target_type = target_type
        self.mode = mode
        super().__init__(root, transform, pre_transform, df)
        

    @property
    def raw_file_names(self):        
        return [f'raw_{i+1}.pt' for i in range(self.df.shape[0])]

    @property
    def processed_file_names(self):
        return [f'data_{i+1}.pt' for i in range(self.df.shape[0])]        

    def len(self):
        return len(self.graph_list)

    def get(self, idx):        
        if self.mode=='train' and random.random() > 10:
            data = DGData()
            smiles = self.smiles_list[idx]
            smiles = sme.randomize_smiles(smiles)                                

            targets = self.targets_list[idx]

            graph = smiles2graphwithface(smiles)

            data.__num_nodes__ = int(graph["num_nodes"])
            data.edge_index = torch.from_numpy(graph["edge_index"]).to(torch.int64)
            data.edge_attr = torch.from_numpy(graph["edge_feat"]).to(torch.int64)
            data.x = torch.from_numpy(graph["node_feat"]).to(torch.int64)
            data.y = torch.Tensor([targets])

            data.ring_mask = torch.from_numpy(graph["ring_mask"]).to(torch.bool)
            data.ring_index = torch.from_numpy(graph["ring_index"]).to(torch.int64)
            data.nf_node = torch.from_numpy(graph["nf_node"]).to(torch.int64)
            data.nf_ring = torch.from_numpy(graph["nf_ring"]).to(torch.int64)
            data.num_rings = int(graph["num_rings"])
            data.n_edges = int(graph["n_edges"])
            data.n_nodes = int(graph["n_nodes"])
            data.n_nfs = int(graph["n_nfs"])
            data.tabular = torch.from_numpy(self.tabular_list[idx])

            return data

        else:
            # if self.mode=='train' and random.random() < 0.5:
            #     return mask_nodes(self.graph_list[idx], 0.1)
            return self.graph_list[idx]


    def process(self):        
        smiles_list = self.df["SMILES"].values
        targets_list = self.df[self.target_type].values
        tabular_list = self.df[['AlogP', 'Molecular_Weight', 'Num_H_Acceptors', 'Num_H_Donors', 'Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']].values.astype('float32')

        data_list = []
        for i in tqdm(range(len(smiles_list))):            
            data = DGData()
            smiles = smiles_list[i]
            targets = targets_list[i]
            graph = smiles2graphwithface(smiles)

            data.__num_nodes__ = int(graph["num_nodes"])
            data.edge_index = torch.from_numpy(graph["edge_index"]).to(torch.int64)
            data.edge_attr = torch.from_numpy(graph["edge_feat"]).to(torch.int64)
            data.x = torch.from_numpy(graph["node_feat"]).to(torch.int64)
            data.y = torch.Tensor([targets])

            data.ring_mask = torch.from_numpy(graph["ring_mask"]).to(torch.bool)
            data.ring_index = torch.from_numpy(graph["ring_index"]).to(torch.int64)
            data.nf_node = torch.from_numpy(graph["nf_node"]).to(torch.int64)
            data.nf_ring = torch.from_numpy(graph["nf_ring"]).to(torch.int64)
            data.num_rings = int(graph["num_rings"])
            data.n_edges = int(graph["n_edges"])
            data.n_nodes = int(graph["n_nodes"])
            data.n_nfs = int(graph["n_nfs"])        
            data.tabular = torch.from_numpy(tabular_list[i])

            data_list.append(data)
        self.smiles_list = smiles_list  
        self.graph_list = data_list
        self.targets_list = targets_list
        self.tabular_list = tabular_list

In [10]:
def correlation_score(y_true, y_pred):
    y_true_centered = y_true - torch.mean(y_true, dim=1)[:, None]
    y_pred_centered = y_pred - torch.mean(y_pred, dim=1)[:, None]
    cov_tp = torch.sum(y_true_centered * y_pred_centered, dim=1) / (y_true.shape[1] - 1)
    var_t = torch.sum(y_true_centered ** 2, dim=1) / (y_true.shape[1] - 1)
    var_p = torch.sum(y_pred_centered ** 2, dim=1) / (y_true.shape[1] - 1)
    return cov_tp / torch.sqrt(var_t * var_p)


def correlation_loss(pred, target):
    return -torch.mean(correlation_score(target.unsqueeze(0), pred.unsqueeze(0)))

In [11]:
class MedModel(torch.nn.Module):

    def __init__(self):
        super(MedModel, self).__init__()
        self.ddi = True
        self.gnn = GNN2(
                        mlp_hidden_size = 512,
                        mlp_layers = 2,
                        latent_size = 128,
                        use_layer_norm = False,
                        use_face=True,
                        # residual = True,
                        ddi=self.ddi,
                        dropedge_rate = 0.1,
                        dropnode_rate = 0.1,
                        dropout = 0.1,
                        dropnet = 0.1,
                        global_reducer = "sum",
                        node_reducer = "sum",
                        face_reducer = "sum",
                        graph_pooling = "sum",
                        # global_attn = True,
                        node_attn = True,
                        face_attn = True
                        # use_bn=True
                        )
        state_dict=  torch.load('ckpt/ognn_pretrain_best.pt', map_location='cpu')        
        self.gnn.load_state_dict(state_dict)#, strict=False)
        
        self.fc1 = nn.Sequential(
                    nn.LayerNorm(128),
                    nn.Linear(128, 128,),
                    nn.BatchNorm1d(128),
                    nn.Dropout(0.1),
                    nn.ReLU(),
                    nn.Linear(128, 1),
                    )

        # self.fc1[1].weight.data.normal_(mean=0.0, std=0.01)
        self.fc1[-1].weight.data.normal_(mean=0.0, std=0.01)
        
    def forward(self, batch):
        mol = self.gnn(batch)
        out1 = self.fc1(mol).squeeze(1)# .sigmoid() * 100
        return out1


In [12]:
def norm_mse_loss(preds, targets):
    p_mean, p_var = preds.mean(), preds.var()
    t_mean, t_var = targets.mean(), targets.var()
    norm_preds = (preds - p_mean) / (p_var + 1.e-6)**.5
    norm_targets = (targets - t_mean) / (t_var + 1.e-6)**.5
    loss = (norm_preds - norm_targets) ** 2
    return loss.mean()

In [13]:
ksplit = KFold(n_splits=5, shuffle=True, random_state=2023)

In [14]:
def norm(x):
    return (x-target_mean) / target_std

def denorm(x):
    return x * target_std + target_mean

In [15]:
mlm_chem_df = pd.read_csv('data2/train-aug-MLM-0906.csv.csv')
hlm_chem_df = pd.read_csv('data2/train-aug-HLM-0906.csv.csv')

In [16]:
# for k, (t_idx, v_idx) in enumerate(ksplit.split(range(data.shape[0]))):
#     train, valid = data.loc[t_idx].reset_index(drop=True), data.loc[v_idx].reset_index(drop=True)
    
#     target_mean = train['MLM'].mean()
#     target_std = train['MLM'].std()

#     train = pd.concat([train, mlm_chem_df]).reset_index(drop=True)

#     train_dataset = CustomDataset(df = train, mode='train')
#     train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers = 8)

#     valid_dataset = CustomDataset(df = valid, mode='test')
#     valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers = 8)

#     model = MedModel().to(device)
#     # model.load_state_dict(torch.load('best_gnn_semi_MLM.pt'))
#     hub_loss = nn.HuberLoss()
#     mse_loss = nn.MSELoss()
#     optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
#     ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=10, T_mult=1, verbose=False)
#     # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, verbose=False)
#     best_val_loss = 1e6

#     for epoch in range(50):
#         model.train()
#         train_loss = 0
#         for batch in train_loader:
#             batch = batch.to(device)
            
#             preds = model(batch)
#             targets = batch.y.to(device)
#             targets = norm(targets)

#             loss = mse_loss(preds, targets) * 0.8 + correlation_loss(preds, targets) * 0.2

#             optim.zero_grad()
#             loss.backward()
#             optim.step()
#             ema.update()
            
#             train_loss += loss.cpu().item()
        
#         model.eval()
#         valid_preds = []
#         valid_label = []
            
#         for batch in valid_loader:
#             batch = batch.to(device)
#             with torch.no_grad():
#                 preds = model(batch)                
#                 targets = batch.y.to(device)                            
#                 preds = denorm(preds) 
                
#                 valid_label += targets.cpu().tolist()
#                 valid_preds += preds.cpu().tolist()
        
#         val_loss = mean_squared_error(valid_preds, valid_label) ** (1/2)

#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             # np.save(f'ckpt_fold/best_gnn_val_preds_{k}.npy', np.array(valid_preds))
#             # np.save(f'ckpt_fold/val_label_{k}.npy', np.array(valid_label))
#             # torch.save(model.state_dict(), f'ckpt_fold/best_gnn_{train_dataset.target_type}_{k}.pt')
#         scheduler.step()
#         print(f'EPOCH : {epoch} | T_LOSS : {train_loss / len(train_loader):.4f} | MLM_RMSE : {val_loss:.2f} | BEST : {best_val_loss:.2f}')    


In [17]:
for k, (t_idx, v_idx) in enumerate(ksplit.split(range(data.shape[0]))):
    train, valid = data.loc[t_idx].reset_index(drop=True), data.loc[v_idx].reset_index(drop=True)

    target_mean = train['HLM'].mean()
    target_std = train['HLM'].std()

    train = pd.concat([train, hlm_chem_df]).reset_index(drop=True)

    train_dataset = CustomDataset(df = train, mode='train', target_type='HLM')
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers = 8)

    valid_dataset = CustomDataset(df = valid, mode='test', target_type='HLM')
    valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False, num_workers = 8)
    

    model = MedModel().to(device)
    mse_loss = nn.MSELoss()
    optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
    ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=10, T_mult=1, verbose=False)
    best_val_loss = 1e6

    for epoch in range(50):
        model.train()
        train_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            
            preds = model(batch)
            targets = batch.y.to(device)
            targets = norm(targets)
            
            loss = mse_loss(preds, targets) * 0.8 + correlation_loss(preds, targets) * 0.2

            optim.zero_grad()
            loss.backward()
            optim.step()
            ema.update()
            
            train_loss += loss.cpu().item()
        
        model.eval()
        valid_preds = []
        valid_label = []
            
        for batch in valid_loader:
            batch = batch.to(device)
            with torch.no_grad():
                preds = model(batch)                
                targets = batch.y.to(device)                            
                preds = denorm(preds) 

                valid_label += targets.cpu().tolist()
                valid_preds += preds.cpu().tolist()
        
        val_loss = mean_squared_error(valid_preds, valid_label) ** (1/2)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            np.save(f'ckpt_fold/best_gnn_val_preds_HLM_{k}.npy', np.array(valid_preds))
            np.save(f'ckpt_fold/val_label_HLM_{k}.npy', np.array(valid_label))
            torch.save(model.state_dict(), f'ckpt_fold/best_gnn_{train_dataset.target_type}_{k}.pt')
        scheduler.step()
        print(f'EPOCH : {epoch} | T_LOSS : {train_loss / len(train_loader):.4f} | HLM_RMSE : {val_loss:.2f} | BEST : {best_val_loss:.2f}')    


Processing...
100%|██████████| 4373/4373 [00:04<00:00, 1086.21it/s]
Done!
Processing...
100%|██████████| 695/695 [00:00<00:00, 1163.90it/s]
Done!


EPOCH : 0 | T_LOSS : 3246.8391 | HLM_RMSE : 62.92 | BEST : 62.92


In [None]:
# test_dataset = CustomDataset(df = test, mode='test', target_type='MLM')
# test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers = 8)

In [None]:
# ognn_MLM = []
# ognn_HLM = []

# for k in tqdm(range(5)):
#     model = MedModel().to(device)
#     model.load_state_dict(torch.load(f'ckpt_fold/best_gnn_MLM_{k}.pt'))
#     model.eval()
#     preds = []
#     for batch in test_loader:
#         with torch.no_grad():            
#             preds += model.fc1(model.gnn(batch.to(device))).cpu().tolist()
#     ognn_MLM.append(preds)

# for k in tqdm(range(5)):
#     model = MedModel().to(device)
#     model.load_state_dict(torch.load(f'ckpt_fold/best_gnn_HLM_{k}.pt'))
#     model.eval()
#     preds = []

#     for batch in test_loader:
#         with torch.no_grad():            
#             preds += model.fc1(model.gnn(batch.to(device))).cpu().tolist()  
#     ognn_HLM.append(preds)

# ognn_MLM = torch.tensor(ognn_MLM)
# ognn_HLM = torch.tensor(ognn_HLM)