In [1]:
config = {}
config['data'] = {"qm9_broad_ir_path":'/home2/kanakala.ganesh/ir_data/qm9_broad_ir.pkl',
                  "vocab_path":'/home2/kanakala.ganesh/CLIP_PART_1/data/qm9_vocab.pkl',
                  "datafiles" : {
                        'train': '/home2/kanakala.ganesh/ir_data/raw_train.pickle',
                        'test':  '/home2/kanakala.ganesh/ir_data/raw_test.pickle',
                        'val':   '/home2/kanakala.ganesh/ir_data/raw_val.pickle'
                        },
                  "normalization" : "unit",
                  "shuffle": True,
                  "batch_size":64,
                  "seq_len":70,
                  "splits":[0.9, 0.1, 0.1],
                  "num_workers":16
                }

config['molecule_encoder'] = {
    'attention': 1,
    'coords_weight' :1.0,
    'device': "cuda",
    'hidden_nf':256,
    'in_edge_nf':0,
    'in_node_nf':15,
    'n_layers': 3,
    'node_attr': 1,
    'output_size':512
}

config['molecule_decoder'] = {
    'in_size': 512,
    'latent_size' : 512,
    'hidden_size': 512,
    'n_layers' : 7,
    'n_heads' : 4
}

config['spectra_encoder'] = {
    'd_ff': 1024,
    'dropout': 0.0,
    'dropout_emb': 0.1,
    'h_dim': 128,
    'max_time_steps': 1000,
    'num_heads': 5,
    'num_layers': 5,
    'output_size': 512,
    'patch_size': 7,
    'use_clf_token': True,
}

config['train'] = {
    'temperature' :0.1
}

In [2]:
from PrepareData import prepare_data
dataloaders, max_charge, num_species = prepare_data(config)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [3]:
import wandb
import yaml
import os
import math
import time
import argparse
import utils
import json
import random
import sys

import torch
from torch import nn, optim, Tensor
from torch.nn import functional as F 
from torch.utils.data import DataLoader

import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

from qm9.data.utils import _get_species, initialize_datasets
from qm9 import utils as qm9_utils
from qm9.data.dataset import ProcessedDataset
from qm9.data.prepare import prepare_dataset
from torch.utils.data import DataLoader
from qm9.data.utils import initialize_datasets
from qm9.args import init_argparse
from qm9.data.collate import collate_fn
from models.vit import ViT
from qm9.models import EGNN
from qm9 import dataset

import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw

import matplotlib.pyplot as plt
import seaborn as sns
import plotly

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

device = torch.device("cuda")
dtype = torch.float32

from models.decoder import LatentToMol

In [4]:
def set_up_causal_mask(seq_len):
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask.requires_grad = False
    return mask

In [5]:

