# Setup

In [1]:
# Test GPU

import tensorflow as tf

with tf.device('/gpu:0'):
    a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
    b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
    c = tf.matmul(a, b)

print(c)

tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py

import tensorflow as tf
from tensorflow import keras

#!pip install bert-for-tf2
import bert
from bert import BertModelLayer

# Create a dataset

In [3]:
VOCAB_SIZE = 16000
N_RESERVED_SYMBOLS = 2 # We want to reserve two symbols: 1) for PADDING, 2) for MASKING.

PAD_TOKEN = VOCAB_SIZE - 1
MASK_TOKEN = VOCAB_SIZE - 2

MAX_LEN = 250

In [4]:
SAMPLES_SLICE = slice(None)

DATASET_H5_FILE_PATH = '/cs/phd/nadavb/cafa_project/data/protein_tokens.h5'

def pad_tokens(tokens):
    return np.concatenate([tokens, PAD_TOKEN * np.ones(MAX_LEN - len(tokens), dtype = tokens.dtype)])

with h5py.File(DATASET_H5_FILE_PATH, 'r') as h5f:
    
    h5f_group = h5f['protein_tokens']
    relevant_seqs_tokens = h5f_group['tokens'][SAMPLES_SLICE][h5f_group['n_tokens'][SAMPLES_SLICE] <= MAX_LEN]
    print('Selected %d proteins of relevant length.' % len(relevant_seqs_tokens))
    
    dataset_tokens = np.array(list(map(pad_tokens, relevant_seqs_tokens)))
    del relevant_seqs_tokens
    print(dataset_tokens.shape)

Selected 947596 proteins of relevant length.
(947596, 250)


In [5]:
MASK_OUT_FREQ = 0.2

mask = np.ones_like(dataset_tokens, dtype = bool).flatten()
mask[:int(MASK_OUT_FREQ * mask.size)] = False
np.random.shuffle(mask)
mask = mask.reshape(dataset_tokens.shape)
print(mask)

masked_dataset_tokens = np.where(mask, dataset_tokens, MASK_TOKEN)
print(masked_dataset_tokens)

[[ True  True  True ...  True False False]
 [ True  True  True ... False  True  True]
 [ True False  True ...  True  True  True]
 ...
 [ True  True  True ... False  True  True]
 [ True  True False ...  True  True  True]
 [False  True False ...  True  True  True]]
[[  316  2194 12416 ... 15999 15998 15998]
 [  313   570  1641 ... 15998 15999 15999]
 [ 4901 15998  2842 ... 15999 15999 15999]
 ...
 [ 8871  7011  2903 ... 15998 15999 15999]
 [  329  5480 15998 ... 15999 15999 15999]
 [15998  1129 15998 ... 15999 15999 15999]]


# Let's BERT

In [6]:
l_bert = BertModelLayer(**BertModelLayer.Params(
    
    # embedding params  
    vocab_size               = VOCAB_SIZE,        
    use_token_type           = True,
    use_position_embeddings  = True,
    token_type_vocab_size    = 2,

    # transformer encoder params
    num_heads                = 12,
    num_layers               = 12,           
    hidden_size              = 768,
    hidden_dropout           = 0.1,
    intermediate_size        = 4 * 768,
    intermediate_activation  = "gelu",

    # see arXiv:1902.00751 (adapter-BERT)
    adapter_size             = None,         

    # True for ALBERT (arXiv:1909.11942)
    shared_layer             = False,
    # None for BERT, wordpiece embedding size for ALBERT
    embedding_size           = None,   

    # any other Keras layer params
    name                     = "bert",    
))

l_input_ids = keras.layers.Input(shape = (MAX_LEN,), dtype = np.int32)
# shape: (batch_size, max_len, hidden_size)
bert_output = l_bert(l_input_ids)
# shape: (batch_size, max_len, vocab_size)
token_guess_output = keras.layers.Dense(VOCAB_SIZE, activation = 'softmax', name = 'token_guess')(bert_output)

model = keras.Model(inputs = l_input_ids, outputs = token_guess_output)
model.build(input_shape = (None, MAX_LEN))

print(model.summary())

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 250)]             0         
_________________________________________________________________
bert (BertModelLayer)        (None, 250, 768)          97737216  
_________________________________________________________________
token_guess (Dense)          (None, 250, 16000)        12304000  
Total params: 110,041,216
Trainable params: 110,041,216
Non-trainable params: 0
_________________________________________________________________
None


In [None]:
# Load some pre-trained weights
bert.load_bert_weights(l_bert, \
        '/cs/phd/nadavb/cafa_project/data/bret_pretrained_model/multi_cased_L-12_H-768_A-12/bert_model.ckpt')

In [7]:
WEIGHTS_FILE = '/cs/phd/nadavb/cafa_project/data/protobret_weights.h5'

In [8]:
# Load previously trained weights
model.load_weights(WEIGHTS_FILE)

In [44]:
l_bert.trainable = True

optimizer = keras.optimizers.Adam(lr = 1e-06, amsgrad = True)
model.compile(optimizer = optimizer, loss = 'sparse_categorical_crossentropy')

In [49]:
N_EPOCHS = 1000
EPOCH_SIZE = 2000

for i in range(N_EPOCHS):
    
    print('Epoch %d:' % (i + 1))
    
    epoch_mask = np.random.randint(0, len(dataset_tokens), EPOCH_SIZE)
    epoch_X = masked_dataset_tokens[epoch_mask, :]
    epoch_Y = dataset_tokens[epoch_mask, :]
    model.fit(epoch_X, epoch_Y, batch_size = 3)
    
    if i % 20 == 0:
        model.save_weights(WEIGHTS_FILE)

