In [2]:
import pathlib
import torch
from absl import app
from absl import flags
import esm
from esm import pretrained, MSATransformer, FastaBatchedDataset
from tqdm import tqdm
from Bio import PDB
from typing import List

In [3]:
from rdkit import Chem
path_protein = "../../dataset2016/1a4k/1a4k_protein.pdb"
path_pocket = "../../dataset2016/1a4k/1a4k_pocket.pdb"

In [4]:
def find_correct_element_of_chain(residue_name, residue_id, atom_symbol, atom_coords, structure):
    for model in structure:
        for i, chain in enumerate(model):
            for j,residue in enumerate(chain):
                if residue.get_resname() == residue_name and residue.get_id()[1] == residue_id:
                    for atom in residue:
                        if atom.element == atom_symbol:
                            if (atom.get_coord() - atom_coords < 0.1).all():
                                return i,j
    return None, None

In [None]:
correct_chain_residue = []
for n, atom in enumerate(molecule.GetAtoms()):
    atom_coord = coords[n]
    atom_symbol = atom.GetSymbol()
    atom_res = atom.GetPDBResidueInfo()
    atom_res_name = atom_res.GetResidueName()
    atom_res_id = atom_res.GetResidueNumber()
    (i, j) = find_correct_element_of_chain(atom_res_name, atom_res_id, atom_symbol, atom_coord, structure)
    correct_chain_residue.append((i, j))
print(correct_chain_residue)

In [37]:
lista = pdb_to_sequences("bla", path_protein)
lista

(<Structure id=bla>,
 [(0,
   'ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRG'),
  (1,
   'QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV'),
  (2,
   'ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNR'),
  (3,
   'QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEP')])

In [9]:
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (

In [11]:
batch_labels, batch_strs, _ = batch_converter(lista)

In [13]:
batch_strs

['ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRG',
 'QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV',
 'ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNR',
 'QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEP']

In [15]:
dataset = FastaBatchedDataset(batch_labels, batch_strs)

In [16]:
dataset

<esm.data.FastaBatchedDataset at 0x7f6054acdba0>

In [17]:
batches = dataset.get_batch_indices(1, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset,collate_fn=batch_converter,batch_sampler=batches)

In [19]:
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in [8]]
repr_layers

[1]

In [22]:
results = dict()
with torch.no_grad():
    for i, (labels, strs, toks) in enumerate(data_loader):
        print(i, labels, strs)
        out = model(toks, repr_layers=repr_layers, return_contacts=False)
        results[labels[0]] = out['representations'][1]

0 [2] ['ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNR']
1 [0] ['ELVMTQTPLSLPVSLGDQASISCRSSQSLLHSNGNTYLHWYLQKPGQSPKLLIYKVSNRFSGVPDRFSGSGSGTDFTLKISRVEAEDLGVYFCSQVTHVPPTFGGGTKLEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRG']
2 [1] ['QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV']
3 [3] ['QVQLLESGPELKKPGETVKISCKASGYTFTNYGMNWVKQAPGKGLKWMGWINTYTGEPTYADDFKGRFAFSLETSASTAYLQINNLKNEDTATYFCVQAERLRRTFDYWGAGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEP']


In [23]:
def read_dataset2(directory, ligand_file_extention, protein_file_extention,
                 aff_dict):
    '''
    from directory returns a list of pdb_id, path to protein, path to ligand
    The directory contains compound folders (each has an ID with 4 letters,
    ex. abcd) with 4 files:
    - abcd_protein.pdb
    - abcd_pocket.pdb
    - abcd_ligand.sdf
    - abcd_ligand.mol2
    '''
    assert ligand_file_extention in ('sdf', 'mol2')
    assert protein_file_extention in ('protein', 'pocket', 'processed')
    molecules_files = []
    for folder_name in os.listdir(directory):
        if len(folder_name) == 4:
            folder_dir = os.path.join(directory, folder_name)
            files = os.listdir(folder_dir)
            compound_id = folder_name

            for file in files:
                if file.endswith(protein_file_extention + '.pdb'):
                    file_pocket = file
                elif file.endswith('ligand.' + ligand_file_extention):
                    file_ligand = file
                elif file.endswith('protein.pdb'):
                    file_protein = file
            if aff_dict[compound_id][4]:
                # only add molecule if affinity is not uncertain
                molecules_files += [
                    (compound_id, os.path.join(folder_dir, file_pocket),
                     os.path.join(folder_dir, file_ligand), os.path.join(folder_dir, file_protein), aff_dict[compound_id][1])
                ]
    return molecules_files