class CLIP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.vocab = pickle.load(open(config['data']['vocab_path'], 'rb'))
        self.temperature = config['train']['temperature']
        
        self.Molecule_Encoder = EGNN(
             in_node_nf = self.config['molecule_encoder']['in_node_nf'], 
             in_edge_nf = self.config['molecule_encoder']['in_edge_nf'], 
             hidden_nf = self.config['molecule_encoder']['hidden_nf'], 
             device = torch.device(self.config['molecule_encoder']['device']), 
             n_layers = self.config['molecule_encoder']['n_layers'], 
             coords_weight = self.config['molecule_encoder']['coords_weight'],
             attention = self.config['molecule_encoder']['attention'], 
             node_attr = self.config['molecule_encoder']['node_attr'],
            output_size = self.config['molecule_encoder']['output_size'],
        )
        
        self.Spectra_Encoder = ViT(
            patch_size = self.config['spectra_encoder']['patch_size'], 
            num_layers = self.config['spectra_encoder']['num_layers'], 
            h_dim = self.config['spectra_encoder']['h_dim'], 
            num_heads = self.config['spectra_encoder']['num_heads'], 
            output_size = self.config['spectra_encoder']['output_size'], 
            d_ff=self.config['spectra_encoder']['d_ff'], 
            max_time_steps=self.config['spectra_encoder']['max_time_steps'], 
            use_clf_token=self.config['spectra_encoder']['use_clf_token'],
            dropout = self.config['spectra_encoder']['dropout'],
            dropout_emb = self.config['spectra_encoder']['dropout_emb']   
        )
        
        self.smiles_decoder = LatentToMol(
            in_size=self.config['molecule_decoder']['latent_size'],
            hidden_size=self.config['molecule_decoder']['hidden_size'], 
            n_layers=self.config['molecule_decoder']['n_layers'], 
            n_heads = self.config['molecule_decoder']['n_heads'],
            seq_len=self.config['data']['seq_len'], 
            vocab = self.vocab)
        
        self.logit_scale = nn.Parameter(torch.ones([]) * self.temperature)
        
    
    def forward_mol(self, data, max_charge, num_species):
        batch_size = self.config['data']['batch_size']
        batch_size, n_nodes, _ = data['positions'].size()
        atom_positions = data['positions'].view(batch_size * n_nodes, -1).to(device, dtype)
        atom_mask = data['atom_mask'].view(batch_size * n_nodes, -1).to(device, dtype)
        edge_mask = data['edge_mask'].to(device, dtype)
        one_hot = data['one_hot'].to(device, dtype)
        charges = data['charges'].to(device, dtype)
        
        charge_scale = max_charge
    
        nodes = qm9_utils.preprocess_input(one_hot, 
                                    charges,
                                    2,
                                    charge_scale,
                                    device)

        nodes = nodes.view(batch_size * n_nodes, -1)
        edges = qm9_utils.get_adj_matrix(n_nodes, batch_size, device)
        
        mol_features = self.Molecule_Encoder(h0=nodes, 
             x=atom_positions, 
             edges=edges, 
             edge_attr=None, 
             node_mask=atom_mask, 
             edge_mask=edge_mask,
             n_nodes=n_nodes)
        
        mol_features = mol_features / mol_features.norm(dim=1, keepdim=True)
        
        return mol_features
    
    def forward_spec(self, data):
        spectra = data['IR'].to(device, dtype)
        spectra = torch.unsqueeze(spectra, 1)
        spectra = torch.unsqueeze(spectra, 1)
        
        spectra_features = self.Spectra_Encoder(spectra)
        spectra_features = spectra_features / spectra_features.norm(dim=1, keepdim=True)
        
        return spectra_features
    
    def forward_decoder(self, data, spec_latents):
        smi = data['decoder_inp'].to(device)
        tgt = data['decoder_tgt'].to(device)
        tgt_padding_mask = data['tgt_padding_mask'].to(device)
        tgt_mask = set_up_causal_mask(self.config['data']['seq_len']).to(device)
        
        pred = self.smiles_decoder(spec_latents,
                                   smi,
                                   tgt_mask,
                                   tgt_padding_mask)
        return pred
        
    def forward(self, data, max_charge, num_species):
        logits_scale = self.logit_scale.exp()
        
        mol_latents = self.forward_mol(data, max_charge, num_species)
        spec_latents = self.forward_spec(data)
        
        smile_preds = self.forward_decoder(data, spec_latents)
        
        return mol_latents, spec_latents, smile_preds, logits_scale, data['index'] 
        
        

In [6]:
model = CLIP(config)
model.to(device)

CLIP(
  (Molecule_Encoder): EGNN(
    (embedding): Linear(in_features=15, out_features=256, bias=True)
    (embedding_out): Linear(in_features=256, out_features=15, bias=True)
    (gcl_0): E_GCL_mask(
      (edge_mlp): Sequential(
        (0): Linear(in_features=513, out_features=256, bias=True)
        (1): SiLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): SiLU()
      )
      (node_mlp): Sequential(
        (0): Linear(in_features=527, out_features=256, bias=True)
        (1): SiLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
      )
      (att_mlp): Sequential(
        (0): Linear(in_features=256, out_features=1, bias=True)
        (1): Sigmoid()
      )
      (act_fn): SiLU()
    )
    (gcl_1): E_GCL_mask(
      (edge_mlp): Sequential(
        (0): Linear(in_features=513, out_features=256, bias=True)
        (1): SiLU()
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): SiLU()
      )
      (node

In [7]:
for i, data in enumerate(dataloaders['train']):
    data
    break

In [8]:
mol_latents, spec_latents, smile_preds, logits_scale, ids = model(data, max_charge, num_species)

In [9]:
mol_latents.shape

torch.Size([64, 512])

In [10]:
spec_latents.shape

torch.Size([64, 512])

In [11]:
smile_preds.shape

torch.Size([64, 70, 25])

In [12]:
ids.shape

torch.Size([64])