In [1]:
import pickle 
import yaml
import pandas as pd
from PrepareData import prepare_data


import torch
from torch import nn, optim, Tensor
from torch.nn import functional as F
import pickle 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import seaborn as sns
from architecture import CLIP
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32


In [2]:
from train_utils import load_model

In [3]:
def make_deterministic(random_seed = 0):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    
make_deterministic(0)

In [4]:
config = yaml.safe_load(open('./checkpoints/FULL_ONLY_DECODER/config.yaml', 'r'))
logs = pickle.load(open('./checkpoints/FULL_ONLY_DECODER/logs.pickle', 'rb'))
for key in logs:
    if "best" in key:
        print(key, logs[key])

best_epoch 78
best_clip_epoch 0
best_recon_epoch 78
best_total_loss 6.046065216064453
best_clip_loss 5.991506538391113
best_recon_loss 0.0475520508736372


In [5]:
model = load_model(config['train']['checkpoint_dir'], type="best_latest")
model = model.eval()
dataloaders, max_charge, num_species = prepare_data(config)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [6]:
val_ids = pickle.load(open('./checkpoints/FULL_ONLY_DECODER/val_ids.pickle', 'rb'))

In [7]:
val_ids

tensor([31720, 11277, 27101,  ..., 61049, 83142, 87043])

In [8]:
all_ids = []
with torch.no_grad():
    for i, data in tqdm(enumerate(dataloaders['val'])):    
        data = {k: v.to(device) for k, v in data.items()}
        all_ids.append(data['index'].detach().cpu())
all_ids = torch.cat(all_ids, 0)

50it [00:00, 70.71it/s] 


In [9]:
val_ids.sort()

torch.return_types.sort(
values=tensor([     7,      9,     17,  ..., 133795, 133804, 133806]),
indices=tensor([15904, 19349,  1738,  ...,  2109, 17500,  1599]))

In [10]:
all_ids.sort()

torch.return_types.sort(
values=tensor([     7,      9,     17,  ..., 133795, 133804, 133806]),
indices=tensor([18840, 12013, 11830,  ...,  9803, 12095, 19162]))

In [11]:
((val_ids.sort()[0] == all_ids.sort()[0]).sum() / all_ids.shape[0]).item() 

1.0

In [12]:
from train_utils import decoder_performance
from train_utils import top_scores, decoder_performance, distance_distribution

def clip_performance(config, model, dataloaders, epoch):
    # model.to(device)
    model.eval()
    max_charge = config['data']['max_charge']
    num_species = config['data']['num_species']

    molembeds = []
    specembeds = []
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloaders['val'])):    
            data = {k: v.to(device) for k, v in data.items()}
            mol_latents, spec_latents, smile_preds, logit_scale, ids = model(data)
            molembeds.append(mol_latents.detach().cpu())
            specembeds.append(spec_latents.detach().cpu())
        del mol_latents, spec_latents, smile_preds, logit_scale, ids

    test_molembeds = torch.cat(molembeds, 0)
    test_specembeds = torch.cat(specembeds, 0)
    
    molembeds = []
    specembeds = []
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloaders['train'])):    
            data = {k: v.to(device) for k, v in data.items()}
            mol_latents, spec_latents, smile_preds, logit_scale, ids = model(data)
            molembeds.append(mol_latents.detach().cpu())
                # specembeds.append(spec_latents.detach().cpu())
        del mol_latents, spec_latents, smile_preds, logit_scale, ids
    
    train_molembeds = torch.cat(molembeds, 0)

    return test_molembeds, train_molembeds, test_specembeds

In [13]:
from train_utils import Sampler, calculate_decoder_accuracy

In [14]:
sampler = Sampler( model.module.smiles_decoder, model.module.vocab)

In [15]:
acc = calculate_decoder_accuracy(model, dataloaders, k=1)

50it [15:39, 18.79s/it]

No of Hits :  6688





In [23]:
print(acc)

0.3344


In [22]:
greedy_smiles_list = []
og_smiles_list = []
random_smiles_list = []
with torch.no_grad():
    for i, data in tqdm(enumerate(dataloaders['val'])):    
        data = {k: v.to(device) for k, v in data.items()}
        spec_latents = model.module.forward_spec(data)
        for spec, og in zip(spec_latents, data['smiles'] ):
            og_smile = ""
            chars = model.module.vocab.from_seq(og)
            for char in chars:
                if char != "<pad>" and char != "<eos>" and char != "<sos>" and char != "<unk>":
                    og_smile += char
            og_smile = Chem.CanonSmiles(og_smile)

            greedy_smiles = sampler.sample_multi(n=1,embed=spec,greedy_decode=True)
            greedy_smiles_list.append(greedy_smiles)
            
            random_smiles = sampler.sample_multi(n=3,embed=spec,greedy_decode=False)
            random_smiles_list.append(random_smiles)
            
            og_smiles_list.append(og_smile)
        break
            

0it [01:11, ?it/s]


In [25]:
random_smiles_list

NameError: name 'random_smiles_list' is not defined