In [1]:
import os
import numpy as np

In [3]:
assays_path = "binding_datasets/" #tsvs to read from
embeddings_path = "embedding_binaries/" #h5py datasets to write out

# Load Assay TSVs

In [4]:
assays = [x for x in os.listdir(assays_path) if x.endswith('.tsv')]
print(assays)

['652067.tsv', '1053197.tsv', 'hiv1_protease.tsv']


In [35]:
assay_name = assays[2] #pick which assay
assay_name

'hiv1_protease.tsv'

In [14]:
def read_assay(assay_path):
    assay = {}
    with open(assay_path, "r") as f:
        lines = [x.strip().split('\t') for x in f.readlines()]
        for i, column in enumerate(lines[0]):
            assay[column] = [line[i] for line in lines[1:]]
    return assay

In [15]:
assay = read_assay(assays_path + assay_name)

In [16]:
assay.keys()

dict_keys(['CID', 'CanonicalSMILES', 'IUPACName', 'result', 'binding'])

In [17]:
result = np.array([float(result) for result in assay['result']], dtype=np.float32)
binding = np.array([int(result) for result in assay['binding']], dtype=np.int32)

In [18]:
binding.sum()

2159

In [19]:
num_instances = len(assay['CanonicalSMILES'])
print(num_instances)

7462


# Load Pretrained Transformer

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

import sys
sys.path.insert(1, '../') #make parent folder visible
from transformer import Transformer, create_masks, nopeak_mask

In [21]:
checkpoint_dir = "../checkpoints/"

In [22]:
MAX_LEN = 256
MODEL_DIM = 512
N_LAYERS = 6

In [23]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
TRANSFORMER_DEVICE = DEVICE #torch.device("cpu")

cuda


In [24]:
PRINTABLE_ASCII_CHARS = 95

_extra_chars = ["seq_start", "seq_end", "pad"]
EXTRA_CHARS = {key: chr(PRINTABLE_ASCII_CHARS + i) for i, key in enumerate(_extra_chars)}
ALPHABET_SIZE = PRINTABLE_ASCII_CHARS + len(EXTRA_CHARS)

In [25]:
def find_ckpts(*args, **kwargs):
    ckpts = os.listdir(checkpoint_dir)
    str_args = [str(x) for x in itertools.chain(args, kwargs.values())]
    return [checkpoint_dir + ckpt for ckpt in ckpts if all([arg in ckpt.replace(".", "_").split("_") for arg in str_args])]

In [26]:
def encode_char(c):
    return ord(c) - 32

def decode_char(n):
    return chr(n + 32)

In [37]:
def encode_string(string, start_char=chr(0)):
    return torch.tensor([ord(start_char)] + [encode_char(c) for c in string])

def encode_string_np(string, start_char=chr(0), pad_char=chr(0)):
    if len(string) > 255:
        string = string[:255]
        
    arr = np.full((256,), ord(pad_char), dtype=np.float32)
    arr[:len(string)+1] = np.array([ord(start_char)] + [encode_char(c) for c in string])
    return arr

In [28]:
def pad_tensors(tensors, pad_char=chr(0), max_len=None):
    if not max_len:
        max_len = max([t.shape[0] for t in tensors]) + 1
        
    padded_tensors = torch.full((len(tensors), max_len), ord(pad_char), dtype=torch.long)
    for i, tensor in enumerate(tensors):
        padded_tensors[i, 0:tensor.shape[0]] = tensor
        
    return padded_tensors

### Select Weights Checkpoint

In [30]:
found = find_ckpts()
print(found)
load_path = found[0] if len(found) > 0 else ""

['../checkpoints/pretrained.ckpt']


In [31]:
model = Transformer(ALPHABET_SIZE, MODEL_DIM, N_LAYERS)
model = nn.DataParallel(model)
model = model.to(TRANSFORMER_DEVICE)

In [32]:
checkpoint = torch.load(load_path)
model.load_state_dict(checkpoint['state_dict'])
model = model.eval()

# Create H5PY Dataset File

In [33]:
import h5py

In [36]:
transformer_epoch = 2
assay_name = assay_name.replace(".tsv", "_" + str(transformer_epoch) + ".hdf5")
assay_path = embeddings_path + assay_name
print(assay_path)

embedding_binaries/hiv1_protease_2.hdf5


In [None]:
f = h5py.File(assay_path, 'w-')

In [None]:
embeddings = f.create_dataset("embeddings", (num_instances, 256, 512), dtype=np.float32)

In [None]:
result_dset = f.create_dataset("result", (num_instances,), dtype=np.float32)
binding_dset = f.create_dataset("binding", (num_instances,), dtype=np.int32)

In [None]:
smiles_enc = f.create_dataset("smiles", (num_instances, 256), dtype=np.float32)

In [None]:
with torch.no_grad():
    for i, smiles in enumerate(assay['CanonicalSMILES']):
        encoded = encode_string(smiles, start_char=EXTRA_CHARS['seq_start']).unsqueeze(0).to(TRANSFORMER_DEVICE)
        encoded = encoded[:,:MAX_LEN]
        mask = create_masks(encoded)
        embedding = model.module.encoder(encoded, mask)[0].cpu().numpy()
        embeddings[i,:embedding.shape[0],:] = embedding
        result_dset[i] = result[i]
        binding_dset[i] = binding[i]
        
        encoded = encode_string_np(smiles, start_char=EXTRA_CHARS['seq_start'], pad_char=EXTRA_CHARS['pad'])
        encoded = encoded / ALPHABET_SIZE
        smiles_enc[i,:] = encoded
        
        if i % 1000 == 0:
            print(i)

In [None]:
f.close()