In [2]:
import torch
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sequence_models.utils import parse_fasta

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

def read_fasta(data_path: str,
               sep: str =" ",
               ignore_labels = False,
               n_seqs: int = None,
               max_seq_length: int = None
               ):
    """
    Reads a FASTA file and returns a list of tuples containing sequences, ids, and labels.
    """

    n_seqs = n_seqs if n_seqs is not None else float("inf")
    max_seq_length = max_seq_length if max_seq_length is not None else float("inf")

    sequences_with_ids_and_labels = []
    for idx,record in enumerate(SeqIO.parse(data_path, "fasta")):
        # If the sequence is too long, skip it
        if len(record.seq) > max_seq_length:
            continue

        sequence = str(record.seq)
        sequence_id = record.id


        # always return dummy labels unless we are not ignoring the labels and the labels are present
        labels = []
        has_labels = False
        
        # labels[0] contains the sequence ID, and the rest of the labels are GO terms.
        temp = record.description.split(sep)[1:] 
        has_labels = len(temp) > 0

        if has_labels and not ignore_labels:
            labels = temp

        # Return a tuple of sequence, sequence_id, and labels
        sequences_with_ids_and_labels.append((sequence, sequence_id, labels))

        if len(sequences_with_ids_and_labels) >= n_seqs:
            break

    return sequences_with_ids_and_labels, has_labels


def save_to_fasta(sequence_id_labels_tuples,
                  output_file,
                  no_annotations = False):
    """
    Save a list of tuples in the form (sequence, [labels]) to a FASTA file.

    :param sequence_label_tuples: List of tuples containing sequences and labels
    :param output_file: Path to the output FASTA file
    """
    records = []


    for _, (
        sequence,
        id,
        labels,
    ) in enumerate(sequence_id_labels_tuples):
        # Create a description from labels, joined by space
        if no_annotations:
            description = ""
        else:
            description = " ".join(labels)

        record = SeqRecord(Seq(sequence), id=id, description=description)
        records.append(record)

    # Write the SeqRecord objects to a FASTA file
    with open(output_file, "w") as output_handle:
        SeqIO.write(records, output_handle, "fasta")
        print("Saved FASTA file to " + output_file)

#Create small sample faste for testing
uniref_sample,_ =read_fasta('dayhoffdata/uniref50_202401/consensus.fasta',
            n_seqs=10,
            max_seq_length=2_048
            )
save_to_fasta(uniref_sample, output_file='dayhoffdata/uniref50_202401/consensus_sample.fasta')
save_to_fasta(uniref_sample, output_file='dayhoffdata/uniref50_202401/consensus_sample_no_annotations.fasta',no_annotations=True)


Saved FASTA file to dayhoffdata/uniref50_202401/consensus_sample.fasta
Saved FASTA file to dayhoffdata/uniref50_202401/consensus_sample_no_annotations.fasta


In [16]:
seqs, seq_names = parse_fasta('dayhoffdata/uniref50_202401/consensus_sample.fasta',return_names=True)
seqs = {name.split()[0]:seq for name,seq in zip(seq_names,seqs)}
selected = ['UniRef50_A0A401TRQ8','UniRef50_A0A7R8YPT0']
selected_seqs = {k:seqs[k] for k in selected}

In [None]:
#Running ESM Fold code

#TODO: This is probably running in CPU.

from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers import AutoTokenizer, EsmForProteinFolding

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
model.eval()
model = model.cuda()
model.esm = model.esm.half() #switch stem to half precision
torch.backends.cuda.matmul.allow_tf32 = True #allower TensorFloat32 computation if HW supports it.
model.trunk.set_chunk_size(64) # reduce chunk size of folding trunk. Less memory but slower.

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")


model_name = "esmfold"
for seq_id,seq in selected_seqs.items():
    inputs = tokenizer([seq], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() 
    with torch.no_grad():
        outputs = model(inputs)
    pdb = convert_outputs_to_pdb(outputs)

    with open(f"dayhoffdata/uniref50_202401/pdb/{model_name}_{seq_id}.pdb", "w") as f:
        f.write("".join(pdb))

In [None]:
# python fidelity.py --path_to_input_fasta ../dayhoffdata/uniref50_202401/consensus_sample.fasta --output_path ../dayhoffdata/uniref50_202401/ --fold_method omegafold --subbatch_size 20 --restart