In [3]:
import torch
import torch.nn as nn
import sys
import pandas as pd
from torchmetrics.regression import R2Score
import pytorch_lightning as pl
import torch.utils.data as data
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    LearningRateMonitor,
    ModelSummary,
)
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from msdatasets import MSDataset
from tqdm import tqdm as tqdm

# from train import SSModel
from pretrain_MAE import SSModel
from models.utils import jaccard_index
from einops import rearrange
import torch.nn.functional as F
import contextlib
import numpy as np

# from train import SSModel
original_sys_path = sys.path.copy()
sys.path.append("./lsm/hf_pretrain")
from models.conditional_gpt2_model import ConditionalGPT2LMHeadModel
from pretrain_smiles_bert import MolBert
from pretrain_smiles_decoder import SelfiesDecoder
sys_path = sys.path

from transformers import (
    AutoModelForCausalLM,
    RobertaTokenizerFast,
    RobertaForMaskedLM,
    DataCollatorWithPadding,
    AutoTokenizer
)
import os
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.DataStructs.cDataStructs import ExplicitBitVect

import selfies as sf

torch.set_float32_matmul_precision("medium")

pl.seed_everything(42)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

Global seed set to 42


In [4]:
path = '../results/gen/leafy-leaf-979-beam/'
# read in csvs with all of the generated smiles
casmi_smiles = pd.read_csv(f'{path}/casmi/casmi_rerank_batch_preds.csv')
casmi2017_smiles = pd.read_csv(f'{path}/casmi2017/casmi2017_rerank_batch_preds.csv')
unknown_smiles = pd.read_csv(f'{path}/unknown/unknown_rerank_batch_preds.csv')

# read in corresponding smiles targets
casmi_targets = pd.read_csv(f'{path}/casmi/casmi_rerank_targets.csv')
casmi2017_targets = pd.read_csv(f'{path}/casmi2017/casmi2017_rerank_targets.csv')
unknown_targets = pd.read_csv(f'{path}/unknown/unknown_rerank_targets.csv')
#rename columns in ttargets to ground_truth
casmi_targets.rename(columns={'0':'ground_truth'}, inplace=True)
casmi2017_targets.rename(columns={'0':'ground_truth'}, inplace=True)
unknown_targets.rename(columns={'0':'ground_truth'}, inplace=True)

# collapse these into one df. First column is the targets, and the remainder are the generated smiles, with first column named target
casmi = pd.concat([casmi_targets, casmi_smiles], axis=1)
casmi2017 = pd.concat([casmi2017_targets, casmi2017_smiles], axis=1)
unknown = pd.concat([unknown_targets, unknown_smiles], axis=1)

# add a row to beginning specifying where each set comes from
casmi.insert(0, 'set', 'casmi')
casmi2017.insert(0, 'set', 'casmi2017')
unknown.insert(0, 'set', 'unknown')

# merge into one
df = pd.concat([casmi, casmi2017, unknown])

In [5]:
# inport BeRT style encoder
encoder_model = MolBert.load_from_checkpoint(
    "/home/gabriel/lsm_ms2/MS2_LSM/lsm/hf_pretrain/trained_models/selfies_bert_best.ckpt"
)
encoder_model.mask_pct = 0.0
encoder_model.eval()
encoder_model.cuda()

tokenizer = AutoTokenizer.from_pretrained(
    "zjunlp/MolGen-large", max_len=256
)
collator = DataCollatorWithPadding(
    tokenizer, padding=True, return_tensors="pt"
)
pl.seed_everything(42)

Global seed set to 42


42

