In [1]:
# use source activate /n/groups/marks/software/anaconda_o2/envs/dd_torch

import esm

import pickle
import numpy as np
import importlib
import pandas as pd
import time
import torch

In [2]:
print(esm.__file__)

/n/groups/marks/software/anaconda_o2/envs/dd_torch/lib/python3.7/site-packages/esm/__init__.py


In [1]:

#
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()

# to get rid of random dropout
model= model.eval()



print('reading structure in')
# read structure in 
cifpath = '/n/groups/marks/users/david/esm_if/data/bio_all_rm_non_chain.cif' # .pdb format is also acceptable
coords, seqs = esm.inverse_folding.multichain_util.load_complex_coords(
    cifpath, 
    ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
)


  "Regression weights not found, predicting contacts will not produce correct results."


reading structure in




MVVRLRVAVTPEQAARMRELVEAGWYATESEIVREAVFRWELEERLRRRDVRRLRELWEEGRRSGEPRPVDFGELRERAEEALRG
gen 1 seq took 14.74303936958313 seconds


In [17]:
# fix positions in esm_if

def _concatenate_coords(coords, target_chain_id, padding_length=10):
    """
    Args:
        coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
            coordinates representing the backbone of each chain
        target_chain_id: The chain id to sample sequences for
        padding_length: Length of padding between concatenated chains
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates, a
              concatenation of the chains with padding in between
            - seq is the extracted sequence, with padding tokens inserted
              between the concatenated chains
    """
    pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
    # For best performance, put the target chain first in concatenation.
    coords_list = [coords[target_chain_id]]
    for chain_id in coords:
        if chain_id == target_chain_id:
            continue
        coords_list.append(pad_coords)
        coords_list.append(coords[chain_id])
    coords_concatenated = np.concatenate(coords_list, axis=0)
    return coords_concatenated

def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
        padding_length=10, mask_pattern = None):
    """
    Samples sequence for one chain in a complex.
    Args:
        model: An instance of the GVPTransformer model
        coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
            coordinates representing the backbone of each chain
        target_chain_id: The chain id to sample sequences for
        padding_length: padding length in between chains
    Returns:
        Sampled sequence for the target chain
    """
    target_chain_len = coords[target_chain_id].shape[0]
    all_coords = _concatenate_coords(coords, target_chain_id) # puts the target chain first

    # Supply padding tokens for other chains to avoid unused sampling for speed
    padding_pattern = ['<pad>'] * all_coords.shape[0]
    for i in range(target_chain_len):
        padding_pattern[i] = '<mask>'
    
    if mask_pattern != None:
        # make sure the supplied mask pattern is the correct length for the sequence

        assert len(mask_pattern) == target_chain_len
        for i in range(len(mask_pattern)):
            padding_pattern[i] = mask_pattern[i]

        
    sampled = model.sample(all_coords, partial_seq=padding_pattern,
            temperature=temperature)
    sampled = sampled[:target_chain_len]
    return sampled

# try really low temperature
start = time.time()
with torch.no_grad():
    sampled_seq = sample_sequence_in_complex(
        model,
        coords,
        'C',
        temperature = 1e-30
    )
    print(sampled_seq)
end = time.time()

print('gen 1 seq took {} seconds'.format(end-start))

MITERLSVRVTPEQARVMDELVAAGRYATRSEIVREAVFRWRLAQERYRRDVRTLRRLWEEGRASGEPRPVDFAELREEARARLG
gen 1 seq took 14.071446180343628 seconds


In [13]:
seqs['C']

'ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT'

In [22]:
# set the mask to wildtype 
mask_list_chC = list(seqs['C'])
mask_list_chC[5] = '<mask>'
# try really low temperature
start = time.time()
with torch.no_grad():
    sampled_seq = sample_sequence_in_complex(
        model,
        coords,
        'C',
        temperature = 10, 
        mask_pattern = mask_list_chC
    )
    print(sampled_seq)
    print(seqs['C'])
end = time.time()

print('gen 1 seq took {} seconds'.format(end-start))

ANVEKNSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
gen 1 seq took 12.993638753890991 seconds


In [38]:
def generate_mask(wt_seq, mut_str_m1, offset = 1):
    # makes a list of positions to mask based on a mustring
    mask_list_chC = list(wt_seq)

    for m in mut_str_m1.split(':'):
        wt_aa = m[0]
        aa_pos = int(m[1:-1])
        aa_pos_off = aa_pos - offset
        
        assert mask_list_chC[aa_pos_off] == wt_aa
        mask_list_chC[aa_pos_off] = '<mask>'
    return mask_list_chC

