In [None]:
! pip install --upgrade transformers py3Dmol accelerate
!pip3 install biopython
!apt-get install dssp
!apt-get install mkdssp

In [None]:
from transformers import AutoTokenizer, EsmForProteinFolding
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

In [None]:
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

In [None]:
import pandas as pd

df = pd.read_excel('/content/AMPS.xlsx')

print(df)

bulk_tokenized = tokenizer(df.Sequence.tolist(), padding=False, add_special_tokens=False)['input_ids']

In [None]:
from tqdm import tqdm

outputs = []

with torch.no_grad():
    for input_ids in tqdm(bulk_tokenized):
        input_ids = torch.tensor(input_ids).unsqueeze(0)
        output = model(input_ids)
        outputs.append({key: val.cpu() for key, val in output.items()})

In [None]:
import pandas as pd

def read_pdb_data(pdb_data):
    column_indices = {
        "atom_serial": (6, 11),
        "atom_name": (12, 16),
        "alt_loc": (16, 17),
        "res_name": (17, 20),
        "chain_id": (21, 22),
        "res_seq": (22, 26),
        "insertion_code": (26, 27),
        "x_coord": (30, 38),
        "y_coord": (38, 46),
        "z_coord": (46, 54),
        "occupancy": (54, 60),
        "b_factor": (60, 66),
        "element_symbol": (76, 78),
        "charge": (78, 80)
    }

    columns = ["ATOM/HETATM", "serial", "atom_name", "alt_loc", "res_name", "chain_id",
               "res_seq", "insertion_code", "x_coord", "y_coord", "z_coord",
               "occupancy", "b_factor", "element_symbol", "charge"]

    data = []

    for pdb_entry in pdb_data:
        lines = pdb_entry.split('\n')
        for line in lines:
            if line.startswith("ATOM") or line.startswith("HETATM"):
                row = {
                    "ATOM/HETATM": line[0:6].strip(),
                    "serial": int(line[6:11]),
                    "atom_name": line[12:16].strip(),
                    "alt_loc": line[16:17].strip(),
                    "res_name": line[17:20].strip(),
                    "chain_id": line[21:22].strip(),
                    "res_seq": int(line[22:26]),
                    "insertion_code": line[26:27].strip(),
                    "x_coord": float(line[30:38]),
                    "y_coord": float(line[38:46]),
                    "z_coord": float(line[46:54]),
                    "occupancy": float(line[54:60]),
                    "b_factor": float(line[60:66]),
                    "element_symbol": line[76:78].strip(),
                    "charge": line[78:80].strip()
                }
                data.append(row)

    df = pd.DataFrame(data, columns=columns)
    return df


In [None]:
pdb_list = [convert_outputs_to_pdb(output) for output in outputs]
dfs = []
for i in range(len(pdb_list)):
     data_frame = read_pdb_data(pdb_list[i])
     dfs.append(data_frame)

plddt_cutoff = 0.7
amino_acids = {
    'ala': 'A', # Alanine
    'arg': 'R', # Arginine
    'asn': 'N', # Asparagine
    'asp': 'D', # Aspartic acid
    'cys': 'C', # Cysteine
    'gln': 'Q', # Glutamine
    'glu': 'E', # Glutamic acid
    'gly': 'G', # Glycine
    'his': 'H', # Histidine
    'ile': 'I', # Isoleucine
    'leu': 'L', # Leucine
    'lys': 'K', # Lysine
    'met': 'M', # Methionine
    'phe': 'F', # Phenylalanine
    'pro': 'P', # Proline
    'ser': 'S', # Serine
    'thr': 'T', # Threonine
    'trp': 'W', # Tryptophan
    'tyr': 'Y', # Tyrosine
    'val': 'V'  # Valine
}
amino_acids = {key.lower(): value for key, value in amino_acids.items()}

sequence = []
folded = []
for df in dfs:
    df = df.groupby(['res_seq', 'res_name']).agg({'b_factor': 'mean'}).reset_index()
    df['res_name'] = df['res_name'].str.lower().map(amino_acids)
    df['F'] = df['b_factor'].apply(lambda x: 'F' if x >= plddt_cutoff else '-')
    res_name_string = ''.join(df.sort_values('res_seq')['res_name'].tolist())
    folded_state = ''.join(df.sort_values('res_seq')['F'].tolist())
    sequence.append(res_name_string)
    folded.append(folded_state)
df = pd.DataFrame({'sequence': sequence, 'folded': folded})

In [None]:
import pandas as pd
from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

# Assuming you have a DataFrame named df
for i in range(len(pdb_list)):
    pdb_data_list = pdb_list[i]
    pdb_data_string = "\n".join(pdb_data_list)
    pdb_data_string = pdb_data_string.replace("PARENT N/A\n", "CRYST1    1.000    1.000    1.000  90.00  90.00  90.00 P 1           1\n")

    pdb_file_path = "PDB.pdb"
    with open(pdb_file_path, 'w') as pdb_file:
        pdb_file.write(pdb_data_string)

    parser = PDBParser()
    structure = parser.get_structure('pdb_structure', '/content/PDB.pdb')

    model = structure[0]
    dssp = DSSP(model, '/content/PDB.pdb')

    secondary_structures = [dssp[key][2] for key in dssp.keys()]
    secondary_structures = ''.join(secondary_structures)
    df.at[i, 'secondary_structures'] = secondary_structures


In [None]:
print(df)
df.to_excel('plddtscores_dsspstructures.xlsx', index=False)

In [None]:
import py3Dmol

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)
view.addModel("".join(pdb_list[0]), 'pdb') #change index to see different proteins; more blue = higher confidence

if torch.max(output['plddt']) <= 1.0:
    vmin = 0.5
    vmax = 0.95
else:
    vmin = 50
    vmax = 95

view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min': vmin,'max': vmax}}})