In [None]:
embedding_store = torch.empty((len(df), 101, 1024))
pl.seed_everything(42)
# make a simple training loop
for i in tqdm(range(len(df))):
    embeddings = torch.empty((101, 1024))
    j = 0
    for _ in range(1):
        # extract last 100 rows
        smiles = df.iloc[i, 1+j:1+j+101].values
        if len(smiles) == 0 :
            break
        # convert smiles to selfies
        selfies = []
        for smile in smiles:
            try:
                selfies.append(sf.encoder(smile))
            except:
                pass

        # encode the selfies
        tokenized_selfies = tokenizer(
                selfies,
                padding=True,
                max_length=256,
                truncation=True,
                return_tensors="pt",
            )
        inputs = collator(tokenized_selfies)
        inputs = {k: v.cuda() for k, v in inputs.items()}
        with torch.no_grad():
            outputs = encoder_model(
                input_ids=inputs['input_ids'], 
                attention_mask=inputs['attention_mask'], 
                output_hidden_states=True
            )
        full_embeddings = outputs[1][-1]
        mask = inputs["attention_mask"]
        mean_embeddings = (full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(
            -1
        ).unsqueeze(-1)
        embeddings[j:j+len(selfies)] = mean_embeddings
        
        
        
        # Free up memory
        del inputs, outputs, full_embeddings, mean_embeddings, tokenized_selfies
        torch.cuda.empty_cache()  # Clear cache to free unused memory


    embedding_store[i] = embeddings.cpu()

In [None]:
target = df['ground_truth'].iloc[365]
gens = df.iloc[365, 2:].values
target_selfie = sf.encoder(target)
gen_selfies = []
for smile in range(len(gens)):
    try:
        gen_selfies.append(sf.encoder(gens[smile]))
    except:
        print(smile)
        pass

s = [target_selfie, gen_selfies]
embs = []
for a in s:
    tokenized_selfies = tokenizer(
            a,
            padding=True,
            max_length=256,
            truncation=True,
            return_tensors="pt",
        )
    inputs = collator(tokenized_selfies)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = encoder_model(
            input_ids=inputs['input_ids'], 
            attention_mask=inputs['attention_mask'], 
            output_hidden_states=True
        )
    full_embeddings = outputs[1][-1]
    mask = inputs["attention_mask"]
    mean_embeddings = (full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(
        -1
    ).unsqueeze(-1)
    embs.append(mean_embeddings)
embs

In [14]:
os.makedirs(f'{path}/embeddings', exist_ok=True)
# first column is target embeddings, next 100 colums are generated embeddings
target_embeds = embedding_store[:, 0]
generated_embeds = embedding_store[:, 1:]

print(target_embeds.shape, generated_embeds.shape)

# save the embeddings as a numpy array along with the corresponding targets
np.save(f'{path}/embeddings/target_embeddings.npy', target_embeds.numpy())
np.save(f'{path}/embeddings/generated_embeddings.npy', generated_embeds.numpy())
df.to_csv(f'{path}/embeddings/targets_and_generated_smiles.csv', index=False)

torch.Size([12981, 1024]) torch.Size([12981, 100, 1024])


### Use 1M random smiles

In [15]:
zinc100m = pd.read_parquet('/mnt/data/gabriel_data/ai_ds_gabriel_ms2_workspace/datasets/hf_smiles/zinc100m.parquet')
# filter out 1M smiles
zinc1m = zinc100m.sample(1000000)

In [16]:
# create simple dataset
class EmbeddingDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return self.df.iloc[idx]['SELFIES']
    
selfies_data = EmbeddingDataset(zinc1m)

loader = DataLoader(
    selfies_data,
    batch_size=1024,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
)


In [None]:
zinc_embedding_store = torch.empty((len(zinc1m),  1024))
pl.seed_everything(42)

j= 0
for batch in tqdm(loader):
    tokenized_selfies = tokenizer(
        batch,
        padding=True,
        max_length=256,
        truncation=True,
        return_tensors="pt",
    )
    inputs = collator(tokenized_selfies)
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        outputs = encoder_model(
            input_ids=inputs['input_ids'], 
            attention_mask=inputs['attention_mask'], 
            output_hidden_states=True
        )
    full_embeddings = outputs[1][-1]
    mask = inputs["attention_mask"]
    mean_embeddings = (full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(
        -1
    ).unsqueeze(-1)
    zinc_embedding_store[j:j+len(batch)] = mean_embeddings.cpu()
    j += len(batch)


In [18]:
# add embeddings as a column to the zinc1m dataframe
zinc1m['embeddings'] = zinc_embedding_store.numpy().tolist()
zinc1m.to_parquet(f'{path}/embeddings/zinc1m_embeddings.parquet')

In [20]:
zinc1m

Unnamed: 0,smiles,zinc_id,SELFIES,deepsmiles,embeddings
45300403,CC1CCC(N(C(=O)COn2nnc3ccc(S(=O)(=O)N(C)C)cc32)...,500410842,[C][C][C][C][C][Branch2][Ring2][O][N][Branch2]...,CCCCCNC=O)COnnnccccS=O)=O)NC)C)))cc69)))))))))...,"[0.5256675481796265, 0.3143625557422638, -0.38..."
90756616,Cc1cccc(NNC(=O)NC[C@@H](NS(=O)(=O)N(C)C)C2CCCC...,1636569630,[C][C][=C][C][=C][C][Branch2][Ring1][S][N][N][...,CcccccNNC=O)NC[C@@H]NS=O)=O)NC)C))))CCCCC5))))...,"[-0.13607320189476013, 0.4302113652229309, -0...."
85566921,CO[C@@H](CN1CC[C@H]1CNC(=O)C1(C)CC=CC1)C1CCCC1,1754493515,[C][O][C@@H1][Branch2][Ring1][Branch2][C][N][C...,CO[C@@H]CNCC[C@H]4CNC=O)CC)CC=CC5)))))))))))))...,"[0.5098204612731934, -0.12734751403331757, -0...."
50096874,C[C@H](c1nccn1C)n1cncc1[C@@H]1CCNC1,2061159209,[C][C@H1][Branch1][=Branch2][C][=N][C][=C][N][...,C[C@H]cnccn5C))))))ncncc5[C@@H]CCNC5,"[0.6597678065299988, -0.21915557980537415, -0...."
55130686,CC[C@H]1[C@@H](O)CCN1C(=O)C[C@@H]1CC[C@@H]2C[C...,2251000541,[C][C][C@H1][C@@H1][Branch1][C][O][C][C][N][Ri...,CC[C@H][C@@H]O)CCN5C=O)C[C@@H]CC[C@@H]C[C@H]63,"[0.7782033681869507, -0.051961664110422134, -0..."
...,...,...,...,...,...
20810104,C=C[C@](C)(CC)C(=O)NCC[C@H](C)NC(=O)[C@H]1CCCN1C,1078221725,[C][=C][C@][Branch1][C][C][Branch1][Ring1][C][...,C=C[C@]C)CC))C=O)NCC[C@H]C)NC=O)[C@H]CCCN5C,"[0.29643702507019043, 0.05870760977268219, 0.0..."
11442819,COc1cccc(/C=C/C(=O)NC[C@@H]2OC[C@@H]3CCN(C(=O)...,1790999571,[C][O][C][=C][C][=C][C][Branch2][Ring2][Ring2]...,COccccc/C=C/C=O)NC[C@@H]OC[C@@H]CCNC=O)CCSC=N)...,"[0.47248637676239014, 0.4747932553291321, 0.06..."
31598141,COC(=O)COc1cccc(CNC(=O)C=Cc2ccccc2C)c1,2297580483,[C][O][C][=Branch1][C][=O][C][O][C][=C][C][=C]...,COC=O)COcccccCNC=O)C=Ccccccc6C))))))))))))c6,"[0.4460662305355072, -0.16228066384792328, -0...."
64222092,Cc1cnc(CNCC[C@H](C)NC(=O)c2sc(C)cc2C)cn1,1497195651,[C][C][=C][N][=C][Branch2][Ring1][=Branch2][C]...,CccncCNCC[C@H]C)NC=O)cscC)cc5C)))))))))))))cn6,"[0.5250552892684937, -0.23900504410266876, -0...."
