In [1]:
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy('mixed_float16')

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 3070, compute capability 8.6


In [2]:
log_dir = "logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [3]:
import proteinbert
from proteinbert import load_pretrained_model
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs
from Bio import SeqIO
import numpy


In [4]:
pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir='proteinbert_modeldir')

In [5]:
train_path = "CAFA3_training_data/uniprot_sprot_exp.fasta"

In [6]:
MAX_SEQ_LEN = 1024

In [21]:
import numpy as np
import tqdm
def chunks(l, size):
    return [l[i:i+size] for i in range(0, len(l), size)]

def piecewise_predict(sequences_file, output_name, superbatch_size=500, batch_size=16):
    seqs = []
    ids = []
    for seq_record in SeqIO.parse(sequences_file, "fasta"):
        seqs.append(str(seq_record.seq))
        ids.append(str(seq_record.id))
    MAX_SEQ_LEN = 1024
    model = get_model_with_hidden_layers_as_outputs(pretrained_model_generator.create_model(MAX_SEQ_LEN))
    trimmed_seqs = [x[:MAX_SEQ_LEN-2] for x in seqs]
    all_ems = []
    for superbatch in tqdm.tqdm(chunks(trimmed_seqs, superbatch_size)):
        X = input_encoder.encode_X(superbatch, MAX_SEQ_LEN)
        _, superbatch_em = model.predict(X, batch_size=batch_size)
        all_ems.append(superbatch_em)
    concat_ems = np.concatenate(all_ems)
    ids_arr = np.asarray(ids)
    np.save(f"{output_name}_embeddings.npy", concat_ems)
    np.save(f"{output_name}_ids.npy", ids_arr)
    return concat_ems, ids_arr

