# GOAL : generate ESM-2 embeddings for each proteins 
### I. Generate the sequence file
### II. The embedding script
***

I.

In [None]:
# Build the sequence idex file :
import os 
from Bio import SeqIO
import pandas as pd 

path_millard = "/home/conchae/ML_depolymerase/get_candidates/millard"

data = {}
for rec in SeqIO.parse(f"{path_millard}/millard_depo.v2.fasta" , "fasta") :
    if rec.seq not in data :
        data[rec.seq] = [rec.id.split("__")[0]]
    else :
        data[rec.seq].append(rec.id.split("__")[0])
        
with open(f"{path_millard}/df_sequences.index.v2.csv" ,"w") as outfile :
    for index_seq, seq in enumerate(list(data.keys())) :
        for prot in data[seq] :
            outfile.write(f"{index_seq}\t{prot}\t{seq}\n")
        
df = pd.read_csv(f"{path_millard}/df_sequences.index.v2.csv", sep="\t", names = ["index","id","sequence"])       
df = df.drop_duplicates(subset=["index"], keep="first")
df.to_csv(f"{path_millard}/df_sequences.index.clean.v2.csv", sep="\t", columns = ["index","sequence"], index=False)


df = pd.read_csv(f"{path_millard}/df_sequences.index.clean.v2.csv", sep="\t")
with open(f"{path_millard}/millard_depo.indexed.v2.fasta" , "w") as outfile :
    dico_interest = df.to_dict("records")
    for row in dico_interest :
        outfile.write(f">{row['index']}\n{row['sequence']}\n")

***
II. 

In [None]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model, alphabet = esm.pretrained(esm2_t33_650M_UR50D())

batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
df = pd.read_csv(f"{path_millard}/df_sequences.index.clean.v2.csv", sep="\t", names = ["index","sequence"])       
data = df.to_records(index=False)[1:]

batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()

***
III.

In [1]:
#!/bin/bash
#BATCH --job-name=ESM_2__
#SBATCH --qos=long
#SBATCH --ntasks=1 
#SBATCH --cpus-per-task=50 
#SBATCH --mem=200gb 
#SBATCH --time=10-00:00:00 
#SBATCH --output=ESM_2__%j.log 

source /storage/apps/ANACONDA/anaconda3/etc/profile.d/conda.sh
conda activate embeddings

python /home/conchae/software/esm/scripts/extract.py \
esm2_t33_650M_UR50D \
/home/conchae/ML_depolymerase/get_candidates/millard/millard_depo.indexed.v2.fasta \
/home/conchae/ML_depolymerase/get_candidates/millard/millard_depo.indexed.v2.esm2_out \
--repr_layers 0 32 33 \
--include mean per_tok

   Courses    Fee Duration  Discount
0    Spark  20000   30days      1000
1  PySpark  25000   40days      2300
2   Python  22000   35days      1200
3   pandas  30000   50days      2000
4   Python  22000   40days      2300
5    Spark  20000   30days      1000
6   pandas  30000   50days      2000


Directory some_proteins_emb_esm2/ now contains one .pt file per FASTA sequence; use torch.load() to load them

In [None]:
import torch
import os 
import pandas as pd

path_millard = "/home/conchae/ML_depolymerase/get_candidates/millard/millard_depo.indexed.v2.esm2_out"

embeddings_esm = {}
for file in os.listdir(path_millard) :
    index = file.split(".pt")[0]
    embb = torch.load(f"{path_millard}/{file}")["mean_representations"][33].tolist()
    embeddings_esm[index] = embb
    
with open(f"/home/conchae/ML_depolymerase/get_candidates/millard/embeddings.proteins.v2.csv" , "w") as outfile :
    for index in embeddings_esm :
        outfile.write(f"{index},")
        for _,  emb in enumerate(embeddings_esm[index]) :
            outfile.write(f"{emb},")
        outfile.write("\n")
    


In [None]:
rsync -avzhe ssh conchae@garnatxa.srv.cpd:/home/conchae/ML_depolymerase/get_candidates/millard/millard_depo.indexed.v2.fasta /media/concha-eloko/Linux/depolymerase_building
rsync -avzhe ssh conchae@garnatxa.srv.cpd:/home/conchae/ML_depolymerase/get_candidates/millard/embeddings.proteins.v2.csv /media/concha-eloko/Linux/depolymerase_building
rsync -avzhe ssh conchae@garnatxa.srv.cpd:/home/conchae/ML_depolymerase/get_candidates/millard/df_sequences.index.v2.csv /media/concha-eloko/Linux/depolymerase_building
rsync -avzhe ssh conchae@garnatxa.srv.cpd:/home/conchae/ML_depolymerase/get_candidates/millard/proteinID_annotation.v2.json /media/concha-eloko/Linux/depolymerase_building

    
    