In this notebook I am going to try to create the template for evaluating the stability scores of the predicted structures of the low temperatures samples from both ESM and Potts.

In [1]:
import pyrosetta
from collections import defaultdict
import torch
import numpy as np
import scipy
import os
import sys
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
from Bio import SeqIO

git_folder = '/mnt/c/Users/Luca/OneDrive/Phd/Second_year/research/Feinauer/InverseFolding'
esm_folder = '/mnt/c/Users/Luca/OneDrive/Phd/Second_year/research/Feinauer/esm/'
sys.path.insert(1, os.path.join(git_folder, 'model'))
sys.path.insert(1, os.path.join(git_folder, 'util'))

sys.path.insert(1, esm_folder)
from potts_decoder import PottsDecoder
from ioutils import read_fasta, read_encodings
import esm
#from esm.inverse_folding import util
import esm.pretrained as pretrained
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt


In [2]:
def load_structure(fpath, chain=None):
    """
    Args:
        fpath: filepath to either pdb or cif file
        chain: the chain id or list of chain ids to load
    Returns:
        biotite.structure.AtomArray
    """
    with open(fpath) as fin:
        pdbf = pdb.PDBFile.read(fin)
    structure = pdb.get_structure(pdbf, model=1)
    bbmask = filter_backbone(structure)
    structure = structure[bbmask]
    all_chains = get_chains(structure)
    if len(all_chains) == 0:
        raise ValueError('No chains found in the input file.')
    if chain is None:
        chain_ids = all_chains
    elif isinstance(chain, list):
        chain_ids = chain
    else:
        chain_ids = [chain] 
    for chain in chain_ids:
        if chain not in all_chains:
            raise ValueError(f'Chain {chain} not found in input file')
    chain_filter = [a.chain_id in chain_ids for a in structure]
    structure = structure[chain_filter]
    return structure

def extract_coords_from_structure(structure: biotite.structure.AtomArray):
    """
    Args:
        structure: An instance of biotite AtomArray
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates
            - seq is the extracted sequence
    """
    coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
    residue_identities = get_residues(structure)[1]
    seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
    return coords, seq

def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
    """
    Example for atoms argument: ["N", "CA", "C"]
    """
    def filterfn(s, axis=None):
        filters = np.stack([s.atom_name == name for name in atoms], axis=1)
        sum = filters.sum(0)
        if not np.all(sum <= np.ones(filters.shape[1])):
            raise RuntimeError("structure has multiple atoms with same name")
        index = filters.argmax(0)
        coords = s[index].coord
        coords[sum == 0] = float("nan")
        return coords

    return biotite.structure.apply_residue_wise(struct, struct, filterfn)

def get_loss_new(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    param_embeddings, fields = decoder.forward_new(encodings, padding_mask)
    msas_embedded = embedding(msas)

    # get npll
    npll = get_npll2(msas_embedded, param_embeddings, fields, N, q)
    padding_mask_inv = (~padding_mask)
    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    
    Q = torch.einsum('bkuia, buhia->bkhia', 
                param_embeddings.unsqueeze(2), param_embeddings.unsqueeze(1)).sum(axis=-1)
    penalty = eta*(torch.sum(torch.sum(Q,axis=-1)**2) - torch.sum(Q**2) + torch.sum(fields**2))/B
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 

def get_loss(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    couplings, fields = decoder(encodings, padding_mask)

    # embed and reshape to (B, M, N*q)
    msas_embedded = embedding(msas).view(B, M, -1)

    # get npll
    npll = get_npll(msas_embedded, couplings, fields, N, q)
    padding_mask_inv = (~padding_mask)

    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    penalty = eta*(torch.sum(couplings**2) + torch.sum(fields**2))/B

    # the padding mask does not contain the msa dimension so we need to multiply by M
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 
    #return loss_penalty

def get_loss_loader(decoder, loader, eta):

    decoder.eval()
    losses = 0
    iterator = 0
    with torch.no_grad():
        for inputs in loader:
            iterator+=1
            _, npll = get_loss_new(decoder, inputs, eta) 
            losses+=npll
    
    return losses/iterator

def compute_covariance(msa, q):
    """
    Compute covariance matrix of a given MSA having q different amino acids
    """
    M, N = msa.shape

    # One hot encode classes and reshape to create data matrix
    D = torch.flatten(one_hot(msa, num_classes=q), start_dim=1).to(torch.float32)

    # Remove one amino acid
    D = D.view(M, N, q)[:, :, :q-1].flatten(1)

    # Compute bivariate frequencies
    bivariate_freqs = D.T @ D / M
    
    # Compute product of univariate frequencies
    univariate_freqs = torch.diagonal(bivariate_freqs).view(N*(q-1), 1) @ torch.diagonal(bivariate_freqs).view(1, N*(q-1))

    return bivariate_freqs - univariate_freqs

## Generate low-temperature samples from Potts Decoder

In [12]:
auxiliary_model_dir = '/mnt/d/Data/InverseFolding/Auxiliary_Potts'
out_dir = '/mnt/d/Data/InverseFolding/Auxiliary_Samples_Potts'
out_file = 'samples_lowT.txt'
samples_path = os.path.join(auxiliary_model_dir, "YAP1_HUMAN_couplings_fields.txt")
## The ! creates a terminal command, to pass variable you need to put square brackets
!(bmdca_sample -p {samples_path} -d {out_dir} -o {out_file} -n 100 -c bmdca_config)

initializing sampler... 0.0110323 sec

sampling model with mcmc... 1.15917 sec
updating mcmc stats with samples... 0.00599692 sec
computing sequence energies and correlations... 0.0126774 sec
decreasing wait time to 600
writing final sequences... done


In [13]:
alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items()))

