In [1]:
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
import os

In [2]:
import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import HybridizationType
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem

from rdkit.Chem import rdMolTransforms
from rdkit.Chem.Draw import rdMolDraw2D, rdDepictor, IPythonConsole
from rdkit import rdBase
blocker = rdBase.BlockLogs()

In [3]:
tqdm.pandas()

In [4]:
!python3 -m wandb login eb7b1964fb84cd81de96b2a273ecf2bb6254aeac

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/nick1899/.netrc


### Upload config

In [5]:
import yaml

config = yaml.load(open("config-qsar-regression-exact.yaml", "r"), Loader=yaml.FullLoader)
print(config)

num_of_shifts = 20

{'batch_size': 32, 'warm_up': 2, 'epochs': 10, 'load_graph_model': 'pretrained_gcn', 'save_every_n_epochs': 1, 'fp16_precision': False, 'init_lr': 0.0005, 'weight_decay': '1e-5', 'gpu': 'cuda:3', 'pretrained_roberta_name': 'molberto_ecfp0_2M', 'roberta_model': {'vocab_size': 30522, 'max_position_embeddings': 514, 'hidden_size': 768, 'num_attention_heads': 12, 'num_hidden_layers': 6, 'type_vocab_size': 1}, 'graph_model_type': 'gin', 'graph_model': {'num_layer': 5, 'emb_dim': 500, 'feat_dim': 768, 'drop_ratio': 0, 'pool': 'mean'}, 'graph_aug': 'node', 'dataset': {'num_workers': 1, 'valid_size': 0.1, 'test_size': 0.1}, 'ntxent_loss': {'temperature': 0.1, 'use_cosine_similarity': True}, 'loss_params': {'alpha': 1.0, 'beta': 1.0, 'gamma': 1.0}}


In [6]:
print('batch_size =', config['batch_size'])

batch_size = 32


In [7]:
print('running on device:', config['gpu'])
device = torch.device(config['gpu']) if torch.cuda.is_available() else torch.device('cpu')

running on device: cuda:3


In [8]:
def _save_config_file(config, log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    with open(os.path.join(log_dir, 'config.yml'), 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False, sort_keys=False)

### Upload and Split Dataset

In [9]:
dataframe = pd.read_csv("datasets/nmr_with_ecfp.csv")

In [11]:
#dataframe = dataframe.drop(columns=['mol_id'])
dataframe = dataframe[dataframe['num_of_atoms']<51]
dataframe = dataframe.sample(frac=1).reset_index(drop=True)

In [12]:
#dataframe = dataframe.drop(columns=['num', 'name'])
target = 'arrayed_data' # choosing target data

In [13]:
dataframe.head()

Unnamed: 0,index,ppmdict_real,smiles,arrayed_data,ecfp0,num_of_atoms
0,174226,"{'78.74': [0], '93.93': [5], '126.66': [6, 7],...",C=C=Cc1ccccc1,"[78.74, 209.77, 126.86, 128.58, 128.58, 93.93,...","['2246997334', '2245900962', '2246703798', '32...",9
1,245444,"{'40.0': [12], '45.0': [0, 1], '57.0': [13], '...",CN(C)CCNc1cc(-c2ccc3ccccc3c2)nc2ccccc12,"[45.0, 45.0, 126.8, 126.3, 124.3, 130.0, 128.1...","['2246728737', '848128881', '2246728737', '224...",26
2,627,"{'86.1': [8], '92.8': [7], '120.6': [12], '122...",C(#Cc1cccnc1)c1ccccc1,"[128.9, 128.6, 128.6, 123.2, 131.8, 131.8, 138...","['2245900962', '2245900962', '3217380708', '32...",14
3,179202,"{'13.7': [0], '21.8': [1], '31.0': [2], '39.5'...",CCCC/C(=C\[Se]c1ccccc1)[Se]c1ccccc1,"[13.7, 21.8, 31.0, 127.2, 127.3, 129.2, 129.2,...","['2246728737', '2245384272', '2245384272', '22...",20
4,12775,"{'17.1': [0], '123.7': [16], '128.4': [3, 4], ...",Cc1cc(Br)cc2nc(-c3ccccc3)c(-c3ccccc3)nc12,"[17.1, 129.9, 130.3, 128.4, 128.4, 129.1, 129....","['2246728737', '3217380708', '3218693969', '32...",24


In [14]:
# this because pandas thinks columns with arrays are strings
def preprocess_data_dataset(df, column):
    for row in tqdm(range(len(df))):
        str_ints = eval(df.iloc[row][column])
        str_fingerprint = ' '.join(str_ints)
        df.at[row, column] = str_fingerprint

In [15]:
dropping  = []
for i in tqdm(range(len(dataframe['ecfp0']))):
    if type(dataframe['ecfp0'].iloc[i]) != str:
        dropping.append(i)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 81856.05it/s]


In [16]:
dataframe = dataframe.drop(dropping)
dataframe = dataframe.reset_index()

In [17]:
preprocess_data_dataset(dataframe, 'ecfp0') # preprocess ecfp due to invalid storing of array

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 9514.45it/s]


