In [1]:
import sys
import os
sys.path.append('../../ProteinMPNN')
sys.path.append('..')

In [2]:
import re
import copy
import numpy as np
import torch

In [3]:
from protein_mpnn_utils import _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDatasetPDB, ProteinMPNN

In [4]:
from DomainPrediction.utils import helper

In [5]:
root = '../..'
data_path = os.path.join(root, 'Data')
pmpnn_path = os.path.join(root, 'ProteinMPNN')
fasta_file = os.path.join(data_path, 'pmpnn_experiments/6mfw_exp/6mfw_pmpnn_1000.fasta')

In [6]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# v_48_010 = version with 48 edges 0.10A noise
model_name = "v_48_002" 
backbone_noise = 0.00 # Standard deviation of Gaussian noise to add to backbone atoms
       
hidden_dim = 128
num_layers = 3 
model_folder_path = os.path.join(pmpnn_path, 'vanilla_model_weights')
checkpoint_path = os.path.join(model_folder_path, f'{model_name}.pt')

checkpoint = torch.load(checkpoint_path, map_location=device) 
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print} A')

model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.eval()

Number of edges: 48
Training noise level: 0.02 A


In [7]:
## in our case we will always be using for a single chain design and monomer

designed_chain = "A"
fixed_chain = ""

if designed_chain == "":
  designed_chain_list = []
else:
  designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

if fixed_chain == "":
  fixed_chain_list = []
else:
  fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")

chain_list = list(set(designed_chain_list + fixed_chain_list))

In [8]:
## Design Options

# folder_for_outputs = out_folder

NUM_BATCHES = 250
BATCH_COPIES = 1
temperatures = [0.3]

alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_list = 'X'
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

pssm_multi = 0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold = 0.0                # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag = 0              # 0 for False, 1 for True
pssm_bias_flag = 0                  # 0 for False, 1 for True


pssm_dict = None
omit_AA_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))


protein_file = os.path.join(data_path, '6mfw_conformations/hm_6mg0_cB_ATC.pdb')
pdb_dict_list = parse_PDB(protein_file, input_chain_list=chain_list)
max_length = 20000 
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']] = (designed_chain_list, fixed_chain_list)

for chain in chain_list:
  l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
  print(f"Length of chain {chain} is {l}")


tied_positions_dict = None

Length of chain A is 1000


In [9]:
len(pdb_dict_list[0]['seq']), len(dataset_valid[0]['seq'])

(1000, 1000)

In [10]:
A = [i for i in range(0, 464)]
C = [i for i in range(571, 1000)]
T = [i for i in range(492, 556)]

# A_gxps_atc = [i for i in range(0,489)]
# C_gxps_atc = [i for i in range(604,1034)]
# T_gxps_atc = [i for i in range(505,575)]

In [11]:
''.join([dataset_valid[0]['seq'][i] for i in A]) ## See A/C domain

'FEQQVEMTPDHVAVVDRGQSLTYKQLNERANQLAHHLRGKGVKPDDQVAIMLDKSLDMIVSILAVMKAGGAYVPIDPDYPGERIAYMLADSSAAILLTNALHEEKANGACDIIDVHDPDSYSENTNNLPHVNRPDDLVYVMYTSGSTGLAKGVMIEHHNLVNFCEWYRPYFGVTPADKALVYSSFSFDGSALDIFTHLLAGAALHIVPSERKYDLDALNDYCNQEGITISYLPTGAAEQFMQMDNQSFRVVITGGDVLKKIERNGTYKLYNGYGPTECTIMVTMFEVDKPYANIPIGKPIDRTRILILDEALALQPIGVAGELFIVGEGLGRGYLNRPELTAEKFIVHPQTGERMYRTGDRARFLPDGNIEFLGRLDNLVKIRGYRIEPGEIEPFLMNHPLIELTTVLAKEQADGRKYLVGYYVAPEEIPHGELREWLGNDLPDYMIPTYFVHMKAFPLTANGK'

In [12]:
''.join([dataset_valid[0]['seq'][i] for i in T]) ## See T domain

'QQLAQVWSHVLGIPQMGIDDHFLERGGDSIKVMQLIHQLKNIGLSLRYDQLFTHPTIRQLKRLL'

In [13]:
fix_cond = ' '.join([str(i+1) for i in range(len(dataset_valid[0]['seq'])) if i not in T])

In [14]:
fixed_list = [[int(item) for item in one.split()] for one in fix_cond.split(",")]
global_designed_chain_list = [str(item) for item in 'A'.split()]
my_dict = {}

In [15]:
for result in pdb_dict_list:
    all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
    fixed_position_dict = {}
    for i, chain in enumerate(global_designed_chain_list):
        fixed_position_dict[chain] = fixed_list[i]
    for chain in all_chain_list:
        if chain not in global_designed_chain_list:       
            fixed_position_dict[chain] = []
    my_dict[result['name']] = fixed_position_dict

In [16]:
fixed_positions_dict = my_dict

In [17]:
print(NUM_BATCHES)
print(BATCH_COPIES)
print(temperatures)

250
1
[0.3]


In [18]:

