# score antitoxin variants and generate antitoxin sequences at various temperature using ProteinMPNN
- code adapted from https://github.com/dauparas/ProteinMPNN

### inputs:
- pdb file

### outputs:
- scored antitoxin variants from proteinMPNN
- generated sampled sequences

In [None]:
import sys
import importlib
import re
import numpy as np
import pandas as pd
from os.path import isfile

sys.path.append('./ProteinMPNN/')
sys.path.append('./ProteinMPNN/vanilla_proteinmpnn/')
sys.path.append('../../../src/')
from regressionTools import hamming

import matplotlib.pyplot as plt
import shutil
import warnings
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
#v_48_010=version with 48 edges 0.10A noise


Mounted at /content/drive


In [None]:
out_folder='../samples/at_samples/'
pdb_path = './bio_all_rm_non_chain_dd1_header.pdb'


In [None]:

def make_tied_pos_dict_5ceg():
    # making positions tied to each other between multimers in the toxin antitoxin pdb 5ceg
    # antitoxin chains: A,C,E,G
    # toxin chains: B, D, F, H
    # making a tied pos dic like in their github
    #{"5TTA": [], "3LIS": [{"A": [1], "B": [1]}, {"A": [2], "B": [2]}, {"A": [3], "B": [3]}, {"A": [4], "B": [4]}, {"A": [5], "B": [5]}, {"A": [6], "B": [6]}, {"A": [7], "B": [7]}, {"A": [8], "B": [8]}, {"A": [9], "B": [9]}, {"A": [10], "B": [10]}, {"A": [11], "B": [11]}, {"A": [12], "B": [12]}, {"A": [13], "B": [13]}, {"A": [14], "B": [14]}, {"A": [15], "B": [15]}, {"A": [16], "B": [16]}, {"A": [17], "B": [17]}, {"A": [18], "B": [18]}, {"A": [19], "B": [19]}, {"A": [20], "B": [20]}, {"A": [21], "B": [21]}, {"A": [22], "B": [22]}, {"A": [23], "B": [23]}, {"A": [24], "B": [24]}, {"A": [25], "B": [25]}, {"A": [26], "B": [26]}, {"A": [27], "B": [27]}, {"A": [28], "B": [28]}, {"A": [29], "B": [29]}, {"A": [30], "B": [30]}, {"A": [31], "B": [31]}, {"A": [32], "B": [32]}, {"A": [33], "B": [33]}, {"A": [34], "B": [34]}, {"A": [35], "B": [35]}, {"A": [36], "B": [36]}, {"A": [37], "B": [37]}, {"A": [38], "B": [38]}, {"A": [39], "B": [39]}, {"A": [40], "B": [40]}, {"A": [41], "B": [41]}, {"A": [42], "B": [42]}, {"A": [43], "B": [43]}, {"A": [44], "B": [44]}, {"A": [45], "B": [45]}, {"A": [46], "B": [46]}, {"A": [47], "B": [47]}, {"A": [48], "B": [48]}, {"A": [49], "B": [49]}, {"A": [50], "B": [50]}, {"A": [51], "B": [51]}, {"A": [52], "B": [52]}, {"A": [53], "B": [53]}, {"A": [54], "B": [54]}, {"A": [55], "B": [55]}, {"A": [56], "B": [56]}, {"A": [57], "B": [57]}, {"A": [58], "B": [58]}, {"A": [59], "B": [59]}, {"A": [60], "B": [60]}, {"A": [61], "B": [61]}, {"A": [62], "B": [62]}, {"A": [63], "B": [63]}, {"A": [64], "B": [64]}, {"A": [65], "B": [65]}, {"A": [66], "B": [66]}, {"A": [67], "B": [67]}, {"A": [68], "B": [68]}, {"A": [69], "B": [69]}, {"A": [70], "B": [70]}, {"A": [71], "B": [71]}, {"A": [72], "B": [72]}, {"A": [73], "B": [73]}, {"A": [74], "B": [74]}, {"A": [75], "B": [75]}, {"A": [76], "B": [76]}, {"A": [77], "B": [77]}, {"A": [78], "B": [78]}, {"A": [79], "B": [79]}, {"A": [80], "B": [80]}, {"A": [81], "B": [81]}, {"A": [82], "B": [82]}, {"A": [83], "B": [83]}, {"A": [84], "B": [84]}, {"A": [85], "B": [85]}, {"A": [86], "B": [86]}, {"A": [87], "B": [87]}, {"A": [88], "B": [88]}, {"A": [89], "B": [89]}, {"A": [90], "B": [90]}, {"A": [91], "B": [91]}, {"A": [92], "B": [92]}, {"A": [93], "B": [93]}, {"A": [94], "B": [94]}, {"A": [95], "B": [95]}, {"A": [96], "B": [96]}]}
    # assuming I'm starting with counting the first amino acid as 1 in the sequence
    '''
    # it should look like this
    tied_pos_dict= { 'bio_all_rm_non_chain_dd1_header': 
                    [
                    {'A': [1], 'C':[2], 'E':[1], 'G':[2]}, # antitoxins until, 
                    {'A': [84], 'C':[85], 'E':[84], 'G':[85]},
                    
                    {'B': [3], 'D':[1], 'F':[3], 'H':[1]}, # toxins until:
                    {'B': [103], 'D':[101], 'F':[103], 'H':[101]}
                    ]
                    }
    '''

    seq_keys = [k for k in pdb_dict_list[0].keys() if k.startswith('seq_chain')]
    seq_key_to_seq = dict(zip(seq_keys, [pdb_dict_list[0][k] for k in seq_keys]))

    # creat antitoxin list of dictionaries:
    at_list_dicts = []
    for i in range(1,85):
        at_list_dicts.append({'A': [i], 'C':[i+1], 'E':[i], 'G':[i+1]})

    t_list_dicts = []
    for i in range(3,104):
        t_list_dicts.append({'B': [i], 'D':[i-2], 'F':[i], 'H':[i-2]})

    tied_pos_dict = { 'bio_all_rm_non_chain_dd1_header': at_list_dicts + t_list_dicts}

    return tied_pos_dict