In [24]:
def get_affinities(affinity_directory):
    affinity_dict = {}
    with open(affinity_directory, 'r', encoding='utf-8') as f:
        for line in f:
            aff_not_uncertain = True
            if line[0] != '#':
                fields = line.split()
                pdb_id = fields[0]
                log_aff = float(fields[3])
                aff_str = fields[4]
                if '<' in aff_str or '>' in aff_str or '~' in aff_str:
                    aff_not_uncertain = False
                aff_tokens = re.split('[=<>~]+', aff_str)
                assert len(aff_tokens) == 2
                label, aff_and_unity = aff_tokens
                assert label in ['Kd', 'Ki', 'IC50']
                affinity_value = float(aff_and_unity[:-2])
                aff_unity = aff_and_unity[-2:]
                aff = float(affinity_value)
                affinity_dict[pdb_id] = [
                    label, log_aff, aff, aff_unity, aff_not_uncertain
                ]
    return affinity_dict

In [None]:
def pdb_to_sequences(pdb_id, pdb_filepath: str):
    pdb_parser = PDB.PDBParser(QUIET=True)
    structure = pdb_parser.get_structure(pdb_id, pdb_filepath)
    # Using C-N
    ppb = PDB.PPBuilder()
    polypeptide_sequences = [
        (i,str(pp.get_sequence())) for i,pp in enumerate(ppb.build_peptides(structure))
    ]
    return structure, polypeptide_sequences

In [117]:
import re
import os
aff_dict = get_affinities("../../dataset2016/index/INDEX_general_PL_data.2016")
pdb_files2 = read_dataset2("../../dataset2016", "mol2", "pocket", aff_dict)
import Bio
from Bio import PDB
parser = PDB.PDBParser(QUIET=True)

In [143]:
elements = []
pdb_ids_failed = []
for k, (pdb_id, path_pocket, path_ligand, path_protein, aff) in enumerate(pdb_files2[0:2]):
    print(k, pdb_id, path_pocket, path_ligand, path_protein, aff)
    # pocket molecule
    molecule = Chem.MolFromPDBFile(path_pocket, flavor=2, sanitize=True, removeHs=True)
    conformer = molecule.GetConformer(0)
    coords = conformer.GetPositions()

    # protein molecule
    structure, protein_sequences = pdb_to_sequences2(pdb_id, path_protein)

    # associate atoms of pocket to protein residues
    correct_chain_residue = []
    for n, atom in enumerate(molecule.GetAtoms()):
        atom_coord = coords[n]
        atom_symbol = atom.GetSymbol()
        atom_res = atom.GetPDBResidueInfo()
        atom_res_name = atom_res.GetResidueName()
        atom_res_id = atom_res.GetResidueNumber()
        (i, j) = find_correct_element_of_chain(atom_res_name, atom_res_id, atom_symbol, atom_coord, structure)
        correct_chain_residue.append((i, j))
        if i is not None and j is not None:
            if i < len(protein_sequences):
                if j < len(protein_sequences[i]):
                    # print(i, j, PDB.Polypeptide.one_to_three(protein_sequences[i][j]), atom_res_name)
                    if PDB.Polypeptide.one_to_three(protein_sequences[i][j]) != atom_res_name:
                        pdb_ids_failed.append(pdb_id)
                        break
    
    elements.append(correct_chain_residue)
elements


0 2xei ../../dataset2016/2xei/2xei_pocket.pdb ../../dataset2016/2xei/2xei_ligand.mol2 ../../dataset2016/2xei/2xei_protein.pdb 10.62
1 3qts ../../dataset2016/3qts/3qts_pocket.pdb ../../dataset2016/3qts/3qts_ligand.mol2 ../../dataset2016/3qts/3qts_protein.pdb 5.51