In [18]:
dataframe['arrayed_data'] = dataframe['arrayed_data'].progress_apply(lambda x: eval(x))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 26584.42it/s]


In [19]:
dataframe.head()

Unnamed: 0,level_0,index,ppmdict_real,smiles,arrayed_data,ecfp0,num_of_atoms
0,0,174226,"{'78.74': [0], '93.93': [5], '126.66': [6, 7],...",C=C=Cc1ccccc1,"[78.74, 209.77, 126.86, 128.58, 128.58, 93.93,...",2246997334 2245900962 2246703798 3217380708 32...,9
1,1,245444,"{'40.0': [12], '45.0': [0, 1], '57.0': [13], '...",CN(C)CCNc1cc(-c2ccc3ccccc3c2)nc2ccccc12,"[45.0, 45.0, 126.8, 126.3, 124.3, 130.0, 128.1...",2246728737 848128881 2246728737 2245384272 224...,26
2,2,627,"{'86.1': [8], '92.8': [7], '120.6': [12], '122...",C(#Cc1cccnc1)c1ccccc1,"[128.9, 128.6, 128.6, 123.2, 131.8, 131.8, 138...",2245900962 2245900962 3217380708 3218693969 32...,14
3,3,179202,"{'13.7': [0], '21.8': [1], '31.0': [2], '39.5'...",CCCC/C(=C\[Se]c1ccccc1)[Se]c1ccccc1,"[13.7, 21.8, 31.0, 127.2, 127.3, 129.2, 129.2,...",2246728737 2245384272 2245384272 2245384272 22...,20
4,4,12775,"{'17.1': [0], '123.7': [16], '128.4': [3, 4], ...",Cc1cc(Br)cc2nc(-c3ccccc3)c(-c3ccccc3)nc12,"[17.1, 129.9, 130.3, 128.4, 128.4, 129.1, 129....",2246728737 3217380708 3218693969 3217380708 36...,24


In [20]:
dataframe['arrayed_data'] = dataframe['arrayed_data'].progress_apply(lambda x: np.asarray(x))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 150177.38it/s]


### Fix of bad padding protocol

In [21]:
dropping  = []
for i in tqdm(range(len(dataframe['arrayed_data']))):
    if len(dataframe['arrayed_data'].iloc[i]) != num_of_shifts:
        dropping.append(i)

dataframe = dataframe.drop(dropping)
#dataframe = dataframe.reset_index()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 84572.81it/s]


### Normalization

In [22]:
mn = dataframe[target].to_numpy().mean()
std = dataframe[target].to_numpy().std()
matrix = np.array(dataframe[target].to_numpy().tolist())

In [23]:
number_of_preds = len(dataframe[target].iloc[0])

In [24]:
for i in range(number_of_preds):
    matrix[:,i] -= mn[i]
    matrix[:,i] /= std[i]

In [25]:
dataframe[target] = matrix.tolist()
#dataframe['arrayed_data'] = dataframe['arrayed_data'].progress_apply(lambda x: np.array(x))

