In [1]:
config = {}
config['data'] = {"qm9_broad_ir_path":'/home2/kanakala.ganesh/ir_data/qm9_broad_ir.pkl',
                  "vocab_path":'/home2/kanakala.ganesh/CLIP_PART_1/data/qm9_vocab.pkl',
                  "datafiles" : {
                        'train': '/home2/kanakala.ganesh/ir_data/raw_train.pickle',
                        'test':  '/home2/kanakala.ganesh/ir_data/raw_test.pickle',
                        'val':   '/home2/kanakala.ganesh/ir_data/raw_val.pickle'
                        },
                  "normalization" : "unit",
                  "shuffle": True,
                  "batch_size":200,
                  "seq_len":70,
                  "splits":[0.8, 0.1, 0.1],
                  "num_workers":20
                }

config['molecule_encoder'] = {
    'attention': 1,
    'coords_weight' :1.0,
    'device': "cuda",
    'hidden_nf':256,
    'in_edge_nf':0,
    'in_node_nf':15,
    'n_layers': 3,
    'node_attr': 1,
    'output_size':512
}

config['molecule_decoder'] = {
    'in_size': 512,
    'latent_size' : 512,
    'hidden_size': 512,
    'n_layers' : 5,
    'n_heads' : 4
}

config['spectra_encoder'] = {
    'd_ff': 1024,
    'dropout': 0.0,
    'dropout_emb': 0.1,
    'h_dim': 256,
    'max_time_steps': 1000,
    'num_heads': 7,
    'num_layers': 5,
    'output_size': 512,
    'patch_size': 7,
    'use_clf_token': True,
}

config['train'] = {
    'lr':0.0001,
    'temperature' :0.1,
    'checkpoint_dir': "checkpoints/temp",
    'device':"cuda",
    'num_epochs':100,
    'threshold': 0.9999,
    'weight_decay': 1.0e-06
}

config['wandb'] = {
    "dir": "/scratch/kanakala.ganesh/",
    "job_type": "sample",
    "project_name": "CLIP_Full_testing",
    "run_name": "RUN_testing"
}
config['data']['max_charge'] = None
config['data']['num_species'] = None

config['train']['logs'] = {
            'train_total_loss':[],
            'train_clip_loss':[],
            'train_recon_loss':[],
            
            'val_total_loss':[],
            'val_clip_loss':[],
            'val_recon_loss':[],
            
            'test_total_loss':[],
            'test_clip_loss':[],
            'test_recon_loss':[],
            
            'best_epoch': -1,
            'best_clip_epoch': -1,
            'best_recon_epoch':-1,
            
            'best_total_loss':1000,
            'best_clip_loss':1000,
            'best_recon_loss':1000
        }
from PrepareData import prepare_data
import torch
from torch import nn, optim, Tensor
from torch.nn import functional as F
import pickle 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import seaborn as sns
import plotly
import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

from train_utils import CombinedLoss
from train_utils import train_clip, train_total, train_recon


logs, max_charge, num_species = None, None, None

In [2]:
import torch
from torch import nn
import pickle
from qm9 import utils as qm9_utils
from models.vit import ViT
from qm9.models import EGNN

device = torch.device("cuda")
dtype = torch.float32

from models.decoder import LatentToMol

def set_up_causal_mask(seq_len):
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask.requires_grad = False
    return mask

In [3]:

import torch
from torch import nn
import pickle
from qm9 import utils as qm9_utils
from models.vit import ViT
from qm9.models import EGNN

device = torch.device("cuda")
dtype = torch.float32

from models.decoder import LatentToMol

def set_up_causal_mask(seq_len):
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask.requires_grad = False
    return mask