[[(0, 108),
  (0, 108),
  (0, 108),
  (0, 108),
  (0, 108),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 109),
  (0, 151),
  (0, 151),
  (0, 151),
  (0, 151),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 152),
  (0, 153),
  (0, 153),
  (0, 153),
  (0, 153),
  (0, 153),
  (0, 153),
  (0, 153),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 154),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 155),
  (0, 156),
  (0, 156),
  (0, 156),
  (0, 156),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 179),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0, 199),
  (0

In [141]:
pdb_ids_failed

['2am2',
 '2pvk',
 '4s3e',
 '2fkf',
 '1ydt',
 '1uu8',
 '3qkk',
 '1xh7',
 '3aqt',
 '4gqq',
 '2xpc',
 '2c1a',
 '1xh9',
 '4qir',
 '2qln',
 '1xh8',
 '3fy0',
 '2c1b',
 '4ob1',
 '1pu8',
 '3ag9',
 '3krr',
 '1gag',
 '2pfy',
 '1fzk',
 '3igg']

In [130]:
structure, protein_sequences = pdb_to_sequences("2pvk", "../../dataset2016/2pvk/2pvk_protein.pdb")
print(protein_sequences)
for i in protein_sequences:
    print(len(i[1]))

[(0, 'MSKARVYADVNVLRPKEYWDYEALTVQWGEQDDYEVVRKVGRGKYSEVFEGINVNNNEKCIIKILKPVKKKKIKREIKILQNL'), (1, 'GGPNIVKLLDIVRDQHSKTPSLIFEYVNNTDFKVLYPTLTDYDIRYYIYELLKALDYCHSQGIMHRDVKPHNVMIDHELRKLRLIDWGLAEFYHPGKEYNVRVASRYFKGPELLVDLQDYDYSLDMWSLGCMFAGMIFRKEPFFYGHDNHDQLVKIAKVLGTDGLNAYLNKYRIELDPQLEALVGRHSRKPWLKFMNADNQHLVSPEAIDFLDKLLRYDHQERLTALEAMTHPYFQQVRAAENS')]
83
244


In [125]:
PDB.Polypeptide.protein_letters_3to1["MSE"]

KeyError: 'MSE'

In [82]:
def pdb_to_sequences2(pdb_id, pdb_filepath: str):
    pdb_parser = PDB.PDBParser(QUIET=True)
    structure = pdb_parser.get_structure(pdb_id, pdb_filepath)
    
    for model in structure:
        chains = []
        for i, chain in enumerate(model):
            chain_str = ""
            for j, residue in enumerate(chain):
                try:
                    chain_str += PDB.Polypeptide.protein_letters_3to1[residue.get_resname()]
                except KeyError:
                    chain_str += ""
                # print(residue.get_resname(), residue.get_id()[1])
            chains.append(chain_str)
    polypeptide_sequences = [
        chain for chain in chains if chain !=""
    ]
    return structure, polypeptide_sequences

In [83]:
structure, protein_sequences = pdb_to_sequences2("2xei", "../../dataset2016/2xei/2xei_protein.pdb")

In [84]:
protein_sequences

['KHNMKAFLDELKAENIKKFLYNFTQIPHLAGTEQNFQLAKQIQSQWKEFGLDSVELAHYDVLLSYPNKTHPNYISIINEDGNEIFNTSLFEPPPPGYENVSDIVPPFSAFSPQGMPEGDLVYVNYARTEDFFKLERDMKINCSGKIVIARYGKVFRGNKVKNAQLAGAKGVILYSDPADYFAPGVKSYPDGWNLPGGGVQRGNILNLNGAGDPLTPGYPANEYAYRRGIAEAVGLPSIPVHPIGYYDAQKLLEKMGGSAPPDSSWRGSLKVPYNVGPGFTGNFSTQKVKMHIHSTNEVTRIYNVIGTLRGAVEPDRYVILGGHRDSWVFGGIDPQSGAAVVHEIVRSFGTLKKEGWRPRRTILFASWDAEEFGLLGSTEWAEENSRLLQERGVAYINADSSIEGNYTLRVDCTPLMYSLVHNLTKELKSPDEGFEGKSLYESWTKKSPSPEFSGMPRISKLGSGNDFEVFFQRLGIASGRARYTKNWETNKFSGYPLYHSVYETYELVEKFYDPMFKYHLTVAQVRGGMVFELANSIVLPFDCRDYAVVLRKYADKIYSISMKHPQEMKTYSVSFDSLFSAVKNFTEIASKFSERLQDFSNPIVLRMMNDQLMFLERAFIDPLGLPDRPFYRHVIYAPSSHNKYAGESFPGIYDALFDIESKVDPSKAWGEVKRQIYVAAFTVQAAAETLSEVA',
 'KHNMKAFLDELKAENIKKFLYNFTQIPHLAGTEQNFQLAKQIQSQWKEFGLDSVELAHYDVLLSYPNKTHPNYISIINEDGNEIFNTSLFEPPPPGYENVSDIVPPFSAFSPQGMPEGDLVYVNYARTEDFFKLERDMKINCSGKIVIARYGKVFRGNKVKNAQLAGAKGVILYSDPADYFAPGVKSYPDGWNLPGGGVQRGNILNLNGAGDPLTPGYPANEYAYRRGIAEAVGLPSIPVHPIGYYDAQKLLEKMGGSAPPDSSWRGSLKVPYNVGPGFTGNFSTQKVKMHIHSTNEVT

- 18 - 2a2m dá problemas porque tem heteroatoms no meio da cadeia (vários MSE)
- 39 - 2pvk tem um CSO que também parte a cadeia
- 73 - 4s3e tem um KCX que parte as duas cadeias
- 103 - 2fkf tem um SEP na cadeia
- 173 - 1ydt tem um TPO na cadeia

In [3]:
import pickle

In [2]:
with open("../../elements.pkl", "rb") as f:
    elements = pickle.load(f)

In [4]:
with open("../../protein_seqs.pkl", "rb") as f:
    protein_sequences = pickle.load(f)

In [13]:
protein_sequences["1a7t"]

['SVKISDDISITQLSDKVYTYVSLAEIEGWGMVPSNGMIVINNHQAALLDTPINDAQTEMLVNWVTDSLHAKVTTFIPNHWHGDCIGGLGYLQRKGVQSYANQMTIDLAKEKGLPVPEHGFTDSLTVSLDGMPLQCYYLGGGHATDNIVVWLPTENILFGGCMLKDNQTTSIGNISDADVTAWPKTLDKVKAKFPSARYVVPGHGNYGGTELIEHTKQIVNQYIESTS',
 'SVKISDDISITQLSDKVYTYVSLAEIEGWGMVPSNGMIVINNHQAALLDTPINDAQTEMLVNWVTDSLHAKVTTFIPNHWHGDCIGGLGYLQRKGVQSYANQMTIDLAKEKGLPVPEHGFTDSLTVSLDGMPLQCYYLGGGHATDNIVVWLPTENILFGGCMLKDNQTTSIGNISDADVTAWPKTLDKVKAKFPSARYVVPGHGNYGGTELIEHTKQIVNQYIESTS']

In [14]:
import torch
import esm
from esm import pretrained, MSATransformer, FastaBatchedDataset

model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (

In [26]:
%%timeit
repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in [8]]

sequences = protein_sequences["1a7t"]
sequences = [(i, chain) for i,chain in enumerate(sequences)]
batch_labels, batch_strs, _ = batch_converter(sequences)
dataset = FastaBatchedDataset(batch_labels, batch_strs)
batches = dataset.get_batch_indices(1, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset,collate_fn=batch_converter,batch_sampler=batches)

results = dict()
with torch.no_grad():
    for i, (labels, strs, toks) in enumerate(data_loader):
        out = model(toks, repr_layers=repr_layers, return_contacts=False)
        results[labels[0]] = out['representations'][1]

32.9 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
results[0].size()

torch.Size([1, 229, 320])

In [10]:
sequences

[(0,
  'SVKISDDISITQLSDKVYTYVSLAEIEGWGMVPSNGMIVINNHQAALLDTPINDAQTEMLVNWVTDSLHAKVTTFIPNHWHGDCIGGLGYLQRKGVQSYANQMTIDLAKEKGLPVPEHGFTDSLTVSLDGMPLQCYYLGGGHATDNIVVWLPTENILFGGCMLKDNQTTSIGNISDADVTAWPKTLDKVKAKFPSARYVVPGHGNYGGTELIEHTKQIVNQYIESTS'),
 (1,
  'SVKISDDISITQLSDKVYTYVSLAEIEGWGMVPSNGMIVINNHQAALLDTPINDAQTEMLVNWVTDSLHAKVTTFIPNHWHGDCIGGLGYLQRKGVQSYANQMTIDLAKEKGLPVPEHGFTDSLTVSLDGMPLQCYYLGGGHATDNIVVWLPTENILFGGCMLKDNQTTSIGNISDADVTAWPKTLDKVKAKFPSARYVVPGHGNYGGTELIEHTKQIVNQYIESTS')]

In [5]:
for key in protein_sequences:
    sequences = protein_sequences[key]
    for seq in sequences:
        

2xei 2
3qts 1
3hqh 1
3fwv 1
1jut 2
3ff6 2
3fue 1
4cwf 2
3qkv 2
3g0b 2
4m7j 2
3ppp 1
4r92 1
3abu 1
4uwl 1
4lph 2
2wnl 5
3sie 1
2am2 1
4ran 4
2d1n 1
2v2q 1
4zwx 1
3t9t 1
2iok 2
4d2r 1
1f4e 2
3wde 1
1i80 3
2j3q 1
1uvu 2
2idk 4
1g6r 4
4non 1
1hqg 3
4ju6 1
4wr7 1
4b8y 1
2pvk 1
4b4q 1
4mlt 1
2cgu 1
4rwj 1
3qti 1
4aa7 3
4xu0 1
4cxw 1
4bh3 6
3ip8 1
4ayt 2
2f0z 1
1g2a 1
2fx9 2
5c7c 1
4yz9 1
3h1x 1
3kb3 2
4bgg 1
2e9u 1
2xx2 2
4av0 1
2uy5 1
3np7 2
4m2u 1
4bzs 1
4q1s 2
3djk 2
4m7c 2
2xk6 1
4xmr 2
3mt7 2
4s3e 2
1jik 2
3fuk 1
4trz 1
3run 1
1h1r 2
4gzw 4
1o0n 1
4b0j 2
3e2m 2
2z7h 2
4ebw 1
1p19 2
1gmy 3
4bdd 2
4bkj 1
11gs 2
1qhc 1
2zaz 1
1gyy 2
3tpu 3
5d9p 1
2a4z 1
3sut 1
5d1t 1
1u6q 1
3eta 1
1zea 2
3ejt 1
1epo 1
2fkf 1
4oyk 1
3k4q 1
4w5a 1
2wtc 1
4buq 1
4q81 1
3g3m 2
3ask 1
4kln 6
4z90 5
2v0c 1
2ewa 1
2obo 4
1o3l 1
4ay6 1
1yyr 2
1zhl 2
1xb7 2
4hpy 2
4mmf 1
1nfx 2
2q9m 1
4x14 3
3kqd 2
4k3l 2
2z6w 1
3d91 1
2eum 1
4afe 1
2v54 2
4q9y 1
4mrw 1
4oc2 2
4o43 1
4qga 2
4uxj 2
1j4k 1
5auy 1
1d3v 3
1mrn 2
1eve 1

In [8]:
elements["1a7t"]

[(0, 23),
 (0, 23),
 (0, 23),
 (0, 23),
 (0, 23),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 25),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 26),
 (0, 27),
 (0, 27),
 (0, 27),
 (0, 27),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 28),
 (0, 29),
 (0, 29),
 (0, 29),
 (0, 29),
 (0, 31),
 (0, 31),
 (0, 31),
 (0, 31),
 (0, 31),
 (0, 31),
 (0, 31),
 (0, 32),
 (0, 32),
 (0, 32),
 (0, 32),
 (0, 32),
 (0, 32),
 (0, 32),
 (0, 33),
 (0, 33),
 (0, 33),
 (0, 33),
 (0, 33),
 (0, 33),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 51),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 78),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 80),
 (0, 81),
 (0, 81),
 (0, 81),
 (0, 81),
 (0, 82),
 (0, 82),
 (0, 82),
 (0, 82),