ch_c_mask_3_pos = generate_mask(
    seqs['C'], 
    'D61A:K64A:E80A',
    offset = 2)
ch_c_mask_4_pos = generate_mask(
    seqs['C'], 
    'L59A:W60L:D61A:K64L',
    offset = 2)

ch_c_mask_10_pos = generate_mask(
    seqs['C'], 
    'L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R',
    offset = 2)



In [40]:
start = time.time()
with torch.no_grad():
    sampled_seq = sample_sequence_in_complex(
        model,
        coords,
        'C',
        temperature = 1, 
        mask_pattern = ch_c_mask_10_pos
    )
    print(sampled_seq)
    print(seqs['C'])
    print(hamming(sampled_seq, seqs['C']))
end = time.time()


ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDIRRRRQLWDEGKASGRPEPVDYDALRKKAEQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT


In [41]:
def hamming(str1, str2):
    assert len(str1) == len(str2)
    return sum(c1 != c2 for c1, c2 in zip(str1, str2))



In [42]:
# figuring out how many mutations roughly for each temperature
for t in [1e-3, 1e-2, 1e-1,1,10]:
    start = time.time()
    with torch.no_grad():
        sampled_seq = sample_sequence_in_complex(
            model,
            coords,
            'C',
            temperature = t, 
            mask_pattern = ch_c_mask_10_pos
        )
        print(sampled_seq)
        print(seqs['C'])
        print('temperature {}, hamming dist {}'.format(t,hamming(sampled_seq, seqs['C'])))
    end = time.time()

ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 0.001, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 0.01, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 0.1, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDEERDRRQLWDEGKASGRPEPVDEDALRKKKKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1, hamming dist 9
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDETRSKRQLWDEGKASGRPEPVDSDALHKLDSQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature

In [47]:
for t in [1e-6, 1e-5, 1e-4]:
    for i in range(2):
        start = time.time()
        with torch.no_grad():
            sampled_seq = sample_sequence_in_complex(
                model,
                coords,
                'C',
                temperature = t, 
                mask_pattern = ch_c_mask_10_pos
            )
            print(sampled_seq)
            print(seqs['C'])
            print('temperature {}, hamming dist {}'.format(t,hamming(sampled_seq, seqs['C'])))
        end = time.time()

ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1e-06, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1e-06, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1e-05, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1e-05, hamming dist 6
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temp

In [44]:
for t in [1e-30]:
    start = time.time()
    with torch.no_grad():
        sampled_seq = sample_sequence_in_complex(
            model,
            coords,
            'C',
            temperature = t, 
            mask_pattern = ch_c_mask_10_pos
        )
        print(sampled_seq)
        print(seqs['C'])
        print('temperature {}, hamming dist {}'.format(t,hamming(sampled_seq, seqs['C'])))
    end = time.time()

ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDYDALRKKAKQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT
temperature 1e-30, hamming dist 6


In [46]:
for t in [1e-40]:
    start = time.time()
    with torch.no_grad():
        sampled_seq = sample_sequence_in_complex(
            model,
            coords,
            'C',
            temperature = t, 
            mask_pattern = ch_c_mask_10_pos
        )
        print(sampled_seq)
        print(seqs['C'])
        print('temperature {}, hamming dist {}'.format(t,hamming(sampled_seq, seqs['C'])))
    end = time.time()

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [5]:

import esm

import pickle
import numpy as np
import importlib
import pandas as pd
import time
import sys
import torch

models_dir = 'models'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_id = float(time.time())
print(device)


# batch score
datapath = '/n/groups/marks/users/david/esm_if/data/gen_seqs/'
num_pos_mut = 10#int(sys.argv[1])
t = 1 #float(sys.argv[2]) # temperature to sample at
n_seqs = 100

pout = datapath + 'esm_t{}_pos{}_n{}_gen_seq.csv'.format(t, num_pos_mut, n_seqs)
print('writing to {}'.format(pout))


print('loading model in')
# load model
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
# to get rid of random dropout
model= model.eval()

