# Load fine-tuned model and test on repo examples

In [None]:
def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])


def contacts_from_pdb(
    structure: bs.AtomArray,
    distance_threshold: float = 15.0,
    chain: Optional[str] = None,
) -> np.ndarray:
    mask = ~structure.hetero
    if chain is not None:
        mask &= structure.chain_id == chain

    N = structure.coord[mask & (structure.atom_name == "N")]
    CA = structure.coord[mask & (structure.atom_name == "CA")]
    C = structure.coord[mask & (structure.atom_name == "C")]

    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    dist = squareform(pdist(Cbeta))
    
    contacts = dist < distance_threshold
    contacts = contacts.astype(np.int64)
    contacts[np.isnan(dist)] = -1
    return contacts

In [None]:
# precision and prediction evaluation

def compute_precisions(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 6,
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  # for casp
):
    if isinstance(outputs, np.ndarray):
        outputs = torch.from_numpy(outputs)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if outputs.dim() == 2:
        outputs = outputs.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    # Check sizes
    if outputs.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {outputs.size()}, "
            f"targets of size {targets.size()}"
        )
    device = outputs.device

    batch_size, seqlen, _ = outputs.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)  # negative targets are invalid

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    
    # Check what the batch size is, comes from "batch_size, seqlen, _ = outputs.size()"
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long) 

    # check the masked_fill function in repo
    outputs = outputs.masked_fill(~valid_mask, float("-inf"))

    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = outputs[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    topk = seqlen if override_length is None else max(seqlen, override_length)
    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(outputs).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    contact_ranges = [
        ("local", 3, 6),
        ("short", 6, 12),
        ("medium", 12, 24),
        ("long", 24, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

In [1]:
# Load the example fastas to compare with repo results

from Bio import SeqIO
import os

def load_fasta_files(folder_path):
    
    fasta_files = [file for file in os.listdir(folder_path) if file.endswith(".fasta")]
    all_fasta_data = {}

    for fasta_file in fasta_files:
        file_path = os.path.join(folder_path, fasta_file)
        with open(file_path, "r") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                all_fasta_data[record.id] = str(record.seq)

    return all_fasta_data

folder_path = r"C:\Users\neil_\Documents\GitHub\contact-prediction\data\Original"
fasta_data_dict = load_fasta_files(folder_path)

fasta_data_dict 

{'1A3A_1|Chains': 'MANLFKLGAENIFLGRKAATKEEAIRFAGEQLVKGGYVEPEYVQAMLDREKLTPTYLGESIAVPHGTVEAKDRVLKTGVVFCQYPEGVRFGEEEDDIARLVIGIAARNNEHIQVITSLTNALDDESVIERLAHTTSVDEVLELLAGRK',
 '1XCR_1|Chains': 'GSACAEFSFHVPSLEELAGVMQKGLKDNFADVQVSVVDCPDLTKEPFTFPVKGICGKTRIAEVGGVPYLLPLVNQKKVYDLNKIAKEIKLPGAFILGAGAGPFQTLGFNSEFMPVIQTESEHKPPVNGSYFAHVNPADGGCLLEKYSEKCHDFQCALLANLFASEGQPGKVIEVKAKRRTGPLNFVTCMRETLEKHYGNKPIGMGGTFIIQKGKVKSHIMPAEFSSCPLNSDEEVNKWLHFYEMKAPLVCLPVFVSRDPGFDLRLEHTHFFSRHGEGGHYHYDTTPDIVEYLGYFLPAEFLYRIDQPKETHSIGRD',
 '5AHW_1|Chains': 'MSAYQTVVVGTDGSDSSLRAVDRAGQIAAASNAKLIIATAYFPQSEDSRAADVLKDEGYKMAGNAPIYAILREANDRAKAAGATDIEERPVVGAPVDALVELADEVKADLLVVGNVGLSTIAGRLLGSVPANVARRSKTDVLIVHTS'}

In [2]:
fasta_data_dict["1A3A_1|Chains"]

'MANLFKLGAENIFLGRKAATKEEAIRFAGEQLVKGGYVEPEYVQAMLDREKLTPTYLGESIAVPHGTVEAKDRVLKTGVVFCQYPEGVRFGEEEDDIARLVIGIAARNNEHIQVITSLTNALDDESVIERLAHTTSVDEVLELLAGRK'

In [27]:
sequences_list = [sequence for sequence in fasta_data_dict.values()]

In [29]:
sequences_list[0]

'MANLFKLGAENIFLGRKAATKEEAIRFAGEQLVKGGYVEPEYVQAMLDREKLTPTYLGESIAVPHGTVEAKDRVLKTGVVFCQYPEGVRFGEEEDDIARLVIGIAARNNEHIQVITSLTNALDDESVIERLAHTTSVDEVLELLAGRK'

In [33]:
indexed_sequences = list(enumerate(sequences_list))

In [34]:
indexed_sequences

[(0,
  'MANLFKLGAENIFLGRKAATKEEAIRFAGEQLVKGGYVEPEYVQAMLDREKLTPTYLGESIAVPHGTVEAKDRVLKTGVVFCQYPEGVRFGEEEDDIARLVIGIAARNNEHIQVITSLTNALDDESVIERLAHTTSVDEVLELLAGRK'),
 (1,
  'GSACAEFSFHVPSLEELAGVMQKGLKDNFADVQVSVVDCPDLTKEPFTFPVKGICGKTRIAEVGGVPYLLPLVNQKKVYDLNKIAKEIKLPGAFILGAGAGPFQTLGFNSEFMPVIQTESEHKPPVNGSYFAHVNPADGGCLLEKYSEKCHDFQCALLANLFASEGQPGKVIEVKAKRRTGPLNFVTCMRETLEKHYGNKPIGMGGTFIIQKGKVKSHIMPAEFSSCPLNSDEEVNKWLHFYEMKAPLVCLPVFVSRDPGFDLRLEHTHFFSRHGEGGHYHYDTTPDIVEYLGYFLPAEFLYRIDQPKETHSIGRD'),
 (2,
  'MSAYQTVVVGTDGSDSSLRAVDRAGQIAAASNAKLIIATAYFPQSEDSRAADVLKDEGYKMAGNAPIYAILREANDRAKAAGATDIEERPVVGAPVDALVELADEVKADLLVVGNVGLSTIAGRLLGSVPANVARRSKTDVLIVHTS')]

In [50]:
import torch
import torch.nn as nn
import torch.optim as optim

class TunedModel(nn.Module):
    def __init__(self):
        super(TunedModel, self).__init__()
        self.esm = esm.pretrained.esm2_t33_650M_UR50D()
        
    def forward(self, batch_tokens):
        outputs = self.esm(batch_tokens, return_contacts=True)
        # contacts = outputs["contacts"][0]
        # sequence_output = outputs["representations"][33]  
        return outputs["contacts"][0]

saved_model_dict = torch.load(r"../contact-prediction/models/trained_model_aw.pth")

model = TunedModel()

model_state_dict = saved_model_dict["model_state_dict"]
optimizer_state_dict = saved_model_dict["optimizer_state_dict"]

#optimizer = optim.AdamW(model.parameters(), lr=0.001)
    
model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict)

model.eval()

RuntimeError: Error(s) in loading state_dict for TunedModel:
	Unexpected key(s) in state_dict: "embed_tokens.weight", "layers.0.self_attn.k_proj.weight", "layers.0.self_attn.k_proj.bias", "layers.0.self_attn.v_proj.weight", "layers.0.self_attn.v_proj.bias", "layers.0.self_attn.q_proj.weight", "layers.0.self_attn.q_proj.bias", "layers.0.self_attn.out_proj.weight", "layers.0.self_attn.out_proj.bias", "layers.0.self_attn.rot_emb.inv_freq", "layers.0.self_attn_layer_norm.weight", "layers.0.self_attn_layer_norm.bias", "layers.0.fc1.weight", "layers.0.fc1.bias", "layers.0.fc2.weight", "layers.0.fc2.bias", "layers.0.final_layer_norm.weight", "layers.0.final_layer_norm.bias", "layers.1.self_attn.k_proj.weight", "layers.1.self_attn.k_proj.bias", "layers.1.self_attn.v_proj.weight", "layers.1.self_attn.v_proj.bias", "layers.1.self_attn.q_proj.weight", "layers.1.self_attn.q_proj.bias", "layers.1.self_attn.out_proj.weight", "layers.1.self_attn.out_proj.bias", "layers.1.self_attn.rot_emb.inv_freq", "layers.1.self_attn_layer_norm.weight", "layers.1.self_attn_layer_norm.bias", "layers.1.fc1.weight", "layers.1.fc1.bias", "layers.1.fc2.weight", "layers.1.fc2.bias", "layers.1.final_layer_norm.weight", "layers.1.final_layer_norm.bias", "layers.2.self_attn.k_proj.weight", "layers.2.self_attn.k_proj.bias", "layers.2.self_attn.v_proj.weight", "layers.2.self_attn.v_proj.bias", "layers.2.self_attn.q_proj.weight", "layers.2.self_attn.q_proj.bias", "layers.2.self_attn.out_proj.weight", "layers.2.self_attn.out_proj.bias", "layers.2.self_attn.rot_emb.inv_freq", "layers.2.self_attn_layer_norm.weight", "layers.2.self_attn_layer_norm.bias", "layers.2.fc1.weight", "layers.2.fc1.bias", "layers.2.fc2.weight", "layers.2.fc2.bias", "layers.2.final_layer_norm.weight", "layers.2.final_layer_norm.bias", "layers.3.self_attn.k_proj.weight", "layers.3.self_attn.k_proj.bias", "layers.3.self_attn.v_proj.weight", "layers.3.self_attn.v_proj.bias", "layers.3.self_attn.q_proj.weight", "layers.3.self_attn.q_proj.bias", "layers.3.self_attn.out_proj.weight", "layers.3.self_attn.out_proj.bias", "layers.3.self_attn.rot_emb.inv_freq", "layers.3.self_attn_layer_norm.weight", "layers.3.self_attn_layer_norm.bias", "layers.3.fc1.weight", "layers.3.fc1.bias", "layers.3.fc2.weight", "layers.3.fc2.bias", "layers.3.final_layer_norm.weight", "layers.3.final_layer_norm.bias", "layers.4.self_attn.k_proj.weight", "layers.4.self_attn.k_proj.bias", "layers.4.self_attn.v_proj.weight", "layers.4.self_attn.v_proj.bias", "layers.4.self_attn.q_proj.weight", "layers.4.self_attn.q_proj.bias", "layers.4.self_attn.out_proj.weight", "layers.4.self_attn.out_proj.bias", "layers.4.self_attn.rot_emb.inv_freq", "layers.4.self_attn_layer_norm.weight", "layers.4.self_attn_layer_norm.bias", "layers.4.fc1.weight", "layers.4.fc1.bias", "layers.4.fc2.weight", "layers.4.fc2.bias", "layers.4.final_layer_norm.weight", "layers.4.final_layer_norm.bias", "layers.5.self_attn.k_proj.weight", "layers.5.self_attn.k_proj.bias", "layers.5.self_attn.v_proj.weight", "layers.5.self_attn.v_proj.bias", "layers.5.self_attn.q_proj.weight", "layers.5.self_attn.q_proj.bias", "layers.5.self_attn.out_proj.weight", "layers.5.self_attn.out_proj.bias", "layers.5.self_attn.rot_emb.inv_freq", "layers.5.self_attn_layer_norm.weight", "layers.5.self_attn_layer_norm.bias", "layers.5.fc1.weight", "layers.5.fc1.bias", "layers.5.fc2.weight", "layers.5.fc2.bias", "layers.5.final_layer_norm.weight", "layers.5.final_layer_norm.bias", "layers.6.self_attn.k_proj.weight", "layers.6.self_attn.k_proj.bias", "layers.6.self_attn.v_proj.weight", "layers.6.self_attn.v_proj.bias", "layers.6.self_attn.q_proj.weight", "layers.6.self_attn.q_proj.bias", "layers.6.self_attn.out_proj.weight", "layers.6.self_attn.out_proj.bias", "layers.6.self_attn.rot_emb.inv_freq", "layers.6.self_attn_layer_norm.weight", "layers.6.self_attn_layer_norm.bias", "layers.6.fc1.weight", "layers.6.fc1.bias", "layers.6.fc2.weight", "layers.6.fc2.bias", "layers.6.final_layer_norm.weight", "layers.6.final_layer_norm.bias", "layers.7.self_attn.k_proj.weight", "layers.7.self_attn.k_proj.bias", "layers.7.self_attn.v_proj.weight", "layers.7.self_attn.v_proj.bias", "layers.7.self_attn.q_proj.weight", "layers.7.self_attn.q_proj.bias", "layers.7.self_attn.out_proj.weight", "layers.7.self_attn.out_proj.bias", "layers.7.self_attn.rot_emb.inv_freq", "layers.7.self_attn_layer_norm.weight", "layers.7.self_attn_layer_norm.bias", "layers.7.fc1.weight", "layers.7.fc1.bias", "layers.7.fc2.weight", "layers.7.fc2.bias", "layers.7.final_layer_norm.weight", "layers.7.final_layer_norm.bias", "layers.8.self_attn.k_proj.weight", "layers.8.self_attn.k_proj.bias", "layers.8.self_attn.v_proj.weight", "layers.8.self_attn.v_proj.bias", "layers.8.self_attn.q_proj.weight", "layers.8.self_attn.q_proj.bias", "layers.8.self_attn.out_proj.weight", "layers.8.self_attn.out_proj.bias", "layers.8.self_attn.rot_emb.inv_freq", "layers.8.self_attn_layer_norm.weight", "layers.8.self_attn_layer_norm.bias", "layers.8.fc1.weight", "layers.8.fc1.bias", "layers.8.fc2.weight", "layers.8.fc2.bias", "layers.8.final_layer_norm.weight", "layers.8.final_layer_norm.bias", "layers.9.self_attn.k_proj.weight", "layers.9.self_attn.k_proj.bias", "layers.9.self_attn.v_proj.weight", "layers.9.self_attn.v_proj.bias", "layers.9.self_attn.q_proj.weight", "layers.9.self_attn.q_proj.bias", "layers.9.self_attn.out_proj.weight", "layers.9.self_attn.out_proj.bias", "layers.9.self_attn.rot_emb.inv_freq", "layers.9.self_attn_layer_norm.weight", "layers.9.self_attn_layer_norm.bias", "layers.9.fc1.weight", "layers.9.fc1.bias", "layers.9.fc2.weight", "layers.9.fc2.bias", "layers.9.final_layer_norm.weight", "layers.9.final_layer_norm.bias", "layers.10.self_attn.k_proj.weight", "layers.10.self_attn.k_proj.bias", "layers.10.self_attn.v_proj.weight", "layers.10.self_attn.v_proj.bias", "layers.10.self_attn.q_proj.weight", "layers.10.self_attn.q_proj.bias", "layers.10.self_attn.out_proj.weight", "layers.10.self_attn.out_proj.bias", "layers.10.self_attn.rot_emb.inv_freq", "layers.10.self_attn_layer_norm.weight", "layers.10.self_attn_layer_norm.bias", "layers.10.fc1.weight", "layers.10.fc1.bias", "layers.10.fc2.weight", "layers.10.fc2.bias", "layers.10.final_layer_norm.weight", "layers.10.final_layer_norm.bias", "layers.11.self_attn.k_proj.weight", "layers.11.self_attn.k_proj.bias", "layers.11.self_attn.v_proj.weight", "layers.11.self_attn.v_proj.bias", "layers.11.self_attn.q_proj.weight", "layers.11.self_attn.q_proj.bias", "layers.11.self_attn.out_proj.weight", "layers.11.self_attn.out_proj.bias", "layers.11.self_attn.rot_emb.inv_freq", "layers.11.self_attn_layer_norm.weight", "layers.11.self_attn_layer_norm.bias", "layers.11.fc1.weight", "layers.11.fc1.bias", "layers.11.fc2.weight", "layers.11.fc2.bias", "layers.11.final_layer_norm.weight", "layers.11.final_layer_norm.bias", "layers.12.self_attn.k_proj.weight", "layers.12.self_attn.k_proj.bias", "layers.12.self_attn.v_proj.weight", "layers.12.self_attn.v_proj.bias", "layers.12.self_attn.q_proj.weight", "layers.12.self_attn.q_proj.bias", "layers.12.self_attn.out_proj.weight", "layers.12.self_attn.out_proj.bias", "layers.12.self_attn.rot_emb.inv_freq", "layers.12.self_attn_layer_norm.weight", "layers.12.self_attn_layer_norm.bias", "layers.12.fc1.weight", "layers.12.fc1.bias", "layers.12.fc2.weight", "layers.12.fc2.bias", "layers.12.final_layer_norm.weight", "layers.12.final_layer_norm.bias", "layers.13.self_attn.k_proj.weight", "layers.13.self_attn.k_proj.bias", "layers.13.self_attn.v_proj.weight", "layers.13.self_attn.v_proj.bias", "layers.13.self_attn.q_proj.weight", "layers.13.self_attn.q_proj.bias", "layers.13.self_attn.out_proj.weight", "layers.13.self_attn.out_proj.bias", "layers.13.self_attn.rot_emb.inv_freq", "layers.13.self_attn_layer_norm.weight", "layers.13.self_attn_layer_norm.bias", "layers.13.fc1.weight", "layers.13.fc1.bias", "layers.13.fc2.weight", "layers.13.fc2.bias", "layers.13.final_layer_norm.weight", "layers.13.final_layer_norm.bias", "layers.14.self_attn.k_proj.weight", "layers.14.self_attn.k_proj.bias", "layers.14.self_attn.v_proj.weight", "layers.14.self_attn.v_proj.bias", "layers.14.self_attn.q_proj.weight", "layers.14.self_attn.q_proj.bias", "layers.14.self_attn.out_proj.weight", "layers.14.self_attn.out_proj.bias", "layers.14.self_attn.rot_emb.inv_freq", "layers.14.self_attn_layer_norm.weight", "layers.14.self_attn_layer_norm.bias", "layers.14.fc1.weight", "layers.14.fc1.bias", "layers.14.fc2.weight", "layers.14.fc2.bias", "layers.14.final_layer_norm.weight", "layers.14.final_layer_norm.bias", "layers.15.self_attn.k_proj.weight", "layers.15.self_attn.k_proj.bias", "layers.15.self_attn.v_proj.weight", "layers.15.self_attn.v_proj.bias", "layers.15.self_attn.q_proj.weight", "layers.15.self_attn.q_proj.bias", "layers.15.self_attn.out_proj.weight", "layers.15.self_attn.out_proj.bias", "layers.15.self_attn.rot_emb.inv_freq", "layers.15.self_attn_layer_norm.weight", "layers.15.self_attn_layer_norm.bias", "layers.15.fc1.weight", "layers.15.fc1.bias", "layers.15.fc2.weight", "layers.15.fc2.bias", "layers.15.final_layer_norm.weight", "layers.15.final_layer_norm.bias", "layers.16.self_attn.k_proj.weight", "layers.16.self_attn.k_proj.bias", "layers.16.self_attn.v_proj.weight", "layers.16.self_attn.v_proj.bias", "layers.16.self_attn.q_proj.weight", "layers.16.self_attn.q_proj.bias", "layers.16.self_attn.out_proj.weight", "layers.16.self_attn.out_proj.bias", "layers.16.self_attn.rot_emb.inv_freq", "layers.16.self_attn_layer_norm.weight", "layers.16.self_attn_layer_norm.bias", "layers.16.fc1.weight", "layers.16.fc1.bias", "layers.16.fc2.weight", "layers.16.fc2.bias", "layers.16.final_layer_norm.weight", "layers.16.final_layer_norm.bias", "layers.17.self_attn.k_proj.weight", "layers.17.self_attn.k_proj.bias", "layers.17.self_attn.v_proj.weight", "layers.17.self_attn.v_proj.bias", "layers.17.self_attn.q_proj.weight", "layers.17.self_attn.q_proj.bias", "layers.17.self_attn.out_proj.weight", "layers.17.self_attn.out_proj.bias", "layers.17.self_attn.rot_emb.inv_freq", "layers.17.self_attn_layer_norm.weight", "layers.17.self_attn_layer_norm.bias", "layers.17.fc1.weight", "layers.17.fc1.bias", "layers.17.fc2.weight", "layers.17.fc2.bias", "layers.17.final_layer_norm.weight", "layers.17.final_layer_norm.bias", "layers.18.self_attn.k_proj.weight", "layers.18.self_attn.k_proj.bias", "layers.18.self_attn.v_proj.weight", "layers.18.self_attn.v_proj.bias", "layers.18.self_attn.q_proj.weight", "layers.18.self_attn.q_proj.bias", "layers.18.self_attn.out_proj.weight", "layers.18.self_attn.out_proj.bias", "layers.18.self_attn.rot_emb.inv_freq", "layers.18.self_attn_layer_norm.weight", "layers.18.self_attn_layer_norm.bias", "layers.18.fc1.weight", "layers.18.fc1.bias", "layers.18.fc2.weight", "layers.18.fc2.bias", "layers.18.final_layer_norm.weight", "layers.18.final_layer_norm.bias", "layers.19.self_attn.k_proj.weight", "layers.19.self_attn.k_proj.bias", "layers.19.self_attn.v_proj.weight", "layers.19.self_attn.v_proj.bias", "layers.19.self_attn.q_proj.weight", "layers.19.self_attn.q_proj.bias", "layers.19.self_attn.out_proj.weight", "layers.19.self_attn.out_proj.bias", "layers.19.self_attn.rot_emb.inv_freq", "layers.19.self_attn_layer_norm.weight", "layers.19.self_attn_layer_norm.bias", "layers.19.fc1.weight", "layers.19.fc1.bias", "layers.19.fc2.weight", "layers.19.fc2.bias", "layers.19.final_layer_norm.weight", "layers.19.final_layer_norm.bias", "layers.20.self_attn.k_proj.weight", "layers.20.self_attn.k_proj.bias", "layers.20.self_attn.v_proj.weight", "layers.20.self_attn.v_proj.bias", "layers.20.self_attn.q_proj.weight", "layers.20.self_attn.q_proj.bias", "layers.20.self_attn.out_proj.weight", "layers.20.self_attn.out_proj.bias", "layers.20.self_attn.rot_emb.inv_freq", "layers.20.self_attn_layer_norm.weight", "layers.20.self_attn_layer_norm.bias", "layers.20.fc1.weight", "layers.20.fc1.bias", "layers.20.fc2.weight", "layers.20.fc2.bias", "layers.20.final_layer_norm.weight", "layers.20.final_layer_norm.bias", "layers.21.self_attn.k_proj.weight", "layers.21.self_attn.k_proj.bias", "layers.21.self_attn.v_proj.weight", "layers.21.self_attn.v_proj.bias", "layers.21.self_attn.q_proj.weight", "layers.21.self_attn.q_proj.bias", "layers.21.self_attn.out_proj.weight", "layers.21.self_attn.out_proj.bias", "layers.21.self_attn.rot_emb.inv_freq", "layers.21.self_attn_layer_norm.weight", "layers.21.self_attn_layer_norm.bias", "layers.21.fc1.weight", "layers.21.fc1.bias", "layers.21.fc2.weight", "layers.21.fc2.bias", "layers.21.final_layer_norm.weight", "layers.21.final_layer_norm.bias", "layers.22.self_attn.k_proj.weight", "layers.22.self_attn.k_proj.bias", "layers.22.self_attn.v_proj.weight", "layers.22.self_attn.v_proj.bias", "layers.22.self_attn.q_proj.weight", "layers.22.self_attn.q_proj.bias", "layers.22.self_attn.out_proj.weight", "layers.22.self_attn.out_proj.bias", "layers.22.self_attn.rot_emb.inv_freq", "layers.22.self_attn_layer_norm.weight", "layers.22.self_attn_layer_norm.bias", "layers.22.fc1.weight", "layers.22.fc1.bias", "layers.22.fc2.weight", "layers.22.fc2.bias", "layers.22.final_layer_norm.weight", "layers.22.final_layer_norm.bias", "layers.23.self_attn.k_proj.weight", "layers.23.self_attn.k_proj.bias", "layers.23.self_attn.v_proj.weight", "layers.23.self_attn.v_proj.bias", "layers.23.self_attn.q_proj.weight", "layers.23.self_attn.q_proj.bias", "layers.23.self_attn.out_proj.weight", "layers.23.self_attn.out_proj.bias", "layers.23.self_attn.rot_emb.inv_freq", "layers.23.self_attn_layer_norm.weight", "layers.23.self_attn_layer_norm.bias", "layers.23.fc1.weight", "layers.23.fc1.bias", "layers.23.fc2.weight", "layers.23.fc2.bias", "layers.23.final_layer_norm.weight", "layers.23.final_layer_norm.bias", "layers.24.self_attn.k_proj.weight", "layers.24.self_attn.k_proj.bias", "layers.24.self_attn.v_proj.weight", "layers.24.self_attn.v_proj.bias", "layers.24.self_attn.q_proj.weight", "layers.24.self_attn.q_proj.bias", "layers.24.self_attn.out_proj.weight", "layers.24.self_attn.out_proj.bias", "layers.24.self_attn.rot_emb.inv_freq", "layers.24.self_attn_layer_norm.weight", "layers.24.self_attn_layer_norm.bias", "layers.24.fc1.weight", "layers.24.fc1.bias", "layers.24.fc2.weight", "layers.24.fc2.bias", "layers.24.final_layer_norm.weight", "layers.24.final_layer_norm.bias", "layers.25.self_attn.k_proj.weight", "layers.25.self_attn.k_proj.bias", "layers.25.self_attn.v_proj.weight", "layers.25.self_attn.v_proj.bias", "layers.25.self_attn.q_proj.weight", "layers.25.self_attn.q_proj.bias", "layers.25.self_attn.out_proj.weight", "layers.25.self_attn.out_proj.bias", "layers.25.self_attn.rot_emb.inv_freq", "layers.25.self_attn_layer_norm.weight", "layers.25.self_attn_layer_norm.bias", "layers.25.fc1.weight", "layers.25.fc1.bias", "layers.25.fc2.weight", "layers.25.fc2.bias", "layers.25.final_layer_norm.weight", "layers.25.final_layer_norm.bias", "layers.26.self_attn.k_proj.weight", "layers.26.self_attn.k_proj.bias", "layers.26.self_attn.v_proj.weight", "layers.26.self_attn.v_proj.bias", "layers.26.self_attn.q_proj.weight", "layers.26.self_attn.q_proj.bias", "layers.26.self_attn.out_proj.weight", "layers.26.self_attn.out_proj.bias", "layers.26.self_attn.rot_emb.inv_freq", "layers.26.self_attn_layer_norm.weight", "layers.26.self_attn_layer_norm.bias", "layers.26.fc1.weight", "layers.26.fc1.bias", "layers.26.fc2.weight", "layers.26.fc2.bias", "layers.26.final_layer_norm.weight", "layers.26.final_layer_norm.bias", "layers.27.self_attn.k_proj.weight", "layers.27.self_attn.k_proj.bias", "layers.27.self_attn.v_proj.weight", "layers.27.self_attn.v_proj.bias", "layers.27.self_attn.q_proj.weight", "layers.27.self_attn.q_proj.bias", "layers.27.self_attn.out_proj.weight", "layers.27.self_attn.out_proj.bias", "layers.27.self_attn.rot_emb.inv_freq", "layers.27.self_attn_layer_norm.weight", "layers.27.self_attn_layer_norm.bias", "layers.27.fc1.weight", "layers.27.fc1.bias", "layers.27.fc2.weight", "layers.27.fc2.bias", "layers.27.final_layer_norm.weight", "layers.27.final_layer_norm.bias", "layers.28.self_attn.k_proj.weight", "layers.28.self_attn.k_proj.bias", "layers.28.self_attn.v_proj.weight", "layers.28.self_attn.v_proj.bias", "layers.28.self_attn.q_proj.weight", "layers.28.self_attn.q_proj.bias", "layers.28.self_attn.out_proj.weight", "layers.28.self_attn.out_proj.bias", "layers.28.self_attn.rot_emb.inv_freq", "layers.28.self_attn_layer_norm.weight", "layers.28.self_attn_layer_norm.bias", "layers.28.fc1.weight", "layers.28.fc1.bias", "layers.28.fc2.weight", "layers.28.fc2.bias", "layers.28.final_layer_norm.weight", "layers.28.final_layer_norm.bias", "layers.29.self_attn.k_proj.weight", "layers.29.self_attn.k_proj.bias", "layers.29.self_attn.v_proj.weight", "layers.29.self_attn.v_proj.bias", "layers.29.self_attn.q_proj.weight", "layers.29.self_attn.q_proj.bias", "layers.29.self_attn.out_proj.weight", "layers.29.self_attn.out_proj.bias", "layers.29.self_attn.rot_emb.inv_freq", "layers.29.self_attn_layer_norm.weight", "layers.29.self_attn_layer_norm.bias", "layers.29.fc1.weight", "layers.29.fc1.bias", "layers.29.fc2.weight", "layers.29.fc2.bias", "layers.29.final_layer_norm.weight", "layers.29.final_layer_norm.bias", "layers.30.self_attn.k_proj.weight", "layers.30.self_attn.k_proj.bias", "layers.30.self_attn.v_proj.weight", "layers.30.self_attn.v_proj.bias", "layers.30.self_attn.q_proj.weight", "layers.30.self_attn.q_proj.bias", "layers.30.self_attn.out_proj.weight", "layers.30.self_attn.out_proj.bias", "layers.30.self_attn.rot_emb.inv_freq", "layers.30.self_attn_layer_norm.weight", "layers.30.self_attn_layer_norm.bias", "layers.30.fc1.weight", "layers.30.fc1.bias", "layers.30.fc2.weight", "layers.30.fc2.bias", "layers.30.final_layer_norm.weight", "layers.30.final_layer_norm.bias", "layers.31.self_attn.k_proj.weight", "layers.31.self_attn.k_proj.bias", "layers.31.self_attn.v_proj.weight", "layers.31.self_attn.v_proj.bias", "layers.31.self_attn.q_proj.weight", "layers.31.self_attn.q_proj.bias", "layers.31.self_attn.out_proj.weight", "layers.31.self_attn.out_proj.bias", "layers.31.self_attn.rot_emb.inv_freq", "layers.31.self_attn_layer_norm.weight", "layers.31.self_attn_layer_norm.bias", "layers.31.fc1.weight", "layers.31.fc1.bias", "layers.31.fc2.weight", "layers.31.fc2.bias", "layers.31.final_layer_norm.weight", "layers.31.final_layer_norm.bias", "layers.32.self_attn.k_proj.weight", "layers.32.self_attn.k_proj.bias", "layers.32.self_attn.v_proj.weight", "layers.32.self_attn.v_proj.bias", "layers.32.self_attn.q_proj.weight", "layers.32.self_attn.q_proj.bias", "layers.32.self_attn.out_proj.weight", "layers.32.self_attn.out_proj.bias", "layers.32.self_attn.rot_emb.inv_freq", "layers.32.self_attn_layer_norm.weight", "layers.32.self_attn_layer_norm.bias", "layers.32.fc1.weight", "layers.32.fc1.bias", "layers.32.fc2.weight", "layers.32.fc2.bias", "layers.32.final_layer_norm.weight", "layers.32.final_layer_norm.bias", "contact_head.regression.weight", "contact_head.regression.bias", "emb_layer_norm_after.weight", "emb_layer_norm_after.bias", "lm_head.weight", "lm_head.bias", "lm_head.dense.weight", "lm_head.dense.bias", "lm_head.layer_norm.weight", "lm_head.layer_norm.bias". 

In [40]:
# load saved finetuned model
import torch
model = torch.load(r"../contact-prediction/models/trained_model_aw.pth")

In [46]:
model

{'model_state_dict': OrderedDict([('embed_tokens.weight',
               tensor([[ 0.0536, -0.0583, -0.1300,  ..., -0.2219,  0.1180, -0.0660],
                       [ 0.0614,  0.0292, -0.1028,  ..., -0.0330,  0.0668, -0.0575],
                       [-0.0895, -0.0448, -0.0575,  ..., -0.0585,  0.0651, -0.1090],
                       ...,
                       [ 0.0023,  0.0143,  0.0514,  ..., -0.0193,  0.0112,  0.0053],
                       [ 0.0731,  0.0470,  0.0346,  ...,  0.1118,  0.0465, -0.0243],
                       [ 0.0492,  0.0254, -0.1112,  ..., -0.0257,  0.0599, -0.0614]])),
              ('layers.0.self_attn.k_proj.weight',
               tensor([[-0.0165,  0.0163,  0.0124,  ...,  0.0110,  0.0016, -0.0176],
                       [-0.0015, -0.0213,  0.0479,  ..., -0.0359,  0.0135,  0.0209],
                       [-0.0050,  0.0462, -0.0268,  ..., -0.0402, -0.0085, -0.0233],
                       ...,
                       [-0.0136,  0.0134, -0.0017,  ...,  0.0211,  

In [42]:
import esm
esm2, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()

In [43]:
batch_labels, batch_strs, batch_tokens = batch_converter(indexed_sequences)

In [51]:
device = torch.device("cpu")

outputs = model(batch_tokens.to(device)) # return_contacts=True

TypeError: 'tuple' object is not callable

In [None]:
#N = len(fasta_data_dict["1A3A_1|Chains"])

fig, ax = plt.subplots(1, 2, figsize=(8, 3))
im = ax[0].imshow(outputs_conts[0])
fig.colorbar(im)
ax[0].set_title("Predicted")

In [None]:
outputs = torch.tensor(np.array(outputs_conts[0]))
targets_met = torch.tensor(np.array(rand_target["dist"]))

example_metrics = evaluate_prediction(outputs_met, targets_met)

print(f"Metrics for: {rand_example}, {rand_target['seq']}")
for key, value in example_metrics.items():
    print(f"{key}: {value}")  