In [1]:
import pickle
import glob

from proteinbert import load_pretrained_model
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs

In [2]:
# There are 8 proteins with sequence length longer than 2048
# however encoder does not support truncating so I just set this to 3K sequence length.
seq_len = 3072
pretrained_model_generator, input_encoder = load_pretrained_model()
model = get_model_with_hidden_layers_as_outputs(pretrained_model_generator.create_model(seq_len))

# Read files
files = glob.glob('../alignment/data/sequences/*/*/*.fa', recursive = True)
print(len(files))

# There are 1564 protein sequences however only 1510 of them are unique.

1564


In [3]:
protein_embeddings_result = {}
for file in files:
    data = {}
    
    lines = []
    with open(file, 'r') as f:
        for line in f:
            lines.append(line.strip())
    
    sequence = "".join(lines[1:])

    # Lookup key that contains the hierarchical directory of the protein.
    # Below code is for windows directory structure with "\\"
    protein_name = file.split('.fa')[0].split('sequences')[-1].replace('\\', '/')[1:]
    
    
    encoded_x = input_encoder.encode_X([sequence], seq_len)
    local_representations, global_representations = model.predict(encoded_x, batch_size = 1, verbose = 0)
    
    # SPECIAL TOKENS: 22: UNK, 23: BOS, 24: EOS, 25: PAD
    mask_to_exclude_special_tokens =  encoded_x[0][0] < 23 # exclude 23: BOS, 24: EOS, 25: PAD
    local_rep = local_representations[0][mask_to_exclude_special_tokens]
    global_rep = global_representations[0]
    
    data['sequence'] = sequence
    data['local_representations'] = local_rep
    data['global_representations'] = global_rep
    
    protein_embeddings_result[protein_name] = data

In [6]:
# Test Print the results
for index, (k, v) in enumerate(protein_embeddings_result.items()):
    print(k)
    print(v['sequence'])
    print(v['local_representations'][:2, :8])
    print(v['global_representations'][-12:])
    print('------------')
    
    if index  == 5:
        break

RV11/BB11001/1aab_
GKGDPKKPRGKMSSYAFFVQTSREEHKKKHPDASVNFSEFSKKCSERWKTMSAKEKGKFEDMAKADKARYEREMKTYIPPKGE
[[-0.03063907  0.0439351  -0.0230548  -0.00454346 -0.07563948  0.23824188
   0.28760016  0.21519585]
 [-0.10842812 -0.01052955 -0.12841734 -0.6247565  -0.15854235 -0.06815735
   0.18639801 -0.2140724 ]]
[3.7444304e-06 1.6075720e-06 2.0214695e-06 2.6980438e-06 4.1014137e-06
 3.1268844e-06 1.9361430e-06 5.6238391e-06 4.0274549e-06 8.0087545e-07
 2.5007503e-06 6.9546523e-08]
------------
RV11/BB11001/1j46_A
MQDRVKRPMNAFIVWSRDQRRKMALENPRMRNSEISKQLGYQWKMLTEAEKWPFFQEAQKLQAMHREKYPNYKYRPRRKAKMLPK
[[-0.0844609  -0.01888618  0.13432218 -0.15523     0.04527348 -0.10490427
   0.16837597  0.3024746 ]
 [-0.13980944  0.08109107 -0.12077155 -0.81356317 -0.10226007 -0.06641937
   0.4179446   0.12429711]]
[1.7145287e-05 3.1849218e-06 4.1299913e-06 4.0996460e-06 4.1652265e-06
 3.1683444e-06 1.9399226e-06 8.1369217e-06 5.2518585e-06 2.4262645e-06
 2.5675452e-06 1.8604208e-06]
------------
RV11/BB11001/1k

In [7]:
# Write result as pickle.
# Json is not possible because NumPy arrays are not serializable
# This is 4.7 GB for whole data.
with open('protein_embeddings.pkl', 'wb') as f:
    pickle.dump(protein_embeddings_result, f)