<a href="https://colab.research.google.com/github/honzas83/bio-bert/blob/main/bert_sample_usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Example usage of BERT in protein classification

## Required Python modules

In [None]:
%%capture pip_install

!pip install tensorflow-text \
             transformers==3.3.0 \
             pyyaml \
             sentencepiece \
             keras-bert==0.81.0 \
             keras-transformer==0.33.0 \
             pandas \
             scikit-learn==0.23 \
             biopython

## Model download


In [None]:
!gdown --id 17FLvsbpjqR_SHAYY-S0YwAmjyLW2PYwx && unzip -u bert-pfam-10k.zip
!gdown --id 10jnY335GcVon8EGqo5h5Kxh_gjQoxbJH && unzip -u SCOPe.zip


### Imports

In [None]:
import keras_bert
from keras_bert import load_trained_model_from_checkpoint

import sentencepiece 
import os
import pandas as pd
import numpy as np
import itertools
import random
from Bio import SeqIO

import tensorflow as tf
import keras
import keras.backend as K
from keras import layers
from keras.layers import Input, Dense, TimeDistributed, Bidirectional, LSTM, Concatenate, Conv1D, Dropout, Dot, Lambda, GlobalAvgPool1D, Add, GaussianNoise, Flatten
from keras.models import Model, load_model
from keras_bert.layers import Extract, MaskedGlobalMaxPool1D
from keras.regularizers import l1
from keras_bert import gelu
from keras_position_wise_feed_forward import FeedForward
from tensorflow.python.framework import tensor_shape
from keras.callbacks import LearningRateScheduler, EarlyStopping

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, average_precision_score

### Experiment configuration

In [None]:
# BERT pretrained model
CHECKPOINT_PATH = './bert-pfam-10k/'
CHECKPOINT_NAME = 'model.ckpt-1500000'
BERT_CONFIG = 'bert_config.json'

# Tokenization model for SentencePieces
SPM_MODEL = "./bert-pfam-10k/pfam.model"

# Load 12 layers from BERT
LAYER_NUM = 12

# Fine-tune the self-attention layers in FINE_TUNE
#FINE_TUNE = [10, 11, 12]
FINE_TUNE = []

# Use True to increase number of trainable parameters in the BERT
# Requires: keras-bert==0.81.0 keras-transformer==0.33.0
USE_ADAPTER = False
#USE_ADAPTER = True   

# Output pooling type extract/max/kmax
POOL_TYPE = "kmax"
POOL_K = 4  # k-max pooling parameter

# Specification of SCOPe classes used in classification
# See: https://scop.berkeley.edu/ver=2.07
TARGETS = "abcd"
#TARGETS = "abcdefg"

# Experimental data 
# Download from: http://bergerlab-downloads.csail.mit.edu/bepler-protein-sequence-embeddings-from-structure-iclr2019/scope.tar.gz
# See also: https://github.com/tbepler/protein-sequence-embedding-iclr2019#data-sets
TRAIN_DATA = "./SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.train.fa"
DEV_DATA = "./SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.train.dev.fa"
TEST_DATA = "./SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.06.test.fa"
TEST_DATA_NEW = "./SCOPe/astral-scopedom-seqres-gd-sel-gs-bib-95-2.07-new.fa"

# Maximum number of SentencePieces per example
MAX_LEN = 128
MAX_LEN1 = (MAX_LEN-2)  # Reserve two slots for [CLS] and [SEP] tokens

# Training parameters
BATCH_SIZE = 32   # Limited by the size of GPU
LR_START = 1e-3   # Initial learning rate, uses a simple Sqrt decay
EPOCHS = 3       # Number of training epochs

# SentencePieces allow sampling of different tokenizations
# Use PREDICT_AVG=1 to disable averaging and use the non-sampled tokenization only
PREDICT_AVG = 5

# Save the resulting model into this file
SAVE_MODEL_FN = "bert-sample-usage.hdf5"

### Custom class for K-max pooling