In [28]:
ems, ids = piecewise_predict(train_path, "proteinbert_train", superbatch_size=512, batch_size=32)

  0%|          | 0/131 [00:00<?, ?it/s]



  1%|          | 1/131 [00:03<07:23,  3.41s/it]



  2%|▏         | 2/131 [00:05<05:12,  2.42s/it]



  2%|▏         | 3/131 [00:06<04:31,  2.12s/it]



  3%|▎         | 4/131 [00:08<04:12,  1.99s/it]



  4%|▍         | 5/131 [00:10<04:00,  1.91s/it]



  5%|▍         | 6/131 [00:12<03:54,  1.88s/it]



  5%|▌         | 7/131 [00:14<03:47,  1.84s/it]



  6%|▌         | 8/131 [00:15<03:43,  1.81s/it]



  7%|▋         | 9/131 [00:17<03:38,  1.79s/it]



  8%|▊         | 10/131 [00:19<03:35,  1.78s/it]



  8%|▊         | 11/131 [00:21<03:34,  1.79s/it]



  9%|▉         | 12/131 [00:22<03:31,  1.78s/it]



 10%|▉         | 13/131 [00:24<03:29,  1.77s/it]



 11%|█         | 14/131 [00:26<03:28,  1.78s/it]



 11%|█▏        | 15/131 [00:28<03:30,  1.82s/it]



 12%|█▏        | 16/131 [00:29<03:24,  1.78s/it]



 13%|█▎        | 17/131 [00:31<03:20,  1.76s/it]



 14%|█▎        | 18/131 [00:33<03:16,  1.74s/it]



 15%|█▍        | 19/131 [00:35<03:16,  1.76s/it]



 15%|█▌        | 20/131 [00:36<03:13,  1.74s/it]



 16%|█▌        | 21/131 [00:38<03:10,  1.73s/it]



 17%|█▋        | 22/131 [00:40<03:07,  1.72s/it]



 18%|█▊        | 23/131 [00:42<03:05,  1.72s/it]



 18%|█▊        | 24/131 [00:43<03:05,  1.73s/it]



 19%|█▉        | 25/131 [00:45<03:02,  1.72s/it]



 20%|█▉        | 26/131 [00:47<03:00,  1.72s/it]



 21%|██        | 27/131 [00:48<02:59,  1.72s/it]



 21%|██▏       | 28/131 [00:50<02:57,  1.72s/it]



 22%|██▏       | 29/131 [00:52<02:55,  1.72s/it]



 23%|██▎       | 30/131 [00:54<02:55,  1.74s/it]



 24%|██▎       | 31/131 [00:55<02:52,  1.73s/it]



 24%|██▍       | 32/131 [00:57<02:50,  1.72s/it]



 25%|██▌       | 33/131 [00:59<02:48,  1.72s/it]



 26%|██▌       | 34/131 [01:00<02:46,  1.71s/it]



 27%|██▋       | 35/131 [01:02<02:44,  1.71s/it]



 27%|██▋       | 36/131 [01:04<02:42,  1.71s/it]



 28%|██▊       | 37/131 [01:06<02:41,  1.72s/it]



 29%|██▉       | 38/131 [01:07<02:39,  1.72s/it]



 30%|██▉       | 39/131 [01:09<02:38,  1.72s/it]



 31%|███       | 40/131 [01:11<02:38,  1.75s/it]



 31%|███▏      | 41/131 [01:13<02:37,  1.74s/it]



 32%|███▏      | 42/131 [01:14<02:35,  1.75s/it]



 33%|███▎      | 43/131 [01:16<02:33,  1.74s/it]



 34%|███▎      | 44/131 [01:18<02:32,  1.76s/it]



 34%|███▍      | 45/131 [01:20<02:30,  1.75s/it]



 35%|███▌      | 46/131 [01:21<02:29,  1.75s/it]



 36%|███▌      | 47/131 [01:23<02:26,  1.75s/it]



 37%|███▋      | 48/131 [01:25<02:25,  1.75s/it]



 37%|███▋      | 49/131 [01:27<02:22,  1.74s/it]



 38%|███▊      | 50/131 [01:28<02:21,  1.74s/it]



 39%|███▉      | 51/131 [01:30<02:18,  1.73s/it]



 40%|███▉      | 52/131 [01:32<02:18,  1.75s/it]



 40%|████      | 53/131 [01:34<02:16,  1.75s/it]



 41%|████      | 54/131 [01:35<02:15,  1.76s/it]



 42%|████▏     | 55/131 [01:37<02:14,  1.77s/it]



 43%|████▎     | 56/131 [01:39<02:11,  1.76s/it]



 44%|████▎     | 57/131 [01:41<02:09,  1.75s/it]



 44%|████▍     | 58/131 [01:42<02:08,  1.76s/it]



 45%|████▌     | 59/131 [01:44<02:06,  1.75s/it]



 46%|████▌     | 60/131 [01:46<02:04,  1.75s/it]



 47%|████▋     | 61/131 [01:48<02:02,  1.75s/it]



 47%|████▋     | 62/131 [01:49<02:00,  1.75s/it]



 48%|████▊     | 63/131 [01:51<02:02,  1.80s/it]



 49%|████▉     | 64/131 [01:53<02:01,  1.81s/it]



 50%|████▉     | 65/131 [01:55<02:01,  1.84s/it]



 50%|█████     | 66/131 [01:57<01:58,  1.83s/it]



 51%|█████     | 67/131 [01:59<01:55,  1.80s/it]



 52%|█████▏    | 68/131 [02:00<01:53,  1.80s/it]



 53%|█████▎    | 69/131 [02:02<01:50,  1.79s/it]



 53%|█████▎    | 70/131 [02:04<01:49,  1.79s/it]



 54%|█████▍    | 71/131 [02:06<01:47,  1.79s/it]



 55%|█████▍    | 72/131 [02:08<01:45,  1.80s/it]



 56%|█████▌    | 73/131 [02:09<01:44,  1.80s/it]



 56%|█████▋    | 74/131 [02:11<01:41,  1.79s/it]



 57%|█████▋    | 75/131 [02:13<01:39,  1.79s/it]



 58%|█████▊    | 76/131 [02:15<01:38,  1.79s/it]



 59%|█████▉    | 77/131 [02:16<01:36,  1.79s/it]



 60%|█████▉    | 78/131 [02:18<01:34,  1.78s/it]



 60%|██████    | 79/131 [02:20<01:33,  1.79s/it]



 61%|██████    | 80/131 [02:22<01:31,  1.80s/it]



 62%|██████▏   | 81/131 [02:24<01:29,  1.79s/it]



 63%|██████▎   | 82/131 [02:25<01:27,  1.79s/it]



 63%|██████▎   | 83/131 [02:27<01:25,  1.79s/it]



 64%|██████▍   | 84/131 [02:29<01:24,  1.80s/it]



 65%|██████▍   | 85/131 [02:31<01:22,  1.78s/it]



 66%|██████▌   | 86/131 [02:33<01:19,  1.77s/it]



 66%|██████▋   | 87/131 [02:34<01:18,  1.77s/it]



 67%|██████▋   | 88/131 [02:36<01:16,  1.78s/it]



 68%|██████▊   | 89/131 [02:38<01:14,  1.78s/it]



 69%|██████▊   | 90/131 [02:40<01:12,  1.78s/it]



 69%|██████▉   | 91/131 [02:41<01:10,  1.77s/it]



 70%|███████   | 92/131 [02:43<01:09,  1.77s/it]



 71%|███████   | 93/131 [02:45<01:07,  1.79s/it]



 72%|███████▏  | 94/131 [02:47<01:06,  1.79s/it]



 73%|███████▎  | 95/131 [02:49<01:03,  1.77s/it]



 73%|███████▎  | 96/131 [02:50<01:01,  1.77s/it]



 74%|███████▍  | 97/131 [02:52<00:59,  1.76s/it]



 75%|███████▍  | 98/131 [02:54<00:57,  1.75s/it]



 76%|███████▌  | 99/131 [02:56<00:56,  1.76s/it]



 76%|███████▋  | 100/131 [02:57<00:54,  1.76s/it]



 77%|███████▋  | 101/131 [02:59<00:52,  1.76s/it]



 78%|███████▊  | 102/131 [03:01<00:50,  1.76s/it]



 79%|███████▊  | 103/131 [03:03<00:49,  1.76s/it]



 79%|███████▉  | 104/131 [03:04<00:47,  1.76s/it]



 80%|████████  | 105/131 [03:06<00:45,  1.76s/it]



 81%|████████  | 106/131 [03:08<00:43,  1.76s/it]



 82%|████████▏ | 107/131 [03:10<00:42,  1.75s/it]



 82%|████████▏ | 108/131 [03:11<00:40,  1.76s/it]



 83%|████████▎ | 109/131 [03:13<00:38,  1.76s/it]



 84%|████████▍ | 110/131 [03:15<00:37,  1.76s/it]



 85%|████████▍ | 111/131 [03:17<00:35,  1.76s/it]



 85%|████████▌ | 112/131 [03:18<00:33,  1.77s/it]



 86%|████████▋ | 113/131 [03:20<00:32,  1.79s/it]



 87%|████████▋ | 114/131 [03:22<00:30,  1.82s/it]



 88%|████████▊ | 115/131 [03:24<00:29,  1.82s/it]



 89%|████████▊ | 116/131 [03:26<00:27,  1.81s/it]



 89%|████████▉ | 117/131 [03:28<00:25,  1.80s/it]



 90%|█████████ | 118/131 [03:29<00:23,  1.79s/it]



 91%|█████████ | 119/131 [03:31<00:21,  1.78s/it]



 92%|█████████▏| 120/131 [03:33<00:19,  1.78s/it]



 92%|█████████▏| 121/131 [03:35<00:17,  1.77s/it]



 93%|█████████▎| 122/131 [03:36<00:15,  1.77s/it]



 94%|█████████▍| 123/131 [03:38<00:14,  1.79s/it]



 95%|█████████▍| 124/131 [03:40<00:12,  1.80s/it]



 95%|█████████▌| 125/131 [03:42<00:10,  1.82s/it]



 96%|█████████▌| 126/131 [03:44<00:09,  1.80s/it]



 97%|█████████▋| 127/131 [03:45<00:07,  1.79s/it]



 98%|█████████▊| 128/131 [03:47<00:05,  1.78s/it]



 98%|█████████▊| 129/131 [03:49<00:03,  1.77s/it]



 99%|█████████▉| 130/131 [03:51<00:01,  1.78s/it]



100%|██████████| 131/131 [03:54<00:00,  1.79s/it]


In [22]:
test_path = "cafa3_targets.fasta"

In [23]:
ems, ids = piecewise_predict(test_path, "proteinbert_test")

  0%|          | 0/7 [00:00<?, ?it/s]



 14%|█▍        | 1/7 [00:03<00:21,  3.59s/it]



 29%|██▊       | 2/7 [00:05<00:13,  2.62s/it]



 43%|████▎     | 3/7 [00:07<00:09,  2.29s/it]



 57%|█████▋    | 4/7 [00:09<00:06,  2.17s/it]



 71%|███████▏  | 5/7 [00:11<00:04,  2.08s/it]



 86%|████████▌ | 6/7 [00:13<00:02,  2.04s/it]



100%|██████████| 7/7 [00:14<00:00,  2.09s/it]
