In [None]:
import torch
import esm
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import pandas as pd
import numpy as np
import os

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

In [None]:
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLGGGGGPPPPQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[6], return_contacts=True)
token_representations = results["representations"][6]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

In [None]:
print(len(data['results']))

In [None]:
from Bio import SeqIO
from tqdm import tqdm
import os

def filter_fasta(input_file, output_file, min_length=10, tax_filter=None, keyword=None, max_len=2000, max_writes=100000):
    num_records = 0
    i = 0
    write_every = max(1, 69000000 // max_writes)

    with open(output_file, "w") as out_handle:
        for record in tqdm(SeqIO.parse(input_file, "fasta")):
            i += 1
            desc = record.description

            # Apply filters
            if len(record.seq) < min_length:
                continue
            if len(record.seq) > max_len:
                continue
            if tax_filter and f"Tax={tax_filter}" not in desc:
                continue
            if keyword and keyword.lower() not in desc.lower():
                continue

            # Throttle writes to limit total
            if i % write_every != 0:
                continue

            SeqIO.write(record, out_handle, "fasta")
            print("writing", record.id)
            num_records += 1

            if num_records >= max_writes:
                break

In [None]:
# # filter large file to smaller one
# filter_fasta("../data/uniref50.fasta", "filtered_len_less_100.fasta")

In [None]:
def fasta_batches(file_path, batch_size=1):
    """
    Generator that yields batches of records from a FASTA file.
    Args:
        file_path (_type_): _description_
        batch_size (int, optional): _description_. Defaults to 1.

    """
    with open(file_path, "r") as handle:
        record_iter = SeqIO.parse(handle, "fasta")
        while True:
            batch = list(next(record_iter) for _ in range(batch_size))
            if not batch:
                break
            yield batch

In [None]:
#batch_1 = fasta_batches("../data/uniref50.fasta", 1)
batch_1 = fasta_batches("filtered_len_less_100.fasta", batch_size=10)
data_1 = []
for batch in batch_1:
    print(batch)
    for record in batch:
        print(record)
        print(f"ID: {record.id}")
        print(f"Description: {record.description}")
        print(f"Sequence: {record.seq}\n")
        data_1.append((record.id, str(record.seq)))
    break

In [None]:
batch_labels, batch_strs, batch_tokens = batch_converter(data_1)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[6], return_contacts=True)
token_representations = results["representations"][6]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data_1, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()