In [None]:
class KMaxPooling(keras.layers.Layer):
    """
    K-max pooling layer that extracts the k-highest activations from a sequence (2nd dimension).
    TensorFlow backend.
    """
    def __init__(self, k=1, axis=1, **kwargs):
        super(KMaxPooling, self).__init__(**kwargs)
        self.supports_masking = True
        self.input_spec = layers.InputSpec(ndim=3)
        self.k = k

        assert axis in [1,2],  'expected dimensions (samples, filters, convolved_values),\
                   cannot fold along samples dimension or axis not in list [1,2]'
        self.axis = axis

        # need to switch the axis with the last elemnet
        # to perform transpose for tok k elements since top_k works in last axis
        self.transpose_perm = [0,1,2] #default
        self.transpose_perm[self.axis] = 2
        self.transpose_perm[2] = self.axis
    
    def get_config(self):
        return {"k": self.k, "axis": self.axis}
   
    def build(self, input_shape):
        super(KMaxPooling, self).build(input_shape)
        self._my_output_shape = [i for i in input_shape]
        self._my_output_shape[0] = -1
        self._my_output_shape[self.axis] = self.k
        self._my_output_shape = tf.convert_to_tensor(self._my_output_shape)

    def call(self, x):
        # swap sequence dimension to get top k elements along axis=1
        transposed_for_topk = tf.transpose(x, perm=self.transpose_perm)

        # extract top_k, returns two tensors [values, indices]
        top_k_vals, top_k_indices = tf.math.top_k(transposed_for_topk,
                                                  k=self.k, sorted=True,
                                                  name=None)
        # maintain the order of values as in the paper
        # sort indices
        sorted_top_k_ind = tf.sort(top_k_indices)
        flatten_seq = tf.reshape(transposed_for_topk, (-1,))
        shape_seq = tf.shape(transposed_for_topk)
        len_seq = tf.shape(flatten_seq)[0]
        indices_seq = tf.range(len_seq)
        indices_seq = tf.reshape(indices_seq, shape_seq)
        indices_gather = tf.gather(indices_seq, 0, axis=-1)
        indices_sum = tf.expand_dims(indices_gather, axis=-1)
        sorted_top_k_ind += indices_sum
        k_max_out = tf.gather(flatten_seq, sorted_top_k_ind)
        # return back to normal dimension but now sequence dimension has only k elements
        # performing another transpose will get the tensor back to its original shape
        # but will have k as its axis_1 size
        transposed_back = tf.transpose(k_max_out, perm=self.transpose_perm)

        
        return tf.reshape(transposed_back, self._my_output_shape)



### Pre-trained BERT loading

Loads the pre-trained BERT model into Keras model.

See also: https://github.com/CyberZHG/keras-bert

In [None]:
config_path = os.path.join(CHECKPOINT_PATH, BERT_CONFIG)
model_path = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
trainable = (['Encoder-{}-MultiHeadSelfAttention-Adapter'.format(i + 1) for i in range(LAYER_NUM)] +
    ['Encoder-{}-FeedForward-Adapter'.format(i + 1) for i in range(LAYER_NUM)] +
    ['Encoder-{}-MultiHeadSelfAttention-Norm'.format(i + 1) for i in range(LAYER_NUM)] +
    ['Encoder-{}-FeedForward-Norm'.format(i + 1) for i in range(LAYER_NUM)] +
    ['Encoder-{}-MultiHeadSelfAttention'.format(i) for i in FINE_TUNE])


bert_model = load_trained_model_from_checkpoint(
    config_path,
    model_path,
    seq_len=MAX_LEN,
    training=False,
    trainable=trainable,
    **({} if USE_ADAPTER else {"use_adapter": True})
)

### SentencePiece model initialization

Loads the SentencePiece tokenizer from file and defines some helper functions.

See also: https://github.com/google/sentencepiece/blob/master/python/README.md

In [None]:
spm = sentencepiece.SentencePieceProcessor(SPM_MODEL)

def encode(seq, enable_sampling=True):
    if enable_sampling:
        ret = spm.encode(seq.upper(), out_type=str, enable_sampling=True, alpha=0.2, nbest_size=-1)
    else:
        ret = spm.encode(seq.upper(), out_type=str)
    ret = [i.lstrip("▁") for i in ret]
    ret = [i for i in ret if i]
    return ret

UNK = spm.PieceToId("[UNK]")

def PieceToId(piece):
    ret = spm.PieceToId(piece)
    if ret == 0:
        return UNK
    else:
        return ret

### Batch vectorization

Vectorizes the input batch, which is a list of (sequence, target_class) pairs.

Returns the tuple ([tokens, segments], targets) where:
 - tokens is an array the shape (batch_size, MAX_LEN) encoding the SentencePieces
 - segments is and array of zeros with the same shape
 - targets are vectorized softmax target values

In [None]:
def vectorize_batch(batch, enable_sampling=True):

    tokens = []
    segments = []
    targets = []
    for seqA, y in batch:
        if y not in TARGETS:
            continue
            
        tokens1 = []
        tokens1.append(PieceToId("[CLS]"))
        encoded = encode(seqA, enable_sampling)
        if len(encoded) > MAX_LEN1:
            if random.random() < 0.5:
                encoded = encoded[:MAX_LEN1]
            else:
                encoded = encoded[-MAX_LEN1:]
        tokens1.extend([PieceToId(i) for i in encoded])
        tokens1.append(PieceToId("[SEP]"))
         
        segments1 = [0]*len(tokens1)

        assert len(tokens1) == len(segments1)
        while len(tokens1) < MAX_LEN:
            tokens1.append(0)
            segments1.append(0)

        Y = [0]*len(TARGETS)
        Y[TARGETS.index(y)] = 1
            
        tokens.append(tokens1)
        segments.append(segments1)
        targets.append(Y)
    
    tokens = np.array(tokens, dtype=np.int32)
    segments = np.array(segments, dtype=np.int32)
    targets = np.array(targets)
    return [tokens, segments], targets

