In this notebook I want to compare the pseudo-likelihood I obtain using ESM

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from encoded_protein_dataset import EncodedProteinDataset, collate_fn, get_embedding
#from pseudolikelihood import get_npll
import torch
from potts_decoder import PottsDecoder
from torch.utils.data import DataLoader, RandomSampler
from functools import partial
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
import scipy
from Bio import SeqIO

import os
##TURIN HPC
sys.path.insert(1, "/Data/silva/esm/")

## EUROPA
#sys.path.insert(1, "/home/lucasilva/esm/")
import esm
#from esm.inverse_folding import util
import esm.pretrained as pretrained
from ioutils import read_fasta, read_encodings
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


Perplexity calculation:

- We would like to compare our results with the perplexity of the paper of ESM, unfortunately we cannot calculate the likelihood of our predicted model, as that is the central problem of Potts models.
- We would like to assess if we could calculate the pseudo likelihood for the model given by ESM and compare it with our results to get a benchmark to measure against.
- We shall remember that our task is harder, as we don't give the model the exact native sequence, but give a batch of the MSA in which we often don't have the true sequence. 

Ok now we have everything, we can hence write a code that gives us the the pseudolikelihood of esm to get a benchmark for our model!

The main observation is that:
$$ p(y_i|y_{-i}) = \frac{p(y)}{p(y_{-i})} = \frac{p(y)}{\sum_{a}p(y_i=a, y_{-i})}, $$
which, since it involves marginalizing over just one position of the sequence is clearly computationally feasible!

In [14]:
?decoder.forward