class PositionalEncodings(nn.Module):
    """Attention is All You Need positional encoding layer"""

    def __init__(self, seq_len, d_model, p_dropout):
        """Initializes the layer."""
        super(PositionalEncodings, self).__init__()
        token_positions = torch.arange(start=0, end=seq_len).view(-1, 1)
        dim_positions = torch.arange(start=0, end=d_model).view(1, -1)
        angles = token_positions / (10000 ** ((2 * dim_positions) / d_model))

        encodings = torch.zeros(1, seq_len, d_model)
        encodings[0, :, ::2] = torch.cos(angles[:, ::2])
        encodings[0, :, 1::2] = torch.sin(angles[:, 1::2])
        encodings.requires_grad = False
        self.register_buffer("positional_encodings", encodings)

        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        """Performs forward pass of the module."""
        x = x + self.positional_encodings
        x = self.dropout(x)
        return x

    
class bottle(nn.Module):
    def __init__(self, seq_len, hidden_size):
        super(bottle, self).__init__()
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.down = nn.Sequential(
            nn.Linear(seq_len*hidden_size, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, hidden_size)
        )
        
    def forward(self, inp):
        inp = inp.view(-1, self.seq_len * self.hidden_size)
        embed = self.down(inp)
        # out = self.up(embed)
        return embed
       

class CLIP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vocab = pickle.load(open(config['data']['vocab_path'], 'rb'))
        self.temperature = config['train']['temperature']
        self.max_charge = config['data']['max_charge']
        self.num_species = config['data']['num_species']
        self.embed = nn.Embedding(len(self.vocab), config['molecule_decoder']['hidden_size'], padding_idx=self.vocab.pad_index)
        self.pe = PositionalEncodings(d_model=config['molecule_decoder']['hidden_size'], p_dropout=0.1, seq_len=config['data']['seq_len'])
        
        transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=config['molecule_decoder']['hidden_size'],
                                                               nhead=8,
                                                               dropout=0.1,
                                                               batch_first=True)
        
        self.trfmencoder = nn.TransformerEncoder(encoder_layer=transformer_encoder_layer,
                                                 num_layers=3)
        
        self.Spectra_Encoder = ViT(
            patch_size = self.config['spectra_encoder']['patch_size'], 
            num_layers = self.config['spectra_encoder']['num_layers'], 
            h_dim = self.config['spectra_encoder']['h_dim'], 
            num_heads = self.config['spectra_encoder']['num_heads'], 
            output_size = self.config['spectra_encoder']['output_size'], 
            d_ff=self.config['spectra_encoder']['d_ff'], 
            max_time_steps=self.config['spectra_encoder']['max_time_steps'], 
            use_clf_token=self.config['spectra_encoder']['use_clf_token'],
            dropout = self.config['spectra_encoder']['dropout'],
            dropout_emb = self.config['spectra_encoder']['dropout_emb']   
        )
        
        self.smiles_decoder = LatentToMol(
            in_size=self.config['molecule_decoder']['latent_size'],
            hidden_size=self.config['molecule_decoder']['hidden_size'], 
            n_layers=self.config['molecule_decoder']['n_layers'], 
            n_heads = self.config['molecule_decoder']['n_heads'],
            seq_len=self.config['data']['seq_len'], 
            vocab = self.vocab)
        
        self.bottle = bottle(config['data']['seq_len'], config['molecule_decoder']['hidden_size'])
        
        self.logit_scale = nn.Parameter(torch.ones([]) * self.temperature)
        
    
    def forward_mol(self, data):
        smi = data['decoder_inp'].to(device)
        smi = self.embed(smi)
        # smi = self.res_block(smi)
        smi = self.pe(smi)
        mem = self.trfmencoder(smi)
        mol_features = self.bottle(mem)
        
        mol_features = mol_features / mol_features.norm(dim=1, keepdim=True)
        
        return mol_features
    
    def forward_spec(self, data):
        spectra = data['IR'].to(device, dtype)
        spectra = torch.unsqueeze(spectra, 1)
        spectra = torch.unsqueeze(spectra, 1)
        
        spectra_features = self.Spectra_Encoder(spectra)
        spectra_features = spectra_features / spectra_features.norm(dim=1, keepdim=True)
        
        return spectra_features
    
    def forward_decoder(self, data, spec_latents):
        smi = data['decoder_inp'].to(device)
        tgt = data['decoder_tgt'].to(device)
        tgt_padding_mask = data['tgt_padding_mask'].to(device)
        tgt_mask = set_up_causal_mask(self.config['data']['seq_len']).to(device)
        
        pred = self.smiles_decoder(spec_latents,
                                   smi,
                                   tgt_mask,
                                   tgt_padding_mask)
        return pred
        
    def forward(self, data):
        logits_scale = self.logit_scale.exp()
        
        mol_latents = self.forward_mol(data)
        spec_latents = self.forward_spec(data)
        
        smile_preds = self.forward_decoder(data, spec_latents)
        
        return mol_latents, spec_latents, smile_preds, logits_scale, data['index'] 
        
        
        
        