### Model definition

In [None]:
input_tokens = Input(shape=(MAX_LEN,), dtype='int32')
input_segments = Input(shape=(MAX_LEN,), dtype='int32')

bert_out = bert_model([input_tokens, input_segments])

bert_out = Dropout(0.25, noise_shape=(None, 128, 1))(bert_out)

dense_out = TimeDistributed(Dense(256, activation="relu"))(bert_out)

if POOL_TYPE == "extract":
    pool_out = Extract(index=0)(dense_out)
elif POOL_TYPE == "max":
    pool_out = MaskedGlobalMaxPool1D()(dense_out)
elif POOL_TYPE == "kmax":    
    pool_out = KMaxPooling(k=POOL_K)(dense_out)
    pool_out = Flatten()(pool_out)

out = Dense(len(TARGETS))(pool_out)
out = keras.layers.Activation("softmax")(out)

In [None]:
model = Model(inputs=[input_tokens, input_segments],
              outputs=[out])
model.compile(optimizer="adam", loss=["categorical_crossentropy"], metrics=["accuracy"])

In [None]:
model.summary()

### Dataset loading functions

In [None]:
def iter_training_sample(fn):
    ret = []
    for seq_record in SeqIO.parse(fn, "fasta"):
        struct = seq_record.description.split()[1]
        seq = str(seq_record.seq)
        ret.append((seq.upper(), struct[0]))
    random.seed(fn)
    random.shuffle(ret)
    return ret

def train_gen(data, batch_size, shuffle=True, enable_sampling=True):
    while True:
        data_it = data[:]
        if shuffle:
            random.shuffle(data_it)
        for idx0 in range(0, len(data_it), batch_size):
            batch = data_it[idx0:idx0+batch_size]
            yield vectorize_batch(batch, enable_sampling=enable_sampling)

In [None]:
train_data = iter_training_sample(TRAIN_DATA)
train_iter = train_gen(train_data, BATCH_SIZE)
TRAIN_STEPS = len(train_data) // BATCH_SIZE

dev_data = iter_training_sample(DEV_DATA)
dev_iter = train_gen(dev_data, BATCH_SIZE, enable_sampling=False)
DEV_STEPS = len(dev_data) // BATCH_SIZE

### Model training

In [None]:
callbacks = [LearningRateScheduler(lambda n: LR_START*1/((n+1)**0.5), verbose=1)]

model.fit(train_iter, steps_per_epoch=TRAIN_STEPS, epochs=EPOCHS,
          validation_data=dev_iter,
          validation_steps=DEV_STEPS,
          callbacks=callbacks)

### Predictions on test data

In [None]:
test_data = iter_training_sample(TEST_DATA)

pred_Y = 0
for i in range(PREDICT_AVG):
    test_X, test_Y = next(train_gen(test_data, len(test_data), shuffle=False, enable_sampling=(i!=0)))
    pred_Y = pred_Y + model.predict(test_X, batch_size=128, verbose=True)

test_Y1 = test_Y.argmax(axis=1)
pred_Y1 = pred_Y.argmax(axis=1)
print("Model accuracy", accuracy_score(test_Y1, pred_Y1))

In [None]:
print("Confusion matrix")
pd.DataFrame(confusion_matrix(test_Y1, pred_Y1),
             columns=list(TARGETS),
             index=list(TARGETS)).round(2)


In [None]:
test_data = iter_training_sample(TEST_DATA_NEW)

pred_Y = 0
for i in range(PREDICT_AVG):
    test_X, test_Y = next(train_gen(test_data, len(test_data), shuffle=False, enable_sampling=(i!=0)))
    pred_Y = pred_Y + model.predict(test_X, batch_size=128, verbose=True)

test_Y1 = test_Y.argmax(axis=1)
pred_Y1 = pred_Y.argmax(axis=1)
print("Model accuracy", accuracy_score(test_Y1, pred_Y1))

In [None]:
model.save(SAVE_MODEL_FN)

## Copy out the resulting model into Google Drive

Clink on the displayed link and authorize the app to access your google drive.

In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

Modify the target directory according to your needs.

In [None]:
!cp $SAVE_MODEL_FN gdrive/MyDrive/Models/