def generate_seqs(dataset_valid,
                temp,
                fixed_positions_dict,
                BATCH_COPIES, 
                NUM_BATCHES,
                device, 
                chain_id_dict, 
                omit_AA_dict, 
                tied_positions_dict, 
                pssm_dict, 
                bias_by_res_dict, 
                pssm_threshold, 
                omit_AAs_np,
                bias_AAs_np,
                pssm_multi,
                pssm_log_odds_flag,
                pssm_bias_flag
                  ):
    '''
    returns a list of generated sequences that matches the wt sequence of that chain A
    '''
    with torch.no_grad():
        print('Generating sequences...')
        for ix, protein in enumerate(dataset_valid):
            score_list = []
            all_probs_list = []
            all_log_probs_list = []
            S_sample_list = []
            batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, 
                                                                                                                                                                                                                                                                                                                            device, 
                                                                                                                                                                                                                                                                                                                            chain_id_dict, 
                                                                                                                                                                                                                                                                                                                            fixed_positions_dict, 
                                                                                                                                                                                                                                                                                                                            omit_AA_dict, 
                                                                                                                                                                                                                                                                                                                            tied_positions_dict, 
                                                                                                                                                                                                                                                                                                                            pssm_dict, 
                                                                                                                                                                                                                                                                                                                            bias_by_res_dict)
            pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
            name_ = batch_clones[0]['name']

            randn_1 = torch.randn(chain_M.shape, device=X.device)
            log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
            mask_for_loss = mask*chain_M*chain_M_pos
            scores = _scores(S, log_probs, mask_for_loss)
            native_score = scores.cpu().data.numpy()

            for j in range(NUM_BATCHES):
                randn_2 = torch.randn(chain_M.shape, device=X.device)
                if tied_positions_dict == None:
                    sample_dict = model.sample(X, 
                                            randn_2, 
                                            S, 
                                            chain_M, 
                                            chain_encoding_all, 
                                            residue_idx, 
                                            mask=mask, 
                                            temperature=temp, 
                                            omit_AAs_np=omit_AAs_np, 
                                            bias_AAs_np=bias_AAs_np, 
                                            chain_M_pos=chain_M_pos, 
                                            omit_AA_mask=omit_AA_mask, 
                                            pssm_coef=pssm_coef, 
                                            pssm_bias=pssm_bias, 
                                            pssm_multi=pssm_multi, 
                                            pssm_log_odds_flag=bool(pssm_log_odds_flag), 
                                            pssm_log_odds_mask=pssm_log_odds_mask, 
                                            pssm_bias_flag=bool(pssm_bias_flag), 
                                            bias_by_res=bias_by_res_all)
                    S_sample = sample_dict["S"] 
                else:
                    sample_dict = model.tied_sample(X, 
                                                    randn_2, 
                                                    S, 
                                                    chain_M, 
                                                    chain_encoding_all, 
                                                    residue_idx, 
                                                    mask=mask, 
                                                    temperature=temp, 
                                                    omit_AAs_np=omit_AAs_np, 
                                                    bias_AAs_np=bias_AAs_np, 
                                                    chain_M_pos=chain_M_pos, 
                                                    omit_AA_mask=omit_AA_mask, 
                                                    pssm_coef=pssm_coef, 
                                                    pssm_bias=pssm_bias, 
                                                    pssm_multi=pssm_multi, 
                                                    pssm_log_odds_flag=bool(pssm_log_odds_flag), 
                                                    pssm_log_odds_mask=pssm_log_odds_mask, 
                                                    pssm_bias_flag=bool(pssm_bias_flag), 
                                                    tied_pos=tied_pos_list_of_lists_list[0], 
                                                    tied_beta=tied_beta, 
                                                    bias_by_res=bias_by_res_all)
                    S_sample = sample_dict["S"]
                # Compute scores
                log_probs = model(X, 
                                S_sample, 
                                mask, 
                                chain_M*chain_M_pos, 
                                residue_idx, 
                                chain_encoding_all, 
                                randn_2, 
                                use_input_decoding_order=True, 
                                decoding_order=sample_dict["decoding_order"])
                
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S_sample, log_probs, mask_for_loss)
                scores = scores.cpu().data.numpy()
                all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
                all_log_probs_list.append(log_probs.cpu().data.numpy())
                S_sample_list.append(S_sample.cpu().data.numpy())
                for b_ix in range(BATCH_COPIES):
                    masked_chain_length_list = masked_chain_length_list_list[b_ix]
                    masked_list = masked_list_list[b_ix]
                    seq_recovery_rate = torch.sum(
                        torch.sum(
                            torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                    seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                    score = scores[b_ix]
                    score_list.append(score)
                    native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                    if b_ix == 0 and j==0 and temp==temperatures[0]:
                        start = 0
                        end = 0
                        list_of_AAs = []
                        for mask_l in masked_chain_length_list:
                            end += mask_l
                            list_of_AAs.append(native_seq[start:end])
                            start = end
                        native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                        l0 = 0
                        for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                            l0 += mc_length
                            native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                            l0 += 1
                        sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                        print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                        sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                        print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                        native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                        line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, 
                                                                                                                native_score_print, 
                                                                                                                print_visible_chains, 
                                                                                                                print_masked_chains, 
                                                                                                                model_name, 
                                                                                                                native_seq
                                                                                                                )
                        #print(line.rstrip())
                    start = 0
                    end = 0
                    list_of_AAs = []
                    for mask_l in masked_chain_length_list:
                        end += mask_l
                        list_of_AAs.append(seq[start:end])
                        start = end

                    seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                    l0 = 0
                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                        l0 += mc_length
                        seq = seq[:l0] + '/' + seq[l0:]
                        l0 += 1
                    score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                    seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                    line = '>T={}, sample={}, score={}, seq_recovery={}\n{}\n'.format(temp,b_ix,score_print,seq_rec_print,seq)
                    #print(line.rstrip())


    all_probs_concat = np.concatenate(all_probs_list)
    all_log_probs_concat = np.concatenate(all_log_probs_list)
    S_sample_concat = np.concatenate(S_sample_list)

    # decode sequences
    pos_to_alphabet = dict(zip(range(len(alphabet)), alphabet))

    wt_seq = pdb_dict_list[0]['seq_chain_A']
    designed_seqs = []
    for i in range(S_sample_concat.shape[0]):
        s= S_sample_concat[i,:]
        seq = ''.join([pos_to_alphabet[k] for k in s])
        at_chain_a = seq[:85]
        designed_seqs.append(at_chain_a)
        #print(at_chain_a)
        print(hamming(at_chain_a, wt_seq))
    return designed_seqs

def write_seqs_fasta(list_seqs, pout):
    with open(pout, 'w') as fout:
        for i,s in enumerate(list_seqs):
            fout.write('>s{}\n'.format(i))
            fout.write(s+'\n')

In [None]:
# setup model

model_name = "v_48_002" #param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
backbone_noise=0.00    # Standard deviation of Gaussian noise to add to backbone atoms

path_to_model_weights='./ProteinMPNN/vanilla_model_weights'          
hidden_dim = 128
num_layers = 3 
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

checkpoint = torch.load(checkpoint_path, map_location=device) 
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')
model = ProteinMPNN(num_letters=21, 
                    node_features=hidden_dim, 
                    edge_features=hidden_dim, 
                    hidden_dim=hidden_dim, 
                    num_encoder_layers=num_layers, 
                    num_decoder_layers=num_layers, 
                    augment_eps=backbone_noise, 
                    k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded")

Number of edges: 48
Training noise level: 0.02A
Model loaded


In [None]:
# setup proteinMPNN parameters

homomer = False #param {type:"boolean"}

num_seqs = 8 #param ["1", "2", "4", "8", "16", "32", "64"] {type:"raw"}
num_seq_per_target = num_seqs

#markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
sampling_temp = "0.1" #param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"]


save_score=!                      # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs=!                      # 0 for False, 1 for True; save MPNN predicted probabilites per position
score_only=0                      # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only=0          # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)
    
batch_size=1                      # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length=20000                  # Max sequence length
    
folder_for_outputs = out_folder

jsonl_path=''                     # Path to a folder with parsed pdb into jsonl
omit_AAs='X'                      # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.
   
pssm_multi=0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold=0.0                # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag=0               # 0 for False, 1 for True
pssm_bias_flag=0                   # 0 for False, 1 for True

##############################################################

NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

chain_id_dict = None # created downstream
fixed_positions_dict = {
    'bio_all_rm_non_chain_dd1_header': 
    {'A': [1,2,3,4], 
     'C': [2,3,4,5], 
     'E': [1,2,3,4], 
     'G': [2,3,4,5]}
     } #None # ?{"5TTA": {"A": [1, 2, 3, 7, 8, 9, 22, 25, 33], "B": []}
pssm_dict = None #
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))


