In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install pandas==1.3.0 --user
!pip install transformers
!pip install torch
!pip install accelerate

# export XRT_TPU_CONFIG="tpu_worker;0;10.128.0.4:8470"

from transformers import AutoTokenizer, EsmForProteinFolding
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

import gc
gc.collect()

model = model.to(device)
model.trunk.set_chunk_size(64)

def tokenized_sequences(sequences, fasta_or_csv) :
    """
    The function takes as an input a either a multifasta file or a dataframe with two columns.
    If the input is a dataframe, the shape would consist of two columns with :
    - 'id', which corresponds to the protein name
    - 'sequence', which corresponds to the aa sequence
    The function returns a list of tuples (a,b) with a as the id and b as the tokenized inputs
    """
    if fasta_or_csv == "csv" :
        dico_seq = {}
        for i, row in sequences.iterrows():
            dico_seq[row["id"]] =  row["sequence"]
    elif fasta_or_csv == "fasta" :
        from Bio import SeqIO
        dico_seq = {record.id : str(record.seq) for record in SeqIO.parse(sequences, "fasta")}
    tokenized_sequences = []
    for idd in dico_seq :
        tokenized_input = tokenizer(dico_seq[idd], return_tensors="pt", add_special_tokens=False)['input_ids']
        a = (idd , tokenized_input)
        tokenized_sequences.append(a)
    return tokenized_sequences

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

def esmfold_prediction(tokenized_sequences , path_out) :
    """
    The function takes as an input :
    - 'tokenized_sequences', the output of the function tokenize_fasta
    - 'path_out', the path of the directory when the pdb files are to be written
    The function generates the pdb files in the path_out
        
    """
    for protein in tokenized_sequences :
        pdb_files = []
        output = ""
        with torch.no_grad():
            output = model(protein[1].to(device))
        pdb_txt = convert_outputs_to_pdb(output)
        with open(f"{path_out}/{protein[0]}.pdb" ,"w") as outfile :
            outfile.write(pdb_txt[0])
        torch.cuda.empty_cache()
            
import pandas as pd
import os

path_data = "/home/robbyconchaeloko/DpoK-serotypeTropism/data"

df = pd.read_csv(f"{path_data}/Results_III_sequences.v3.csv" , sep = "\t", names = ["id","sequence"])

eg_tokenized = tokenized_sequences(df , "csv")
esmfold_prediction(eg_tokenized, "/home/robbyconchaeloko/output" )