print('reading structure in')
# read structure in 
cifpath = '/n/groups/marks/users/david/esm_if/data/bio_all_rm_non_chain.cif' # .pdb format is also acceptable
coords, seqs = esm.inverse_folding.multichain_util.load_complex_coords(
    cifpath, 
    ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
)

# getting the right sequence position mask
def generate_mask(wt_seq, mut_str_m1, offset = 1):
    # makes a list of positions to mask based on a mustring
    mask_list_chC = list(wt_seq)

    for m in mut_str_m1.split(':'):
        wt_aa = m[0]
        aa_pos = int(m[1:-1])
        aa_pos_off = aa_pos - offset
        
        assert mask_list_chC[aa_pos_off] == wt_aa
        mask_list_chC[aa_pos_off] = '<mask>'
    return mask_list_chC

ch_c_mask_3_pos = generate_mask(
    seqs['C'], 
    'D61A:K64A:E80A',
    offset = 2)
ch_c_mask_4_pos = generate_mask(
    seqs['C'], 
    'L59A:W60L:D61A:K64L',
    offset = 2)

ch_c_mask_10_pos = generate_mask(
    seqs['C'], 
    'L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R',
    offset = 2)

if num_pos_mut == 3:
    ch_c_mask = ch_c_mask_3_pos
elif num_pos_mut == 4:
    ch_c_mask = ch_c_mask_4_pos
elif num_pos_mut == 10:
    ch_c_mask = ch_c_mask_10_pos
else:
    print('wrong input given for num positions to mutate')

    
##############################################
##### sampling ###############################
def _concatenate_coords(coords, target_chain_id, padding_length=10):
    """
    Args:
        coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
            coordinates representing the backbone of each chain
        target_chain_id: The chain id to sample sequences for
        padding_length: Length of padding between concatenated chains
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates, a
              concatenation of the chains with padding in between
            - seq is the extracted sequence, with padding tokens inserted
              between the concatenated chains
    """
    pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
    # For best performance, put the target chain first in concatenation.
    coords_list = [coords[target_chain_id]]
    for chain_id in coords:
        if chain_id == target_chain_id:
            continue
        coords_list.append(pad_coords)
        coords_list.append(coords[chain_id])
    coords_concatenated = np.concatenate(coords_list, axis=0)
    return coords_concatenated

def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
        padding_length=10, mask_pattern = None):
    """
    Samples sequence for one chain in a complex.
    Args:
        model: An instance of the GVPTransformer model
        coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
            coordinates representing the backbone of each chain
        target_chain_id: The chain id to sample sequences for
        padding_length: padding length in between chains
    Returns:
        Sampled sequence for the target chain
    """
    target_chain_len = coords[target_chain_id].shape[0]
    all_coords = _concatenate_coords(coords, target_chain_id) # puts the target chain first

    # Supply padding tokens for other chains to avoid unused sampling for speed
    padding_pattern = ['<pad>'] * all_coords.shape[0]
    for i in range(target_chain_len):
        padding_pattern[i] = '<mask>'
    
    if mask_pattern != None:
        # make sure the supplied mask pattern is the correct length for the sequence

        assert len(mask_pattern) == target_chain_len
        for i in range(len(mask_pattern)):
            padding_pattern[i] = mask_pattern[i]

        
    sampled = model.sample(all_coords, partial_seq=padding_pattern,
            temperature=temperature)
    sampled = sampled[:target_chain_len]
    return sampled

sampled_seqs = []
for i in range(n_seqs):
    with torch.no_grad():
            sampled_seq = sample_sequence_in_complex(
                model,
                coords,
                'C',
                temperature = t, 
                mask_pattern = ch_c_mask
            )
            sampled_seqs.append(sampled_seq)
            fout = open(pout, 'w')
            fout.write('\n'.join(sampled_seqs))
            fout.close()
            print(sampled_seq)


cuda
writing to /n/groups/marks/users/david/esm_if/data/gen_seqs/esm_t1_pos10_n100_gen_seq.csv
loading model in
reading structure in
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDERRRRQLWDEGKASGRPEPVDFDALRKRAAQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDKRRRRQLWDEGKASGRPEPVDFDALRKKAAQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDETRKRRQLWDEGKASGRPEPVDFDALRKEAIQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDDQRRERQLWDEGKASGRPEPVDYDALEKRARQKLT
ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKREERHDEERRERQLWDEGKASGRPEPVDRDALRKRARQKLT


KeyboardInterrupt: 