###############################################################
print('----pdb file from pymol with modified header: ----')
## parsing the chains
designed_chain = "A,C,E,G" #param {type:"string"}
fixed_chain = "B,D,F,H" #param {type:"string"}
if designed_chain == "":
    designed_chain_list = []
else:
    designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")
if fixed_chain == "":
    fixed_chain_list = []
else:
    fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")
chain_list = list(set(designed_chain_list + fixed_chain_list))


# creating the dataset
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)

print('pdb_dict_list',pdb_dict_list)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)
print(chain_id_dict)

for chain in chain_list:
    l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
    print(f"Length of chain {chain} is {l}")

##############################################

tied_positions_dict = make_tied_pos_dict_5ceg()


----pdb file from pymol with modified header: ----
pdb_dict_list [{'seq_chain_E': 'NVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLTE', 'coords_chain_E': {'N_chain_E': [[-5.93, -53.112, -55.965], [-9.055, -53.198, -57.35], [-11.611, -55.721, -57.375], [-13.925, -57.239, -59.42], [-16.943, -59.1, -59.225], [-19.55, -61.417, -60.087], [-22.405, -61.972, -58.286], [-24.951, -63.898, -56.574], [-27.709, -62.928, -54.827], [-30.476, -63.011, -52.446], [-33.858, -62.054, -53.366], [-35.051, -60.519, -51.406], [-32.855, -59.629, -49.912], [-31.596, -58.067, -51.888], [-33.481, -56.102, -52.194], [-33.498, -54.706, -49.746], [-30.888, -53.77, -49.722], [-31.005, -52.014, -51.91], [-32.9, -50.2, -50.909], [-31.393, -49.081, -48.822], [-29.32, -47.806, -50.142], [-30.794, -45.982, -51.75], [-31.833, -44.405, -49.74], [-29.647, -43.046, -48.916], [-27.824, -44.089, -47.121], [-25.713, -45.526, -48.096], [-23.517, -45.341, -50.974], [-23.417, -47.141, -53.149], [

In [None]:

# generate seqs for only the 10 positions: L47, D51, I52, R54, L55, F73, R77, E79, A80, R81 # M0 indexed. so subtract 1 to calculate to pmpnn's chain A+E maNVE, and don't subtract for chain C+G
# for 3 positions: D60, K63, E79 # M0 indexed, as above.
# show the hamming distances

m0_var_pos_10x = [47, 51, 52, 54, 55, 73, 77, 79, 80, 81] # for chain C+G
chAE_var_pos_10x = [p-1 for p in m0_var_pos_10x]

m0_var_pos_3x = [60,63,79] # for chain C+G
chAE_var_pos_3x = [p-1 for p in m0_var_pos_3x]

fixed_positions_dict_3pos_lib = {'bio_all_rm_non_chain_dd1_header':
                                  {'A':[p for p in list(range(1,86)) if p not in chAE_var_pos_3x],
                                   'C':[p for p in list(range(1,86)) if p not in m0_var_pos_3x],
                                   'E':[p for p in list(range(1,86)) if p not in chAE_var_pos_3x],
                                   'G':[p for p in list(range(1,86)) if p not in m0_var_pos_3x]
                                   }
                                  }
fixed_positions_dict_10pos_lib= {'bio_all_rm_non_chain_dd1_header':
                                  {'A':[p for p in list(range(1,86)) if p not in chAE_var_pos_10x],
                                   'C':[p for p in list(range(1,86)) if p not in m0_var_pos_10x],
                                   'E':[p for p in list(range(1,86)) if p not in chAE_var_pos_10x],
                                   'G':[p for p in list(range(1,86)) if p not in m0_var_pos_10x]
                                   }
                                  }

fixed_positions_dict_n_term= {'bio_all_rm_non_chain_dd1_header':
                                  {'A':list(range(1,43)) ,
                                   'C':list(range(2,44)),
                                   'E':list(range(1,43)),
                                   'G':list(range(2,44))
                                   }
                                  }


# scoring seqs

In [None]:
data_dir = '../../../data/DMS_data/'
data_dir_out = '../../../data/coves/scores/protein_mpnn_scores/'

df_10pos = pd.read_csv(data_dir + 'df_at_10pos.csv')
df_3pos = pd.read_csv(data_dir + 'df_mut_all_norm.csv') 
at_ch1_seq = 'NVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLTE'
at_ch2_seq = 'ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT'


In [None]:


def get_full_str(at_ch_seq, offset, mutkey):
    # take a mutkey and mutate the correct positions in the sequence
    # use offset of 3
    #at_ch1_seq = 'NVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLTE'
    # use offset of 2
    #at_ch2_seq = 'ANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLT'

    list_at_ch_seq = list(at_ch_seq)
    for m in mutkey.split(':'):
        wt_aa = m[0]
        pos = int(m[1:-1])
        mut_aa = m[-1]
        #print(at_ch1_seq[pos-1])
        assert at_ch_seq[pos-offset] == wt_aa

        list_at_ch_seq[pos-offset] = mut_aa
    return ''.join(list_at_ch_seq)

def get_full_S_string(mutkey):
    # 
    at_1_mut_seq = get_full_str(at_ch1_seq, 3, mutkey)
    at_2_mut_seq = get_full_str(at_ch2_seq, 2, mutkey)

    toxin_seq = 'MAVRLVWSPTAKADLIDIYVMIGSENIRAADRYYDQLEARALQLADQPRMGVRRPDIRPSARMLVEAPFVLLYETVPDTDDGPVEWVEIVRVVDGRRDLNRLFVRLVWSPTAKADLIDIYVMIGSENIRAADRYYDQLEARALQLADQPRMGVRRPDIRPSARMLVEAPFVLLYETVPDTDDGPVEWVEIVRVVDGRRDLNRLFMAVRLVWSPTAKADLIDIYVMIGSENIRAADRYYDQLEARALQLADQPRMGVRRPDIRPSARMLVEAPFVLLYETVPDTDDGPVEWVEIVRVVDGRRDLNRLFVRLVWSPTAKADLIDIYVMIGSENIRAADRYYDQLEARALQLADQPRMGVRRPDIRPSARMLVEAPFVLLYETVPDTDDGPVEWVEIVRVVDGRRDLNRLF'

    full_s_string = at_1_mut_seq + at_2_mut_seq + at_1_mut_seq + at_2_mut_seq + toxin_seq
    return full_s_string


def get_mutkey_from_seq(seq, wt_mutkey, offset=3):
    mutkey = []
    for m in wt_mutkey.split(':'):
        pos =int(m[1:-1])
        new_pos = pos - offset
        mut_aa = seq[new_pos]
        mutkey.append(m[:-1]+mut_aa)
        
    return ':'.join(mutkey)


def get_S_idx(all_sequence):
    # convert a sequence to list of indices
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    indices = np.asarray([[alphabet.index(a) for a in all_sequence]], dtype=np.int32)
    return indices

def get_mut_s_score(s_string):
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
    for ix, protein in enumerate(dataset_valid):
        batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]

        X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, 
                                                                                                                                                                                                                                                                                                                        device, 
                                                                                                                                                                                                                                                                                                                        chain_id_dict, 
                                                                                                                                                                                                                                                                                                                        fixed_positions_dict, 
                                                                                                                                                                                                                                                                                                                        omit_AA_dict, 
                                                                                                                                                                                                                                                                                                                        tied_positions_dict, 
                                                                                                                                                                                                                                                                                                                        pssm_dict, 
                                                                                                                                                                                                                                                                                                                        bias_by_res_dict)

                                            


        S = get_S_idx(s_string)
        S = torch.from_numpy(S).to(dtype=torch.long,device=device)
        #print(S)

    randn_1 = torch.randn(chain_M.shape, device=X.device)

    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)

    mask_for_loss = mask*chain_M*chain_M_pos
    scores = _scores(S, log_probs, mask_for_loss)
    mut_score = scores.cpu().data.numpy()
    return mut_score[0]

def calc_scores(df, fout_name):
    dic_mutkey_score = {}
    # read in the file of things already done
    fout_name = data_dir_out + fout_name
    if isfile(fout_name):
        with open(fout_name, 'r') as fin:
            for l in fin:
                ms, score = l.rstrip().split(',')
                dic_mutkey_score[ms] = score
        print('read {} mutants'.format(len(dic_mutkey_score)))

    # calc scores
    c=0
    for mut_seq in df['full_s_string']:
        if mut_seq in dic_mutkey_score:
            continue
        mut_score = get_mut_s_score(mut_seq)
        dic_mutkey_score[mut_seq] = mut_score
        with open(fout_name, 'w') as fout:
            for mk, score in dic_mutkey_score.items():
                fout.write(','.join([mk, str(score)])+'\n')
        
        if c%100==0:
            print(len(dic_mutkey_score)/len(df), ' done.')
        c+=1
'''
# sanity checks for conversions
mut_key_m1 = df_10pos.muts_m1[0]
print(mut_key_m1)
for m in mut_key_m1.split(':'):
    wt_aa = m[0]
    pos = int(m[1:-1])
    mut_aa = m[-1]
    print(at_ch1_seq[pos-1])
    assert at_ch1_seq[pos-3] == wt_aa
    assert at_ch2_seq[pos-2] == wt_aa

wt_mutkey = df_10pos.muts_m1[0]

# testing back and forth conversion
test_mutkey = df_10pos.muts_m1[7000]
seq = get_full_S_string(test_mutkey)
reversed_mutkey = get_mutkey_from_seq(seq, wt_mutkey)
assert test_mutkey == reversed_mutkey
'''


L48L:D52D:I53I:R55R:L56L:F74F:R78R:E80E:A81A:R82R
H
R
R
R
Q
A
E
R
Q
K


In [None]:
# add the full_s_string
df_10pos['full_s_string'] = df_10pos.apply(lambda r: get_full_S_string(r.muts_m1), axis=1)
df_3pos['full_s_string'] = df_3pos.apply(lambda r: get_full_S_string(r.muts_m1), axis=1)

print('calculating scores for 10 position library...')
calc_scores(df_10pos, data_dir_out+'10pos_score.csv')
print('calculating scores for 3 position library...')
calc_scores(df_3pos, data_dir_out+'3pos_score.csv')


calculating scores for 10 position library...
read 7923 mutants
calculating scores for 3 position library...
read 5861 mutants
0.73275  done.
0.74525  done.
0.75775  done.
0.77025  done.
0.78275  done.
0.79525  done.
0.80775  done.
0.82025  done.
0.83275  done.
0.84525  done.
0.85775  done.
0.87025  done.
0.88275  done.
0.89525  done.
0.90775  done.



# generating sequences

In [None]:
# generating sequences for 10 position library only

print('generating sequences only for the 10 positions')

wt_seq = pdb_dict_list[0]['seq_chain_A']

n_seqs = 30
for t in [0.1, 0.3, 0.5, 0.7]:#[0.8, 0.9, 1.0, 1.1, 1.3, 1.4, 1.5]:#[1.2]:#[0.5, 1, 4]: #:
    full_gen_seq = generate_seqs(   dataset_valid,
                                    t,
                                    fixed_positions_dict_10pos_lib,
                                    1, # BATCH_COPIES --> how many to do at once
                                    n_seqs, # NUM_BATCHES --> number of sequences to generate
                                    device,
                                    chain_id_dict,
                                    omit_AA_dict,
                                    tied_positions_dict,
                                    pssm_dict,
                                    bias_by_res_dict,
                                    pssm_threshold,
                                    omit_AAs_np,
                                    bias_AAs_np,
                                    pssm_multi,
                                    pssm_log_odds_flag,
                                    pssm_bias_flag
                    )
    full_gen_seq
    print(' temp: {}, avg. hamming {}'.format(t, np.mean([hamming(s, wt_seq) for s in full_gen_seq])))
    write_seqs_fasta(full_gen_seq, '../samples/at_samples/t{}_ta_n{}_10x.fa'.format(t, n_seqs))

generating sequences only for the 10 positions
Generating sequences...
0
0
0
0
1
0
1
1
0
0
2
1
1
1
0
0
1
0
1
1
1
1
0
0
1
0
1
1
0
0
 temp: 0.1, avg. hamming 0.5333333333333333
Generating sequences...
1
1
3
2
2
2
1
1
1
2
1
0
2
2
0
1
2
1
2
0
1
2
1
0
1
1
2
1
1
1
 temp: 0.3, avg. hamming 1.2666666666666666
Generating sequences...
2
3
1
3
2
2
1
0
2
1
0
1
2
1
1
3
1
3
1
1
1
3
2
1
1
1
0
2
2
0
 temp: 0.5, avg. hamming 1.4666666666666666
Generating sequences...
3
4
2
4
4
3
1
3
4
2
2
2
2
4
4
3
4
5
4
2
2
4
2
3
2
2
2
1
2
4
 temp: 0.7, avg. hamming 2.8666666666666667