In [14]:
file='samples_lowT_numerical.txt'
with open(os.path.join(out_dir,file), mode='r') as f:
    lines=f.readlines()

char_seq = [] ##36 is the lenght of YAP

for i in range(1, len(lines)):
    line = lines[i][0:-1].split(" ") ## I take out the end of file
    line_char = [aa_index_inv[int(idx)] for idx in line]
    char_seq.append(line_char)
    
## Now re-translate
for prot_idx in range(len(char_seq)):
    for aa in range(len(char_seq[prot_idx])):
        char_seq[prot_idx][aa] = aa_index[char_seq[prot_idx][aa]]

In [17]:
np.array(char_seq)[::,4]

array([12, 12, 12, 20, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
       17, 12, 12, 12, 12, 12, 12, 14, 12, 12, 20,  8, 12, 12,  5, 12, 14,
       12, 20, 20, 20, 12, 12, 12, 12, 12, 12, 12, 18, 12, 12, 20, 12, 12,
       20, 12,  9, 12, 12, 12, 12, 12, 20,  3, 12, 12, 12, 12, 12, 12, 12,
       12, 12,  0, 20,  4,  0, 10, 20, 12, 12, 12, 12, 12, 12, 12,  7, 20,
       12,  7,  4, 12, 12,  1, 20, 12, 20,  1, 12, 20, 12, 12, 12, 12, 12,
       12, 12, 12, 12, 12, 12, 20, 12, 12, 12, 12, 20, 12,  5, 13, 20, 12,
       12, 12, 12,  3, 12, 20, 12, 10, 12, 12, 12, 12,  0, 12, 17, 16, 17,
        6, 19, 12, 12,  3, 12, 18, 20, 12, 12, 12, 12, 12, 12, 12, 12, 12,
       12, 20, 12, 12, 12, 12, 20, 12, 12, 12,  0, 12, 12, 12, 12, 20, 12,
       12, 20, 12, 12, 12, 12, 20, 10, 12, 12, 20,  2, 12, 12, 12, 19, 18,
       12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
       12, 12, 12, 12, 12, 12, 12, 12,  6, 12, 12, 12, 12, 20, 20, 12, 12,
       12, 12, 12,  8, 12

In [52]:
np.array(char_seq)[99]

array([20, 20, 20, 20, 12, 12,  5, 18,  3, 14, 17, 16, 16,  2,  2,  5, 14,
        9, 19, 19, 19,  2,  6, 14, 16,  5, 16, 16, 16, 18, 19,  2, 12, 20,
       20, 20])

...... still have to insert AF2 part

## ESM sampling

In [129]:
mutational_dir = '/mnt/d/Data/InverseFolding/Mutational_Data'
msas_folder = '/mnt/d/Data/InverseFolding/Mutational_Data/alphafold_results_wildtype'

protein_original_DMS = 'YAP1_HUMAN_1_b0.5.a2m.wildtype.fasta'
structure_name = 'YAP1_HUMAN_1_b0.5.a2m_unrelaxed_rank_1_model_5.pdb'

folder_fasta = os.path.join(mutational_dir, 'alignments')
native_path = os.path.join(folder_fasta, protein_original_DMS)
structure_folder = os.path.join(mutational_dir, 'alphafold_results_wildtype')
structure_path = os.path.join(structure_folder, structure_name)

num_seq = read_fasta(native_path, mutated_exp=True)
structure =  load_structure(structure_path)
coords, native_seq = extract_coords_from_structure(structure)

In [130]:
model, alphabet = pretrained.esm_if1_gvp4_t16_142M_UR50() 
model.eval();
rep = esm.inverse_folding.util.get_encoder_output(model, alphabet, coords)



In [132]:
samples_esm = []
for attempt in range(100):
    print(f"We are at sample {attempt} out of {10000}", end="\r")
    sample = model.sample(coords, temperature=0.1)
    seq_num = []
    for char in sample:
        seq_num.append(aa_index[char])
    samples_esm.append(seq_num)
    #samples_esm.append(model.sample(coords))

We are at sample 99 out of 10000

In [142]:
np.array(samples_esm)[:,35]

array([8, 8, 3, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
       8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])