In [None]:
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

os.environ['TPU_NUM_DEVICES'] = '4'
os.environ['LD_PRELOAD'] = ''
os.environ['XLA_USE_BF16'] = '1'

def tokenized_sequences(sequences, fasta_or_csv) :
    """
    The function takes as an input a either a multifasta file or a dataframe with two columns.
    If the input is a dataframe, the shape would consist of two columns with :
    - 'id', which corresponds to the protein name
    - 'sequence', which corresponds to the aa sequence
    The function returns a list of tuples (a,b) with a as the id and b as the tokenized inputs
    """
    if fasta_or_csv == "csv" :
        dico_seq = {}
        for i, row in sequences.iterrows():
            dico_seq[row["id"]] =  row["sequence"]
    elif fasta_or_csv == "fasta" :
        from Bio import SeqIO
        dico_seq = {record.id : str(record.seq) for record in SeqIO.parse(sequences, "fasta")}
    tokenized_sequences = []
    for idd in dico_seq :
        tokenized_input = tokenizer(dico_seq[idd], return_tensors="pt", add_special_tokens=False)['input_ids']
        a = (idd , tokenized_input)
        tokenized_sequences.append(a)
    return tokenized_sequences


# Move the device setup and model instantiation inside the prediction function
def esmfold_prediction(index, flags, tokenized_sequences, path_out):
    device = xm.xla_device()
    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
    model = model.to(dtype=torch.bfloat16)
    model = model.to(device)
    model.trunk.set_chunk_size(64)
    for protein in tokenized_sequences:
        pdb_files = []
        with torch.no_grad():
            prot_to_pred = protein[1].to(device)
            output = model(prot_to_pred)
        print("Is it ?")
        with open(f"{path_out}/{protein[0]}.esmfold_out", "w") as outfile:
            outfile.write(pdb_txt[0])
        torch.cuda.empty_cache()

import pandas as pd
import os
path_data = "/home/robbyconchaeloko/DpoK-serotypeTropism/data"
df = pd.read_csv(f"{path_data}/Results_III_sequences.v3.csv" , sep = "\t", names = ["id","sequence"])
eg_tokenized = tokenized_sequences(df , "csv")

def main():
    # Set the number of TPU cores (usually 8 for a single Cloud TPU v3-8)
    num_tpu_cores = 3
    # Replace the following with the output of the tokenized_sequences function
    tokenized_sequences = [...]
    # Set the output path
    path_out = "/home/robbyconchaeloko/output"
    xmp.spawn(esmfold_prediction, args=({}, eg_tokenized, path_out), nprocs=num_tpu_cores, start_method="fork")

if __name__ == "__main__":
    main()