Epoch 1:
Train on 2000 samples
Epoch 2:
Train on 2000 samples
Epoch 3:
Train on 2000 samples
Epoch 4:
Train on 2000 samples
Epoch 5:
Train on 2000 samples
Epoch 6:
Train on 2000 samples
Epoch 7:
Train on 2000 samples
Epoch 8:
Train on 2000 samples
Epoch 9:
Train on 2000 samples
Epoch 10:
Train on 2000 samples
Epoch 11:
Train on 2000 samples
Epoch 12:
Train on 2000 samples
Epoch 13:
Train on 2000 samples
Epoch 14:
Train on 2000 samples
Epoch 15:
Train on 2000 samples
Epoch 16:
Train on 2000 samples
Epoch 17:
Train on 2000 samples
Epoch 18:
Train on 2000 samples
Epoch 19:
Train on 2000 samples
Epoch 20:
Train on 2000 samples
Epoch 21:
Train on 2000 samples
Epoch 22:
Train on 2000 samples
Epoch 23:
Train on 2000 samples
Epoch 24:
Train on 2000 samples
Epoch 25:
Train on 2000 samples
Epoch 26:
Train on 2000 samples
Epoch 27:
Train on 2000 samples
Epoch 28:
Train on 2000 samples
Epoch 29:
Train on 2000 samples
Epoch 30:
Train on 2000 samples
Epoch 31:
Train on 2000 samples
Epoch 32:
Train o

KeyboardInterrupt: 

In [50]:
!nvidia-smi

/usr/bin/sh: nvidia-smi: command not found


# View the model's predictions

In [18]:
import sentencepiece as spm

%cd /cs/phd/nadavb/cafa_project/data

sp = spm.SentencePieceProcessor()
sp.load('protopiece.model')

/cs/labs/michall/nadavb/cafa_project/data


In [48]:
def format_token_id(token_id):
    if token_id == PAD_TOKEN:
        return '/'
    elif token_id == MASK_TOKEN:
        return '?'
    else:
        return sp.id_to_piece(int(token_id))

def pad_to_max_len(*strings):
    max_len = max(map(len, strings))
    return [string + (max_len - len(string)) * ' ' for string in strings]

def display_model_result(i = 0):
    
    original_token_ids = dataset_tokens[i, :]
    used_mask = mask[i, :]
    masked_totken_ids = masked_dataset_tokens[i, :]
    predicted_token_ids = model.predict(masked_totken_ids.reshape(1, -1))[0, :, :].argmax(axis = -1)
    
#     print(np.concatenate([original_token_ids.reshape(-1, 1), masked_totken_ids.reshape(-1, 1), \
#             predicted_token_ids.reshape(-1, 1)], axis = 1))
    
    original_formatted_tokens = []
    predicted_formatted_tokens = []
    
    for original_token_id, mask_bit, predicted_token_id in zip(original_token_ids, used_mask, predicted_token_ids):
        mask_surrounding = '' if mask_bit else '?'
        original_formatted_token, predicted_formatted_token = pad_to_max_len(mask_surrounding + \
                format_token_id(original_token_id) + mask_surrounding, format_token_id(predicted_token_id))
        original_formatted_tokens.append(original_formatted_token)
        predicted_formatted_tokens.append(predicted_formatted_token)
        
    print(' '.join(original_formatted_tokens) + '\n' + ' '.join(predicted_formatted_tokens))

display_model_result(i = 99999)

▁MSK GLY ?DIP? ?SWA? ?TTE? TRT LAKL AGVE RLFE PQY ?MAL? ?QAG? VEK GEN LVV AAP ?TGSGKT? FI ALVA IVN SLAR AGGR AFY LVP LKS ?VAY? ?EKY? ?TSF? SILS RMG ?LKLK? ISVG DFR ?EGP? PEAP VVIA TYE ?KFD? SLLR VSP SLA RNV ?SVL? IVDE IHS ?VS? DPK RGP ?ILES? ?IVS? ?RML? ASAG EAQ ?LVGL? SA TVP ?NAG? EIAE WIG GKIV ES SWR ?PVP? LRE YVF KEY KLY SPTG GLR EVP RVY GLY DLDL AAEA IED GGQ ALV ?FTY? SRRR AVT LAKR AAK RLGR RLS SRE ARV YSA EASR AEGA PRS VAEE LASL IAAG ?IA? ?YHH? AGL ?PPS? LRK TVE EAFR AGAV ?KVV? YST PTL AAGV ?NLP? ARR VV ?IDS? ?YYR? ?YEA? GFR EPI ?RVAE? YKQ MAG RAGR PGL DEF GEAI IVA ERLD ?RPED? ?LIS? GYI RAP PERV ESR ?LAGL? ?RGLR? HF ILGI VA PEGE VS IGSI EKV SGLT LYSL QRG LPR ETI ARA VEDL SAW GLV ?EVK? GWR IAA TSL GREV ?AAV? YLD PESV ?PVF? ?REEV? KHL SFD NEF DIL ?YL? IST MPD MVR LPA ?TRR? ?EEE? RLLE ?AI? LDA SPR ?MLS? SVD WLG PEE ?MAA? ?VKT? AVVL KLW IDEA SED TIY GEW GVH TGDL LNM VST AEW IASG LSR IAP YLG ?LNS? KVS HIL SVI ARR ?IKH? ?GVK? PELL QLVE IPGV GRV RAR IL FEA ?GYR? SIED LATA ?RAE? DLM RLP L