with torch.no_grad():
  print('Generating sequences...')
  for ix, protein in enumerate(dataset_valid):
    score_list = []
    all_probs_list = []
    all_log_probs_list = []
    S_sample_list = []
    batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
    (X, S, mask, lengths, chain_M, chain_encoding_all, 
     chain_list_list, visible_list_list, masked_list_list, 
     masked_chain_length_list_list, chain_M_pos, omit_AA_mask, 
     residue_idx, dihedral_mask, tied_pos_list_of_lists_list, 
     pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta
     ) = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, 
                        omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)
    
    pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
    name_ = batch_clones[0]['name']

    randn_1 = torch.randn(chain_M.shape, device=X.device)
    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
    mask_for_loss = mask*chain_M*chain_M_pos
    scores = _scores(S, log_probs, mask_for_loss)
    native_score = scores.cpu().data.numpy()

    for temp in temperatures:
        for j in range(NUM_BATCHES):
            randn_2 = torch.randn(chain_M.shape, device=X.device)
            if tied_positions_dict == None:
                sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, 
                                           temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, 
                                           chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, 
                                           pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), 
                                           pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), 
                                           bias_by_res=bias_by_res_all)
                S_sample = sample_dict["S"] 
            else:
                sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, 
                                                temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, 
                                                chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, 
                                                pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), 
                                                pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), 
                                                tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
                S_sample = sample_dict["S"]
            
            # Compute scores
            log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
            mask_for_loss = mask*chain_M*chain_M_pos
            scores = _scores(S_sample, log_probs, mask_for_loss)
            scores = scores.cpu().data.numpy()
            all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
            all_log_probs_list.append(log_probs.cpu().data.numpy())
            S_sample_list.append(S_sample.cpu().data.numpy())
            for b_ix in range(BATCH_COPIES):
                masked_chain_length_list = masked_chain_length_list_list[b_ix]
                masked_list = masked_list_list[b_ix]
                seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                score = scores[b_ix]
                score_list.append(score)
                native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                if b_ix == 0 and j==0 and temp==temperatures[0]:
                    start = 0
                    end = 0
                    list_of_AAs = []
                    for mask_l in masked_chain_length_list:
                        end += mask_l
                        list_of_AAs.append(native_seq[start:end])
                        start = end
                    native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                    l0 = 0
                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                        l0 += mc_length
                        native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                        l0 += 1
                    sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                    print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                    sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                    print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                    native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                    # line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)
                    line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, ''.join([native_seq[i] for i in T]))
                    print(line.rstrip())

                start = 0
                end = 0
                list_of_AAs = []
                for mask_l in masked_chain_length_list:
                    end += mask_l
                    list_of_AAs.append(seq[start:end])
                    start = end

                seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                l0 = 0
                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                    l0 += mc_length
                    seq = seq[:l0] + '/' + seq[l0:]
                    l0 += 1
                score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                # line = '>T={}, sample={}, batch={}, score={}, seq_recovery={}\n{}\n'.format(temp,j,b_ix,score_print,seq_rec_print,seq)
                line = '>T={}, sample={}, batch={}, score={}, seq_recovery={}\n{}\n'.format(temp,j,b_ix,score_print,seq_rec_print,''.join([seq[i] for i in T]))
                print(line.rstrip())

                _hm = os.path.basename(protein_file).replace('.pdb', '').replace('gxps_ATC_', '')
                header = f'6mfw-{_hm}-{model_name}-temp_{temp}-gen-{j}'
                helper.create_fasta(sequences={header: seq}, file=fasta_file, append=True)


all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
S_sample_concat = np.concatenate(S_sample_list)

Generating sequences...
>hm_6mg0_cB_ATC, score=1.6188, fixed_chains=[], designed_chains=['A'], model_name=v_48_002
QQLAQVWSHVLGIPQMGIDDHFLERGGDSIKVMQLIHQLKNIGLSLRYDQLFTHPTIRQLKRLL
>T=0.3, sample=0, batch=0, score=1.1273, seq_recovery=0.3906
KRLSDVWSEIFGRPDIGLDTHFFDRGGTREKITELIRKLKQLGINISYEEIYKHPTLTLFKKYL
>T=0.3, sample=1, batch=0, score=1.0244, seq_recovery=0.3750
KKLCEVWAEILGRPNVGLDDHFFDLGGTQYRVSVLLRKLKELGIKIKYNEIYKNPTLKKLKKYL
>T=0.3, sample=2, batch=0, score=1.0594, seq_recovery=0.4375
QKLADVWSEVFGRPQMGLDAHFFDLGGTEDRVSVLLRKLRQLGIRIRYEEIYKNPTILKFKKYL
>T=0.3, sample=3, batch=0, score=1.1449, seq_recovery=0.3438
AKLCEIMAGLFGKPHQGLTAHFFDRGGTRELVSELLRALRSEGIRLSYEEIFKNPTIEELKKYL
>T=0.3, sample=4, batch=0, score=1.1134, seq_recovery=0.3906
TELADIWAKVFGVPDMGLKEHYFDRGGTSDRVSVLLRELRGKGFRIKYDEIYKNPTIDALKKYL
>T=0.3, sample=5, batch=0, score=1.0498, seq_recovery=0.4688
ARLCQVWSKLFGVPDMGLRDHFFDRGGTSELVSRLLRELRGLGLRLRYDEIYENPTLEKLRRWL
>T=0.3, sample=6, batch=0, score=1.0773, seq_recovery=0.3594
KEL