In [4]:
from train_utils import validate, train_one_epoch
def run(config):
    with wandb.init(project= config['wandb']['project_name'],
                    dir= config['wandb']['dir'],
                    name=config['wandb']['run_name'] ,
                    config = config,
                    job_type= config['wandb']['job_type'],
                    save_code= True):
        config = wandb.config
        global logs, max_charge, num_species
        num_gpus = torch.cuda.device_count()
        print("No of GPUs available", num_gpus)
        model = CLIP(config)
        model.to(device)
        model = torch.nn.parallel.DataParallel(model)
        
        optimizer = torch.optim.AdamW(model.parameters(), 
                                      lr = config['train']['lr'],
                                      weight_decay=config['train']['weight_decay'])
        vocab = pickle.load(open(config['data']['vocab_path'], 'rb'))
        loss_fn = CombinedLoss(vocab).to(device)
        
        logs = config['train']['logs']
        
        dataloaders, max_charge, num_species = prepare_data(config)
        for d in dataloaders:
            print("no of batches ", len(dataloaders[d]))
        
        config['data']['max_charge'] = max_charge
        config['data']['num_species'] = num_species
        
        print("Starting Training")
        
        wandb.watch(model, loss_fn, log='all', log_freq=100, log_graph=True)
        train_clip(config, model, dataloaders, optimizer, loss_fn, logs, 0, 200)
        # train_recon(config, model, dataloaders, optimizer, loss_fn, logs, 200, 300)
        # train_total(config, model, dataloaders, optimizer, loss_fn, logs, 300,400)
run(config)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkganeshchandan[0m. Use [1m`wandb login --relogin`[0m to force relogin


No of GPUs available 2


0it [00:00, ?it/s]

0it [00:00, ?it/s]

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


no of batches  250
no of batches  30
no of batches  25
Starting Training




Saved to checkpoints/tempration: 24/25 | Loss: 8.679645538330078
Saved to checkpoints/temp
Saved to checkpoints/temp


2it [00:00,  2.47it/s]


0,1
epoch,▁
train_clip_loss,▁
train_recon_loss,▁
train_total_loss,▁
val_clip_loss,▁
val_recon_loss,▁
val_total_loss,▁

0,1
epoch,0.0
train_clip_loss,5.93478
train_recon_loss,3.08123
train_total_loss,9.01601
val_clip_loss,5.62017
val_recon_loss,3.06288
val_total_loss,8.68305


OutOfMemoryError: Caught OutOfMemoryError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/tmp/ipykernel_23124/2753486839.py", line 145, in forward
    spec_latents = self.forward_spec(data)
  File "/tmp/ipykernel_23124/2753486839.py", line 124, in forward_spec
    spectra_features = self.Spectra_Encoder(spectra)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home2/kanakala.ganesh/CLIP_PART_1/models/vit.py", line 124, in forward
    x = self.enc(x)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home2/kanakala.ganesh/CLIP_PART_1/models/vit.py", line 92, in forward
    x = layer(x, mask=mask)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home2/kanakala.ganesh/CLIP_PART_1/models/vit.py", line 56, in forward
    x = self.ffn(x_) + x
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home2/kanakala.ganesh/miniconda3/envs/sbdd-env/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 52.00 MiB (GPU 0; 23.70 GiB total capacity; 21.64 GiB already allocated; 33.69 MiB free; 22.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