[0;31mSignature:[0m
[0mdecoder[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mprev_output_tokens[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mencoder_out[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mincremental_state[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfeatures_only[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_all_hiddens[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;

In [10]:
def get_sequence_loss(model, alphabet, coords, seq, device=None):
    batch_converter = CoordBatchConverter(alphabet)
    batch = [(coords, None, seq)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)
    
    prev_output_tokens = tokens[:, :-1]
    target = tokens[:, 1:]
    target_padding_mask = (target == alphabet.padding_idx)
    logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
    loss = F.cross_entropy(logits, target, reduction='none')
    #loss = loss[0].detach().numpy()
    #loss = loss[0].to('cpu').numpy()
    #target_padding_mask = target_padding_mask[0].numpy()
    return loss, target_padding_mask

def get_sequence_loss_decoder(decoder, alphabet, coords, encodings, seq, device=None):
    batch_converter = CoordBatchConverter(alphabet)
    batch = [(coords, None, seq)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)
    
    prev_output_tokens = tokens[:, :-1]
    target = tokens[:, 1:]
    target_padding_mask = (target == alphabet.padding_idx)
    #encoder_padding_mask = torch.zeros(1, N, dtype=bool).to(device)  ## This will be needed if we batch
    encoder_output = {'encoder_out':[encodings], 'encoder_padding_mask':[]}
    logits, _ = decoder.forward(prev_output_tokens, encoder_output)#encoder_output)
    loss = F.cross_entropy(logits, target, reduction='none')
    return loss, target_padding_mask

class BatchConverter(object):
    """Callable to convert an unprocessed (labels + strings) batch to a
    processed (labels + tensor) batch.
    """

    def __init__(self, alphabet, truncation_seq_length: int = None):
        self.alphabet = alphabet
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        batch_labels, seq_str_list = zip(*raw_batch)
        seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
        if self.truncation_seq_length:
            seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        tokens = torch.empty(
            (
                batch_size,
                max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        strs = []

        for i, (label, seq_str, seq_encoded) in enumerate(
            zip(batch_labels, seq_str_list, seq_encoded_list)
        ):
            labels.append(label)
            strs.append(seq_str)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[
                i,
                int(self.alphabet.prepend_bos) : len(seq_encoded)
                + int(self.alphabet.prepend_bos),
            ] = seq
            if self.alphabet.append_eos:
                tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx

        return labels, strs, tokens


class CoordBatchConverter(BatchConverter):
    def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
        """
        Args:
            raw_batch: List of tuples (coords, confidence, seq)
            In each tuple,
                coords: list of floats, shape L x 3 x 3
                confidence: list of floats, shape L; or scalar float; or None
                seq: string of length L
        Returns:
            coords: Tensor of shape batch_size x L x 3 x 3
            confidence: Tensor of shape batch_size x L
            strs: list of strings
            tokens: LongTensor of shape batch_size x L
            padding_mask: ByteTensor of shape batch_size x L
        """
        self.alphabet.cls_idx = self.alphabet.get_idx("<cath>") 
        batch = []
        for coords, confidence, seq in raw_batch:
            if confidence is None:
                confidence = 1.
            if isinstance(confidence, float) or isinstance(confidence, int):
                confidence = [float(confidence)] * len(coords)
            if seq is None:
                seq = 'X' * len(coords)
            batch.append(((coords, confidence), seq))

        coords_and_confidence, strs, tokens = super().__call__(batch)

        # pad beginning and end of each protein due to legacy reasons
        coords = [
            F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)
            for cd, _ in coords_and_confidence
        ]
        confidence = [
            F.pad(torch.tensor(cf), (1, 1), value=-1.)
            for _, cf in coords_and_confidence
        ]
        coords = self.collate_dense_tensors(coords, pad_v=np.nan)
        confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
        if device is not None:
            coords = coords.to(device)
            confidence = confidence.to(device)
            tokens = tokens.to(device)
        padding_mask = torch.isnan(coords[:,:,0,0])
        coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
        confidence = confidence * coord_mask + (-1.) * padding_mask
        return coords, confidence, strs, tokens, padding_mask

    def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
        """
        Args:
            coords_list: list of length batch_size, each item is a list of
            floats in shape L x 3 x 3 to describe a backbone
            confidence_list: one of
                - None, default to highest confidence
                - list of length batch_size, each item is a scalar
                - list of length batch_size, each item is a list of floats of
                    length L to describe the confidence scores for the backbone
                    with values between 0. and 1.
            seq_list: either None or a list of strings
        Returns:
            coords: Tensor of shape batch_size x L x 3 x 3
            confidence: Tensor of shape batch_size x L
            strs: list of strings
            tokens: LongTensor of shape batch_size x L
            padding_mask: ByteTensor of shape batch_size x L
        """
        batch_size = len(coords_list)
        if confidence_list is None:
            confidence_list = [None] * batch_size
        if seq_list is None:
            seq_list = [None] * batch_size
        raw_batch = zip(coords_list, confidence_list, seq_list)
        return self.__call__(raw_batch, device)

    @staticmethod
    def collate_dense_tensors(samples, pad_v):
        """
        Takes a list of tensors with the following dimensions:
            [(d_11,       ...,           d_1K),
             (d_21,       ...,           d_2K),
             ...,
             (d_N1,       ...,           d_NK)]
        and stack + pads them into a single tensor of:
        (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
        """
        if len(samples) == 0:
            return torch.Tensor()
        if len(set(x.dim() for x in samples)) != 1:
            raise RuntimeError(
                f"Samples has varying dimensions: {[x.dim() for x in samples]}"
            )
        (device,) = tuple(set(x.device for x in samples))  # assumes all on same device
        max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
        result = torch.empty(
            len(samples), *max_shape, dtype=samples[0].dtype, device=device
        )
        result.fill_(pad_v)
        for i in range(len(samples)):
            result_i = result[i]
            t = samples[i]
            result_i[tuple(slice(0, k) for k in t.shape)] = t
        return result
    

In [5]:
def get_npll(msas_embedded, couplings, fields, N, q):
    """ Get negative pseudo log likelihood (npll)
    """

    B, M, _ = msas_embedded.shape

    # (B, M, N*q) x (B, N*q, N*q) + (B, 1, N*q) -> (B, M, N*q) -> (B, M, N, q)
    A = (msas_embedded @ couplings + fields.unsqueeze(1)).view(B, M, N, q)

    # (B, M, N, q) -> (B, M, N)
    Z = torch.logsumexp(A, dim=-1)

    # (B, M, N, q) * (B, M, N, q) -> (B, M, N, q) -> (B, M, N)
    C = torch.sum(A * msas_embedded.view(B, M, N, q), dim=-1)

    # (B, M, N) - (B, M, N) -> (B, M, N)
    pll = C - Z

    return -pll
def load_structure(fpath, chain=None):
    """
    Args:
        fpath: filepath to either pdb or cif file
        chain: the chain id or list of chain ids to load
    Returns:
        biotite.structure.AtomArray
    """
    with open(fpath) as fin:
        pdbf = pdb.PDBFile.read(fin)
    structure = pdb.get_structure(pdbf, model=1)
    bbmask = filter_backbone(structure)
    structure = structure[bbmask]
    all_chains = get_chains(structure)
    if len(all_chains) == 0:
        raise ValueError('No chains found in the input file.')
    if chain is None:
        chain_ids = all_chains
    elif isinstance(chain, list):
        chain_ids = chain
    else:
        chain_ids = [chain] 
    for chain in chain_ids:
        if chain not in all_chains:
            raise ValueError(f'Chain {chain} not found in input file')
    chain_filter = [a.chain_id in chain_ids for a in structure]
    structure = structure[chain_filter]
    return structure

def extract_coords_from_structure(structure: biotite.structure.AtomArray):
    """
    Args:
        structure: An instance of biotite AtomArray
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates
            - seq is the extracted sequence
    """
    coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
    residue_identities = get_residues(structure)[1]
    seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
    return coords, seq

def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
    """
    Example for atoms argument: ["N", "CA", "C"]
    """
    def filterfn(s, axis=None):
        filters = np.stack([s.atom_name == name for name in atoms], axis=1)
        sum = filters.sum(0)
        if not np.all(sum <= np.ones(filters.shape[1])):
            raise RuntimeError("structure has multiple atoms with same name")
        index = filters.argmax(0)
        coords = s[index].coord
        coords[sum == 0] = float("nan")
        return coords

    return biotite.structure.apply_residue_wise(struct, struct, filterfn)

In [12]:
def marginalize(model, coords, native_seq, restr_alphabet, alphabet, i, device=None):
    """ This gives the normalizing factor at different positions, old version which uses the whole model"""
    if native_seq[i] not in restr_alphabet:
        raise ValueError("Error! Character is not in the resitricted dataset.")
    ## I want to calculate the probability for the restricted dataset 
    norm = 0
    N = len(native_seq)
    for char in restr_alphabet:
        mutated_seq = list(native_seq)
        mutated_seq[i] = char
        mutated_seq = "".join(mutated_seq)
        #res = esm.inverse_folding.util.score_sequence(model, alphabet_esm, coords, mutated_seq)[0]
        res=get_sequence_loss(model, alphabet, coords, mutated_seq, device=device)[0]#.item()
        #norm += np.exp(-torch.sum(res).item()/N)
        norm += np.exp(-torch.sum(res).item())
    return norm

def get_loss_esm(model, restr_alphabet, alphabet_esm, coords, native_seq, device=None):
    """Old version using the whole model at every pass"""
    N = coords.shape[0]
    part = 0
    for i in range(N):
        norm = marginalize(model, coords, native_seq, restr_alphabet, alphabet_esm, i, device=device)
        res=get_sequence_loss(model, alphabet_esm, coords, native_seq, device=device)[0]
        #part += np.exp(-torch.sum(res)[0].item()/N)/norm
        part += (-torch.sum(res).item() - np.log(norm))
    return part/N


def marginalize_decoder(decoder, coords, encodings, native_seq, restr_alphabet, alphabet, i, device=None):
    """ This gives the normalizing factor at different positions, new version using just the decoder"""
    if native_seq[i] not in restr_alphabet:
        raise ValueError("Error! Character is not in the resitricted dataset.")
    ## I want to calculate the probability for the restricted dataset 
    norm = 0
    N = len(native_seq)
    for char in restr_alphabet:
        mutated_seq = list(native_seq)
        mutated_seq[i] = char
        mutated_seq = "".join(mutated_seq)
        #res = esm.inverse_folding.util.score_sequence(model, alphabet_esm, coords, mutated_seq)[0]
        res=get_sequence_loss_decoder(decoder, alphabet, coords, encodings, mutated_seq, device=device)[0]#.item()
        #norm += np.exp(-torch.sum(res).item()/N)
        norm += np.exp(-torch.sum(res).item())
    return norm



def get_loss_esm_decoder(decoder, restr_alphabet, alphabet_esm, coords, encodings, native_seq, device=None):
    """New version using just the decoder"""
    N = len(native_seq)
    part = 0
    for i in range(N):
        norm = marginalize_decoder(decoder, coords, encodings, native_seq, restr_alphabet, alphabet_esm, i, device=device)
        res=get_sequence_loss_decoder(decoder, alphabet_esm, coords, encodings, native_seq, device=device)[0]
        #part += np.exp(-torch.sum(res)[0].item()/N)/norm
        part += (-torch.sum(res).item() - np.log(norm))
    return part/N

In [18]:
model, alphabet = pretrained.esm_if1_gvp4_t16_142M_UR50() 
model.eval();

device=0
model.to(device)

ab = 'ACDEFGHIKLMNPQRSTVWY-'


structure_dir = '/Data/christoph/bocconi/dompdb'
encodings_folder = '/Data/InverseFoldingData/structure_encodings'
checks=1
check=0
npll = np.zeros(checks)

#for pdb_name in (os.listdir(structure_dir)):
for encoding_files in (os.listdir(encodings_folder)):
    pdb_name = encoding_files[0:7]
    print(f"We are at iteration {check} out of {checks}", end="\r")
    aux = {}
    pdb_path = os.path.join(structure_dir, pdb_name)
    structure =  load_structure(pdb_path)
    coords, native_seq = extract_coords_from_structure(structure)
    #coords = torch.from_numpy(coords)
    #lppds[check] = np.log(marginalize(model, coords, native_seq, ab, alphabet, 1, device=device))
    npll[check] = get_loss_esm(model, ab, alphabet, coords, native_seq, device=device)
    check+=1
    if check>=checks:
        break
        





We are at iteration 0 out of 1

  warn("{} elements were guessed from atom_name.".format(rep_num))


In [19]:
len(native_seq)

378

In [20]:
npll

array([-1.00190027])

In [21]:
model, alphabet = pretrained.esm_if1_gvp4_t16_142M_UR50() 
decoder = model.decoder
decoder.eval();

device=0
decoder.to(device)

ab = 'ACDEFGHIKLMNPQRSTVWY-'

structure_dir = '/Data/christoph/bocconi/dompdb'
encodings_folder = '/Data/InverseFoldingData/structure_encodings'
checks=1
check=0
npll2 = np.zeros(checks)

for encoding_files in (os.listdir(encodings_folder)):
    id = encoding_files[0:7]
    encoding_path = os.path.join(encodings_folder, encoding_files)
    print(f"We are at iteration {check} out of {checks}", end="\r")
    aux = {}
    pdb_path = os.path.join(structure_dir, id)
    structure =  load_structure(pdb_path)
    coords, native_seq = extract_coords_from_structure(structure)
    encodings = torch.tensor(read_encodings(encoding_path, trim=False))
    encodings = torch.tensor(encodings).unsqueeze(1).to(device)
    
    npll2[check] = get_loss_esm_decoder(decoder, ab, alphabet, coords, encodings, native_seq, device=device)
    check+=1
    break
    #batch_converter = CoordBatchConverter(alphabet)
    #batch = [(coords, None, native_seq)]
    #coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)
    
    #encodings = torch.tensor(read_encodings(encoding_path, trim=False)).to(device)
    #prev_output_tokens = tokens[:, :-1]
    #target = tokens[:, 1:]
    #encodings = torch.tensor(encodings).unsqueeze(1)
    #encoder_padding_mask = torch.zeros(1, N, dtype=bool).to(device)
    #encoder_output = {'encoder_out':[encodings], 'encoder_padding_mask':[encoder_padding_mask]}
    #encoder_padding_mask = torch.zeros(1, N, dtype=bool)
    #logits, _ = decoder.forward(prev_output_tokens, encodings)#encoder_output)
    #loss = F.cross_entropy(logits, target, reduction='none')
    #coords = torch.from_numpy(coords)
    #lppds[check] = np.log(marginalize(model, coords, native_seq, ab, alphabet, 1, device=device))
    #npll[check] = get_loss_esm(model, ab, alphabet, coords, native_seq, device=device)
    #check+=1
    #if check>=checks:
    #    break
        



We are at iteration 0 out of 1

  warn("{} elements were guessed from atom_name.".format(rep_num))
  encodings = torch.tensor(read_encodings(encoding_path, trim=False))
  encodings = torch.tensor(encodings).unsqueeze(1).to(device)


In [58]:
structure =  load_structure(pdb_path)
coords, native_seq = extract_coords_from_structure(structure)

  warn("{} elements were guessed from atom_name.".format(rep_num))


In [59]:
coords.shape

(378, 3, 3)

In [30]:
len(native_seq)

378

In [100]:
encoder = model.encoder
encoder.eval()
encoder.to(device)
decoder.eval()
structure =  load_structure(pdb_path)
coords, native_seq = extract_coords_from_structure(structure)

#encodings = torch.tensor(read_encodings(encoding_path, trim=False))
#encodings = torch.tensor(encodings).unsqueeze(1).to(device)
batch_converter = CoordBatchConverter(alphabet)
batch = [(coords, None, native_seq)]
coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)

encoder_output = encoder.forward(coords, padding_mask, confidence)

prev_output_tokens = tokens[:, :-1]
target = tokens[:, 1:]
target_padding_mask = (target == alphabet.padding_idx)
#encoder_padding_mask = torch.zeros(1, N, dtype=bool).to(device)  ## This will be needed if we batch
#encoder_output = {'encoder_out':[encodings], 'encoder_padding_mask':[]}
logits, _ = decoder.forward(prev_output_tokens, encoder_output)#encoder_output)


In [86]:
#encoder_output['encoder_out'][0].shape

In [101]:
#model.to(device)
model.eval()
structure =  load_structure(pdb_path)
coords, native_seq = extract_coords_from_structure(structure)
batch_converter = CoordBatchConverter(alphabet)
batch = [(coords, None, native_seq)]
coords, confidence, strs, tokens, padding_mask = batch_converter(batch, device=device)

prev_output_tokens = tokens[:, :-1]
target = tokens[:, 1:]
target_padding_mask = (target == alphabet.padding_idx)
logits2, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)

## PROBLEM
-   The forward method of ESM does add a beginnng/end of sequence, so to make comparison we have to do a forward pass because for our task we do not add this dimension!

In [102]:
encodings.shape

torch.Size([378, 1, 512])

In [103]:
logits[0, 0:5, 0]

tensor([-21.7387, -21.7923, -14.6354, -21.7338,   0.4261], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [104]:
logits2[0, 0:5, 0]

tensor([-21.7387, -21.7923, -14.6354, -21.7338,   0.4261], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [67]:
logits.shape

torch.Size([1, 35, 378])

In [68]:
logits2.shape

torch.Size([1, 35, 378])

In [46]:
loss = F.cross_entropy(logits, target, reduction='none')

In [48]:
loss.shape

torch.Size([1, 378])

In [35]:
encoder_output['encoder_out'].shape

torch.Size([378, 1, 512])

In [29]:
encodings.unsqueeze(1).shape

torch.Size([378, 1, 512])

In [9]:
npll

array([-2.60031303])

In [24]:
decoder = model.decoder
?decoder.forward

[0;31mSignature:[0m
[0mdecoder[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mprev_output_tokens[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mencoder_out[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mincremental_state[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfeatures_only[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_all_hiddens[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;

In [17]:
encoder = model.encoder
?encoder.forward

[0;31mSignature:[0m
[0mencoder[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mcoords[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mencoder_padding_mask[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mconfidence[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_all_hiddens[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Args:
    coords (Tensor): backbone coordinates
        shape batch_size x num_residues x num_atoms (3 for N, CA, C) x 3
    encoder_padding_mask (ByteTensor): the positions of
          padding elements of shape `(batch_size x num_residues)`
    confidence (Tensor): the confidence score of shape (batch_size x
        num_residues). The value is between 0. and 1. for each residue
        coordinate, or -1. if no coordinate is given
    return_all_hiddens (bool, optional): also return all of the
        intermediate hidden states (default