In [1]:
import pandas as pd
from tensorflow import keras

from proteinbert import OutputType, OutputSpec, FinetuningModelGenerator, load_pretrained_model, \
finetune, evaluate_by_len
from proteinbert.finetuning import encode_train_and_valid_sets
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs
from os import path
import pickle

In [2]:
import wandb
from wandb.keras import WandbCallback

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [4]:
DATA_DIR = "../../data/"

In [5]:
OUTPUT_TYPE = OutputType(False, 'binary')
UNIQUE_LABELS = [0, 1]
OUTPUT_SPEC = OutputSpec(OUTPUT_TYPE, UNIQUE_LABELS)

In [6]:
pretrained_model_generator, input_encoder = load_pretrained_model("../../data/protein_bert/", "epoch_92400_sample_23500000.pkl")

In [22]:
model_generator = FinetuningModelGenerator(pretrained_model_generator, OUTPUT_SPEC, pretraining_model_manipulation_function = \
        get_model_with_hidden_layers_as_outputs, dropout_rate = 0.5)

In [8]:
wandb.init(project=f"TrainFrozen2", entity="kvetab")

[34m[1mwandb[0m: Currently logged in as: [33mkvetab[0m (use `wandb login --relogin` to force relogin)


In [9]:
train_data = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_train_data.csv"), index_col=0)
valid_data = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_valid_data.csv"), index_col=0)
test_data = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_test_data.csv"), index_col=0)
train_data.head()

Unnamed: 0,Antibody_ID,heavy,light,Y
2073,6aod,EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,DIVMTKSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKR...,0
1517,4yny,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,EFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2025,5xcv,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,QFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1
2070,6and,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPG...,1
666,2xqy,QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLE...,DIVLTQSPASLALSLGQRATISCRASKSVSTSGYSYMYWYQQKPGQ...,0


In [10]:
train_data["seq"] = train_data["heavy"] + train_data["light"]
valid_data["seq"] = valid_data["heavy"] + valid_data["light"]
test_data["seq"] = test_data["heavy"] + test_data["light"]

In [13]:
training_callbacks = [
    keras.callbacks.ReduceLROnPlateau(patience = 1, factor = 0.25, min_lr = 1e-07, verbose = 1),
    keras.callbacks.EarlyStopping(patience = 3, restore_best_weights = True),
    WandbCallback()
]
seq_len = 512

In [14]:
encoded_train_set, encoded_valid_set = encode_train_and_valid_sets(
    train_data['seq'], 
    train_data['Y'], 
    valid_data['seq'], 
    valid_data['Y'], 
    input_encoder, 
    OUTPUT_SPEC, 
    seq_len
)

[2022_01_29-10:32:09] Training set: Filtered out 0 of 1338 (0.0%) records of lengths exceeding 510.
[2022_01_29-10:32:09] Validation set: Filtered out 0 of 120 (0.0%) records of lengths exceeding 510.


In [15]:
train_X, train_Y, train_sample_weigths = encoded_train_set

In [23]:
model = model_generator.create_model(seq_len=512, freeze_pretrained_layers=True)

In [24]:
epoch_num = 50
batch_size = 128
learning_rate = 1e-1

In [25]:
wandb.config = {
      "learning_rate": learning_rate,
      "epochs": epoch_num * 2,
      "batch_size": batch_size
    }

In [26]:
model.optimizer.lr = learning_rate

In [27]:
model.fit(
    x=train_X,
    y=train_Y,
    batch_size=batch_size,
    epochs=epoch_num,
    callbacks=training_callbacks,
    validation_data=encoded_valid_set
)


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50

Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.02500000037252903.
Epoch 7/50

Epoch 00007: ReduceLROnPlateau reducing learning rate to 0.0062500000931322575.
Epoch 8/50
Epoch 9/50

Epoch 00009: ReduceLROnPlateau reducing learning rate to 0.0015625000232830644.
Epoch 10/50
Epoch 11/50
Epoch 12/50

Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.0003906250058207661.
Epoch 13/50

Epoch 00013: ReduceLROnPlateau reducing learning rate to 9.765625145519152e-05.
Epoch 14/50

Epoch 00014: ReduceLROnPlateau reducing learning rate to 2.441406286379788e-05.


<keras.callbacks.History at 0x7f21f0146110>

In [45]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs:", len(physical_devices))

Num GPUs: 1


In [21]:
model_generator.update_state(model)