In [26]:
dataframe['arrayed_data'].iloc[0]

[0.3657146214263286,
 2.4528125756053263,
 0.523606308827743,
 0.43746912497137075,
 0.3890613820924653,
 -0.6801269234931228,
 0.30470232017903615,
 0.3124663943357695,
 0.5621099149371945,
 -1.6942891576572239,
 -1.3746231117720253,
 -1.214154661859889,
 -1.0528204243330668,
 -0.9034024672949567,
 -0.7411973030081994,
 -0.596013776419713,
 -0.45686966925848904,
 -0.34558708607568833,
 -0.2514298432508667,
 -0.16883025283670966]

In [27]:
mn

array([ 61.0300655 ,  87.2510262 , 103.57351528, 112.85737991,
       116.22031659, 117.36626638, 115.75804585, 113.34661572,
       106.43558952,  96.37207424,  85.9991048 ,  80.40157205,
        72.7894869 ,  63.45159389,  51.06505459,  37.95661572,
        24.92104803,  15.36378821,   8.21463974,   3.87098253])

In [28]:
std

array([48.42555769, 49.95040184, 44.47326994, 35.93995368, 31.76795224,
       34.45866583, 35.77903228, 42.60741162, 48.84171183, 56.8805353 ,
       62.56195176, 66.22020619, 69.13760905, 70.23624152, 68.89535941,
       63.6841248 , 54.54739001, 44.45706691, 32.67169733, 22.92825171])

### Create Molecule Dataset
##### It will generate torch_geometric.data.Data objects for both bert and GIN/GCN models.

In [29]:
from rdkit import Chem

ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
    Chem.rdchem.BondType.SINGLE, 
    Chem.rdchem.BondType.DOUBLE, 
    Chem.rdchem.BondType.TRIPLE, 
    Chem.rdchem.BondType.AROMATIC
]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]

In [30]:
tqdm.pandas()

In [31]:
import random
import math
from copy import deepcopy
from torch_geometric.data import Data, Dataset

class MoleculeDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame, tokenizer, node_mask_percent=0.15, edge_mask_percent=0.25):
        super(Dataset, self).__init__()
        self.dataset = dataset
        self.node_mask_percent = node_mask_percent
        self.edge_mask_percent = edge_mask_percent

        self.tokenizer = tokenizer
        self.tokenizer.model_max_len = 512

        self.dataset['graph'] = self.dataset['smiles'].progress_apply(self.get_graph_from_smiles)
        #print(self.dataset['graph'].iloc[0])
        self.dataset['graph_copy1'] = self.dataset['graph'].progress_apply(lambda x: self.get_augmented_graph_copy(x[0], x[1], x[2], x[3], x[4]))
        self.dataset['graph_copy2'] = self.dataset['graph'].progress_apply(lambda x: self.get_augmented_graph_copy(x[0], x[1], x[2], x[3], x[4]))


        self.dataset['tokens'] = self.dataset['ecfp0'].progress_apply(self.tokenize)
        self.dataset['mlm'] = self.dataset['tokens'].progress_apply(self.apply_mlm)


    def get_graph_from_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return torch.tensor([[], []], dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    torch.tensor(np.array([]), dtype=torch.long), \
                    0
    
        N = mol.GetNumAtoms()
        M = mol.GetNumBonds()
    
        type_idx = []
        chirality_idx = []
        atomic_number = []
        
        for atom in mol.GetAtoms():
            type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
            chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
            atomic_number.append(atom.GetAtomicNum())
        
        x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
        x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
        node_feat = torch.cat([x1, x2], dim=-1)
    
        row, col, edge_feat = [], [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row += [start, end]
            col += [end, start]
            
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
            edge_feat.append([
                BOND_LIST.index(bond.GetBondType()),
                BONDDIR_LIST.index(bond.GetBondDir())
            ])
    
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.long)
        num_nodes = N
        num_edges = M
        return node_feat, edge_index, edge_attr, num_nodes, num_edges

    def get_augmented_graph_copy(self, node_feat, edge_index, edge_attr, N, M):
        num_mask_nodes = max([1, math.floor(self.node_mask_percent * N)])
        
        mask_nodes = random.sample(list(range(N)), num_mask_nodes)

        node_feat_new = deepcopy(node_feat)
        for atom_idx in mask_nodes:
            node_feat_new[atom_idx, :] = torch.tensor([len(ATOM_LIST), 0])
        
        return Data(x=node_feat_new, edge_index=edge_index, edge_attr=edge_attr)

    def tokenize(self, item):
        return self.tokenizer(item, truncation=True, max_length=512, padding='max_length')

    def mlm(self, tensor):
        rand = torch.rand(tensor.shape)
        # mask random 15% where token is not 0 <s>, 1 <pad>, or 2 <s/>
        mask_arr = (rand < .15) * (tensor != 0) * (tensor != 1) * (tensor != 2)
        selection = torch.flatten(mask_arr.nonzero()).tolist()
        tensor[selection] = 4
        return tensor

    def apply_mlm(self, sample):
        labels = torch.tensor(sample.input_ids)
        attention_mask = torch.tensor(sample.attention_mask)
        input_ids = self.mlm(labels.detach().clone())
        return Data(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def __getitem__(self, index):
        return self.dataset['mlm'].iloc[index], self.dataset['graph_copy1'].iloc[index], self.dataset['graph_copy2'].iloc[index], self.dataset[target].iloc[index]

    def __len__(self):
        return len(self.dataset)

    def get(self):
        pass
    def len(self):
        pass

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
from transformers import AutoTokenizer

model_name_bert = 'molberto_ecfp0_2M'
tokenizer = AutoTokenizer.from_pretrained(model_name_bert)
dataset = MoleculeDataset(dataframe, tokenizer)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 916/916 [00:00<00:00, 2375.00it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 916/916 [00:00<00:00, 8595.17it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 916/916 [00:00<00:00, 3192.72it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 916/916 [00:00<00:00, 6039.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████

In [33]:
from torch_geometric.loader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


num_train = len(dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
    
split_tr = int(np.floor(config['dataset']['valid_size'] * num_train))
split_test = int(np.floor(config['dataset']['test_size'] * num_train))
    
train_idx, valid_idx, test_idx = indices[split_tr + split_test : ], indices[: split_tr], indices[split_tr : split_tr + split_test]
    
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)
    
    
train_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=train_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)
    
eval_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=valid_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)
    
test_dataloader = DataLoader(
    dataset, batch_size=config['batch_size'], sampler=test_sampler,
    num_workers=config['dataset']['num_workers'], drop_last=True
)

### Create Transformer Model

In [34]:
import torch
import numpy as np


class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits = logits.abs() + 0.0001
        logits = torch.log(logits)
        logits /= self.temperature
        
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [35]:
from transformers import RobertaForMaskedLM
from transformers import RobertaConfig
from torch import nn

if config['graph_model_type'] == 'gin':
    from MolCLR.models.ginnet_old import GINet as GraphModel
elif config['graph_model_type'] == 'gcn':
    from MolCLR.models.gcn_molclr import GCN as GraphModel
else:
    raise ValueError('GNN model is not defined in config.')

class MolecularBertGraph(torch.nn.Module):
    def __init__(self):
        super(MolecularBertGraph, self).__init__()
        self.batch_size = config['batch_size']

        roberta_config = roberta_config = RobertaConfig(
            vocab_size=30_522,
            max_position_embeddings=514,
            hidden_size=768,
            num_attention_heads=12,
            num_hidden_layers=6,
            type_vocab_size=1
        )
        
        self.bert = RobertaForMaskedLM(roberta_config).to(device)

        self.graph_model = GraphModel(**config['graph_model']).to(device)
        # self.graph_model = self._load_graph_pretrained_weights(self.graph_model)

        self.out_graph_linear = torch.nn.Linear(2 * config['graph_model']['feat_dim'], 
                                                768, bias=True)

        self.out_graph_projection1 = torch.nn.Linear(768, 768, bias=True)

        self.bn1_graph = nn.BatchNorm1d(768)

        self.out_graph_projection2 = torch.nn.Linear(768, 768, bias=True)

        self.bn2_graph = nn.BatchNorm1d(768)

        self.out_bert_projection1 = torch.nn.Linear(768, 768, bias=True)

        self.bn1_bert = nn.BatchNorm1d(768)

        self.out_bert_projection2 = torch.nn.Linear(768, 768, bias=True)
        
        self.bn2_bert = nn.BatchNorm1d(768)
        
        # contrastive loss for MolCLR
        self.nt_xent_criterion = NTXentLoss(device, self.batch_size, **config['ntxent_loss'])

    def forward(self, bert_batch, graph_batch1, graph_batch2):
        bert_output = self.bert(input_ids=bert_batch['input_ids'].view(self.batch_size, -1), 
                                 attention_mask=bert_batch['attention_mask'].view(self.batch_size, -1),
                                 labels=bert_batch['labels'].view(self.batch_size, -1), output_hidden_states=True)
        bert_loss = bert_output.loss
        bert_emb = bert_output.hidden_states[0][:, 0, :] # take emb for CLS token

        graph_loss, hidden_states_1, hidden_states_2 = self.graph_step(graph_batch1, graph_batch2)
        
        graph_emb = self.out_graph_linear(torch.cat((hidden_states_1, hidden_states_2), dim=-1))

        graph_emb_projected1 = self.out_graph_projection1(graph_emb)
        
        graph_emb_projected_bn1 = self.bn1_graph(graph_emb_projected1)

        graph_emb_projected2 = self.out_graph_projection2(torch.nn.functional.relu(graph_emb_projected_bn1))
        
        graph_emb_projected_bn2 = self.bn2_graph(graph_emb_projected2)
        
        #bert projections:
        bert_emb_projected1 = self.out_bert_projection1(bert_emb)
        
        bert_emb_projected_bn1 = self.bn1_bert(bert_emb_projected1)

        bert_emb_projected2 = self.out_bert_projection2(torch.nn.functional.relu(bert_emb_projected_bn1))
        
        bert_emb_projected_bn2 = self.bn2_bert(bert_emb_projected2)

        # bimodal_loss = ((1 - self.cosine_sim(bert_emb, graph_emb))**2).mean()
        bimodal_loss = self.nt_xent_criterion(bert_emb_projected_bn2, graph_emb_projected_bn2)
        return bert_loss, graph_loss, bimodal_loss, graph_emb_projected_bn2, bert_emb_projected_bn2

    def graph_step(self, xis, xjs):
        # get the representations and the projections
        ris, zis = self.graph_model(xis)  # [N,C]
    
        # get the representations and the projections
        rjs, zjs = self.graph_model(xjs)  # [N,C]
    
        # normalize projection feature vectors
        zis = torch.nn.functional.normalize(zis, dim=1)
        zjs = torch.nn.functional.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss, ris, rjs

In [36]:
#model = MolecularBertGraph().to(device)

In [37]:
#print(model)

### Define utils

In [38]:
wandb.init(
    project="efcp_transformer",
    name="NMR " + target + " Multiple-all metrics " + config['graph_model_type'] + " (bert+molclr)",
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33morlov-aleksei53[0m ([33mmoleculary-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...

### Training (with validation)

In [39]:
epoch_counter = 0

In [40]:
from datetime import datetime

model_checkpoints_folder = os.path.join('ckpts')
dir_name = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(model_checkpoints_folder, dir_name)
_save_config_file(config, log_dir)

In [41]:
class MolecularPropertiesRegression(torch.nn.Module):
    def __init__(self):
        super(MolecularPropertiesRegression, self).__init__()

        self.test = MolecularBertGraph().to(device)
        self.test.load_state_dict(torch.load('weights_pretr/' + config['graph_model_type'] + '.pth'))

        self.linear1 = torch.nn.Linear(768 * 2, 768, bias=True)
        self.linear2 = torch.nn.Linear(768, 512, bias=True)
        self.bn = nn.BatchNorm1d(512)
        self.linear3 = torch.nn.Linear(512, 256, bias=True)
        
        # Creating output heads: one for each property
        self.output_heads = nn.ModuleList([torch.nn.Linear(256, 1) for _ in range(number_of_preds)])

    def forward(self, b, g1, g2):
        l1, l2, l3, graph_emb, bert_emb = self.test(b, g1, g2)

        first_linear_out = self.linear1(torch.cat((graph_emb, bert_emb), dim=-1))
        sec_linear_out = self.linear2(torch.nn.functional.leaky_relu(first_linear_out))
        batchnormed = self.bn(sec_linear_out)
        thd_linear_out = self.linear3(torch.nn.functional.leaky_relu(batchnormed))

        # Output for each head
        logits = [head(thd_linear_out) for head in self.output_heads]
        
        return torch.cat(logits, dim=-1)  # Concatenate outputs along the last dimension

In [42]:
model = MolecularPropertiesRegression().to(device)

In [43]:
model

MolecularPropertiesRegression(
  (test): MolecularBertGraph(
    (bert): RobertaForMaskedLM(
      (roberta): RobertaModel(
        (embeddings): RobertaEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=1)
          (position_embeddings): Embedding(514, 768, padding_idx=1)
          (token_type_embeddings): Embedding(1, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): RobertaEncoder(
          (layer): ModuleList(
            (0-5): 6 x RobertaLayer(
              (attention): RobertaAttention(
                (self): RobertaSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
 

In [44]:
num_epoch = config['epochs']

optimizer = torch.optim.Adam(
    model.parameters(), float(config['init_lr']), 
    weight_decay=eval(config['weight_decay'])
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['epochs']-config['warm_up'], 
    eta_min=0, last_epoch=-1
)
#loss_func = torch.nn.L1Loss()

In [45]:
def reform(tens):
    stacked_tensor = torch.stack(tens)  # This will give you a shape of (30, 30, 16)
    return stacked_tensor.t()

In [46]:
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()

In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F  # Importing functional module

class MultiHeadLoss(nn.Module):
    def __init__(self):
        super(MultiHeadLoss, self).__init__()

    def forward(self, predictions, targets):
        losses = []
        
        for i in range(predictions.size(1)):  # Iterate over each property
            property_predictions = predictions[:, i]
            property_targets = targets[:, i]
            
            # Compute MSE for this property
            loss = F.mse_loss(property_predictions, property_targets)
            #loss = F.l1_loss(property_predictions, property_targets)
            losses.append(loss)

        # Return a list of losses for each property
        return losses

def train_loop():
    train_tqdm = tqdm(train_dataloader, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    loss_sum = 0
    total_pred_labels = []
    total_true_labels = []
    model.train()
    criterion = MultiHeadLoss()  # Initialize the multi-head loss

    for (bert_batch, graph_batch1, graph_batch2, targets) in train_tqdm:
        
        optimizer.zero_grad()
        
        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)
        targets = reform(targets).clone().detach().to(device)

        #print(bert_batch)
    
        with autocast():  # Mixed precision context
            pred_labels = model(bert_batch, graph_batch1, graph_batch2)
            # Compute the separate losses
            losses = criterion(pred_labels, targets)
    
        # Backward pass for each loss
        for loss in losses:
            scaler.scale(loss).backward(retain_graph=True)
                #loss.backward(retain_graph=True)  # Retain the graph for multiple backward passes
    
            # After accumulating all gradients, optimize
        scaler.step(optimizer)
        scaler.update()
        #optimizer.step()  # Update the model parameters
    
        pred_labels = pred_labels
        true_labels = targets
        total_pred_labels.append(pred_labels.detach().cpu())
        total_true_labels.append(true_labels.detach().cpu())

        loss_sum += sum([loss.item() for loss in losses])  # Sum for total loss log
    
        wandb.log({"loss/train": sum([loss.item() for loss in losses])})  # Log the total loss
    
        train_tqdm.set_postfix(loss=loss_sum / len(total_pred_labels))

        del pred_labels
        del true_labels
        del losses

        torch.cuda.empty_cache()
            
       # optimizer.zero_grad(set_to_none=True)

    return loss_sum / len(train_dataloader), total_pred_labels, total_true_labels

In [48]:
'''def train_loop():
    train_tqdm = tqdm(train_dataloader, unit="batch")
    train_tqdm.set_description(f'Epoch {epoch_counter}')
    loss_sum = 0
    total_pred_labels = []
    total_true_labels = []
    
    model.train()
    for (bert_batch, graph_batch1, graph_batch2, targets) in train_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)
        #print(targets)
        #print(reform(targets))
        targets = reform(targets).clone().detach().to(device)

        pred_labels = model(bert_batch, graph_batch1, graph_batch2)
        #loss = loss_func(pred_labels.view(-1), targets.float())
        loss = loss_func(pred_labels, targets.float())
        loss.backward()


        pred_labels = pred_labels.reshape(-1)
        true_labels = targets.reshape(-1)
        
        total_pred_labels.append(pred_labels)
        total_true_labels.append(true_labels)
        
        loss_sum += loss.item()

        #wandb.log({“bert_loss/train”:bert_loss, “graph_loss/train”: graph_loss, "bimodal_loss/train": bimodal_loss, "loss/train": loss})
        wandb.log({"loss/train": loss})


        optimizer.step()
        train_tqdm.set_postfix(loss=loss.item())
    return loss_sum / len(train_dataloader), total_pred_labels, total_true_labels'''

'def train_loop():\n    train_tqdm = tqdm(train_dataloader, unit="batch")\n    train_tqdm.set_description(f\'Epoch {epoch_counter}\')\n    loss_sum = 0\n    total_pred_labels = []\n    total_true_labels = []\n    \n    model.train()\n    for (bert_batch, graph_batch1, graph_batch2, targets) in train_tqdm:\n        optimizer.zero_grad()\n\n        bert_batch = bert_batch.to(device)\n        graph_batch1 = graph_batch1.to(device)\n        graph_batch2 = graph_batch2.to(device)\n        #print(targets)\n        #print(reform(targets))\n        targets = reform(targets).clone().detach().to(device)\n\n        pred_labels = model(bert_batch, graph_batch1, graph_batch2)\n        #loss = loss_func(pred_labels.view(-1), targets.float())\n        loss = loss_func(pred_labels, targets.float())\n        loss.backward()\n\n\n        pred_labels = pred_labels.reshape(-1)\n        true_labels = targets.reshape(-1)\n        \n        total_pred_labels.append(pred_labels)\n        total_true_labels.app

In [49]:
loss_func = MultiHeadLoss()

In [50]:
def eval_loop():
    eval_tqdm = tqdm(eval_dataloader, unit="batch")
    eval_tqdm.set_description(f'Epoch {epoch_counter}')
    loss_sum = 0
    total_pred_labels = []
    total_true_labels = []
    
    model.eval()
    for (bert_batch, graph_batch1, graph_batch2, targets) in eval_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)
        targets = reform(targets).clone().detach().to(device)

        with torch.no_grad():
            pred_labels = model(bert_batch, graph_batch1, graph_batch2)


        losses = loss_func(pred_labels, targets.float())

        pred_labels = pred_labels
        true_labels = targets
        total_pred_labels.append(pred_labels.detach().cpu())
        total_true_labels.append(true_labels.detach().cpu())

        loss_sum += sum([loss.item() for loss in losses]) 

        eval_tqdm.set_postfix(loss=loss_sum / len(total_pred_labels))

        del pred_labels
        del true_labels
        del losses
        torch.cuda.empty_cache()

        
    return loss_sum / len(eval_dataloader), total_pred_labels, total_true_labels

In [51]:
def test_loop():
    test_tqdm = tqdm(test_dataloader, unit="batch")
    test_tqdm.set_description(f'Epoch {epoch_counter}')
    loss_sum = 0
    total_pred_labels = []
    total_true_labels = []

    criterion = MultiHeadLoss()

    
    model.eval()
    for (bert_batch, graph_batch1, graph_batch2, targets) in test_tqdm:
        optimizer.zero_grad()

        bert_batch = bert_batch.to(device)
        graph_batch1 = graph_batch1.to(device)
        graph_batch2 = graph_batch2.to(device)
        targets = reform(targets).clone().detach().to(device)

        with torch.no_grad():
            pred_labels = model(bert_batch, graph_batch1, graph_batch2)

        losses = loss_func(pred_labels, targets.float())

        pred_labels = pred_labels
        true_labels = targets
        total_pred_labels.append(pred_labels.detach().cpu())
        total_true_labels.append(true_labels.detach().cpu())

        loss_sum += sum([loss.item() for loss in losses]) 

        eval_tqdm.set_postfix(loss=loss_sum / len(total_pred_labels))

        del pred_labels
        del true_labels
        del losses
        torch.cuda.empty_cache()

        
    return loss_sum / len(test_tqdm), total_pred_labels, total_true_labels

In [52]:
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, roc_auc_score, roc_curve, auc
import torch.nn.functional as F 

In [53]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [54]:
def invert_norm(tensors_list):
    transformed_list = tensors_list#[batch[i, :].clone() for batch in tensors_list for i in range(np.shape(batch)[0])]
    for i in range(number_of_preds):
        transformed_list[:,i] *= std[i]
        transformed_list[:,i] += mn[i]
    return transformed_list

def inverse(total_true_labels, total_pred_labels):
    losses =  loss_func(torch.tensor(invert_norm(total_pred_labels)), torch.tensor(invert_norm(total_true_labels)))
    loss_sum = sum([loss.item() for loss in losses])
    return float(loss_sum/number_of_preds)

In [55]:
n_iter = 0
valid_n_iter = 0
best_valid_loss = np.inf

for epoch_counter in range(num_epoch):
    loss, total_pred_labels, total_true_labels = train_loop()
    
    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()

    wandb.log({"loss/train-average": inverse(total_pred_labels, total_true_labels)})

    loss, total_pred_labels, total_true_labels = eval_loop()

    total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
    total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()

    wandb.log({"loss/eval": inverse(total_pred_labels, total_true_labels)})

    if loss < best_valid_loss:
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_nmr.pth'))
    
    if (epoch_counter + 1) % config['save_every_n_epochs'] == 0:
        torch.save(model.state_dict(), os.path.join(log_dir, 'model_nmr_{}.pth'.format(str(epoch_counter))))

    if epoch_counter >= config['warm_up']:
                scheduler.step()


Epoch 0:   0%|                                                                                                                                                                                        | 0/22 [00:00<?, ?batch/s]

DataBatch(input_ids=[16384], attention_mask=[16384], labels=[16384])


Epoch 0:   5%|███████▌                                                                                                                                                             | 1/22 [00:01<00:34,  1.66s/batch, loss=17.2]

DataBatch(input_ids=[16384], attention_mask=[16384], labels=[16384])


Epoch 0:   9%|███████████████▏                                                                                                                                                       | 2/22 [00:02<00:23,  1.16s/batch, loss=19]

DataBatch(input_ids=[16384], attention_mask=[16384], labels=[16384])


Epoch 0:  14%|██████████████████████▌                                                                                                                                              | 3/22 [00:03<00:18,  1.01batch/s, loss=18.9]

DataBatch(input_ids=[16384], attention_mask=[16384], labels=[16384])


Epoch 0:  18%|██████████████████████████████                                                                                                                                       | 4/22 [00:04<00:18,  1.04s/batch, loss=18.6]


KeyboardInterrupt: 

In [None]:
loss, total_pred_labels, total_true_labels = test_loop()

total_pred_labels = torch.cat(total_pred_labels).cpu().detach().numpy()
total_true_labels = torch.cat(total_true_labels).cpu().detach().numpy()

wandb.log({"loss/test": inverse(total_pred_labels, total_true_labels)})


In [None]:
wandb.finish()

In [None]:
targets

In [None]:
np.shape(total_pred_labels)

In [None]:
type(invert_norm(total_pred_labels))