### 1. Importing important libraries

In [97]:

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import timeit

import os
import pickle
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import  confusion_matrix, ConfusionMatrixDisplay
from matplotlib.pyplot import figure
import wandb

wandb.login()

True

In [98]:
!pwd

/home/danish/PPI_prediction/new_experiment


### 2. Basic EDA

In [99]:
train =  pd.read_csv("Train_0_0.tsv",delimiter='\t',header=None)
test =  pd.read_csv("Test_0_0.tsv",delimiter='\t',header=None)

In [100]:
train

Unnamed: 0,0,1,2
0,8812,165140,1
1,5581,93953,0
2,56165,284293,0
3,1788,10919,1
4,999,54894,1
...,...,...,...
99995,3172,10987,1
99996,10107,58478,0
99997,27185,55180,0
99998,1674,6047,1


In [101]:
test[2].value_counts()

1    3415
0    3415
Name: 2, dtype: int64

In [102]:
train_pos = train[train[2] == 1][0:50000]
train_neg = train[train[2] == 0][0:50000]
train = pd.concat([train_pos,train_neg])

test_pos = test[test[2] == 1]
test_neg = test[test[2] ==0][0:len(test_pos)]
test = pd.concat([test_pos,test_neg])
test[2].value_counts()

1    3415
0    3415
Name: 2, dtype: int64

In [103]:
val_pos =  test[test[2] == 1][0:1000]
val_neg =  test[test[2] == 0][0:1000]
val = pd.concat([val_pos,val_neg])

test_pos =  test[test[2] == 1][1000:]
test_neg =  test[test[2] == 0][1000:]
test = pd.concat([test_pos,test_neg])

In [104]:
print("Size of Train dataset: ", len(train))
print("Size of Test dataset: ", len(test))
print("Size of val dataset: ", len(val))

Size of Train dataset:  100000
Size of Test dataset:  4830
Size of val dataset:  2000


In [105]:
print(f"Number of negative points in training set: {train[2].value_counts()[0]}")
print(f"Number of positive points in training set: {train[2].value_counts()[1]}")
print("----"*57)
print(f"Number of negative points in test set: {test[2].value_counts()[0]}")
print(f"Number of positive points in test set: {test[2].value_counts()[1]}")
print("----"*57)
print(f"Number of negative points in test set: {val[2].value_counts()[0]}")
print(f"Number of positive points in test set: {val[2].value_counts()[1]}")

Number of negative points in training set: 50000
Number of positive points in training set: 50000
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Number of negative points in test set: 2415
Number of positive points in test set: 2415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Number of negative points in test set: 1000
Number of positive points in test set: 1000


### 3. Importing embedding vectors from pickle file

In [106]:
with open("sum_of_amino_acid_benchmark_vector_bert_normalize_bfd.pickle",'rb') as handle:
    dc = pickle.load(handle)

In [107]:
train

Unnamed: 0,0,1,2
0,8812,165140,1
3,1788,10919,1
4,999,54894,1
6,10771,51438,1
7,274,5338,1
...,...,...,...
99993,138065,219854,0
99994,25763,729597,0
99996,10107,58478,0
99997,27185,55180,0


In [108]:
def return_embed(prot_name):
    try:
        return dc[str(prot_name)]*500-1
    except:
        return np.nan
train['embed_vec_protein_A'] = train[0].apply(return_embed)
train['embed_vec_protein_B'] = train[1].apply(return_embed)

test['embed_vec_protein_A'] = test[0].apply(return_embed)
test['embed_vec_protein_B'] = test[1].apply(return_embed)

val['embed_vec_protein_A'] = val[0].apply(return_embed)
val['embed_vec_protein_B'] = val[1].apply(return_embed)


train = train.dropna()
test = test.dropna()
val = val.dropna()

In [109]:
train.head()

Unnamed: 0,0,1,2,embed_vec_protein_A,embed_vec_protein_B
0,8812,165140,1,"[-4.137003, 1.1387072, -5.577691, 23.804869, 1...","[10.549672, 0.6898947, -7.3381624, -17.99396, ..."
4,999,54894,1,"[3.429071, 1.9356925, -6.0805483, -0.8916067, ...","[6.4256415, 13.054534, 0.22838259, 16.80025, 1..."
6,10771,51438,1,"[6.155548, 9.562074, -0.1157856, 3.327611, 6.1...","[10.115945, -1.2304958, -9.726263, 16.496975, ..."
7,274,5338,1,"[-10.636223, -7.661087, -16.337265, 18.760033,...","[10.423297, 10.890132, -5.2806287, 15.781599, ..."
8,537,8856,1,"[-6.659892, -3.1955101, -0.2977332, 5.195242, ...","[14.750025, 17.76645, -6.3672185, 10.122561, 1..."


In [110]:
train_features_Protein_A = []
train_features_Protein_B = []
train_label = []
test_features_Protein_A = []
test_features_Protein_B = []
test_label =[]
val_features_Protein_A = []
val_features_Protein_B = []
val_label = []
for i in tqdm(range(len(train))):
    train_features_Protein_A.append(np.array(train.iloc[i].embed_vec_protein_A))
    train_features_Protein_B.append(np.array(train.iloc[i].embed_vec_protein_B))
    train_label.append(np.array(train.iloc[i][2]))
    
for i in tqdm(range(len(test))):
    test_features_Protein_A.append(np.array(test.iloc[i].embed_vec_protein_A))
    test_features_Protein_B.append(np.array(test.iloc[i].embed_vec_protein_B))
    test_label.append(np.array(test.iloc[i][2]))  
for i in tqdm(range(len(val))):
    
    
    val_features_Protein_A.append(np.array(val.iloc[i].embed_vec_protein_A))
    val_features_Protein_B.append(np.array(val.iloc[i].embed_vec_protein_B))
    val_label.append(np.array(val.iloc[i][2]))

100%|██████████| 79639/79639 [00:23<00:00, 3362.79it/s]
100%|██████████| 3956/3956 [00:01<00:00, 3365.53it/s]
100%|██████████| 1629/1629 [00:00<00:00, 3328.37it/s]


In [111]:
train_features_Protein_A = np.array(train_features_Protein_A)
train_features_Protein_B = np.array(train_features_Protein_B)
train_label = np.array(train_label)

test_features_Protein_A = np.array(test_features_Protein_A)
test_features_Protein_B = np.array(test_features_Protein_B)
test_label = np.array(test_label)

val_features_Protein_A = np.array(val_features_Protein_A)
val_features_Protein_B = np.array(val_features_Protein_B)
val_label = np.array(val_label)

In [112]:
np.unique(train_label, return_counts=True)

(array([0, 1]), array([38723, 40916]))

### 4. Dataloader

In [113]:
class Data(Dataset):
    
    def __init__(self, X_data_A,X_data_B, y_data):
        self.X_data_A = X_data_A
        self.X_data_B = X_data_B
        self.y_data = y_data
        
    def __getitem__(self, index):
        return self.X_data_A[index],self.X_data_B[index], self.y_data[index]
        
    def __len__ (self):
        return len(self.X_data_A)

In [114]:
train_data = Data(torch.FloatTensor(train_features_Protein_A), torch.FloatTensor(train_features_Protein_B),
                       torch.FloatTensor(train_label))

test_data = Data(torch.FloatTensor(test_features_Protein_A), torch.FloatTensor(test_features_Protein_B),
                       torch.FloatTensor(test_label))

val_data = Data(torch.FloatTensor(val_features_Protein_A), torch.FloatTensor(val_features_Protein_B),
                       torch.FloatTensor(val_label))

In [115]:

train_loader = DataLoader(dataset=train_data, batch_size=512, shuffle=True,drop_last=True)
test_loader = DataLoader(dataset=test_data, batch_size=512,shuffle=True,drop_last=True )
val_loader = DataLoader(dataset=val_data, batch_size=512,shuffle=True,drop_last=True )

In [116]:
for i,j,k in train_loader:
    print(i.shape)
    print(j.shape)
    print(k.shape)
    break
    

torch.Size([512, 1024])
torch.Size([512, 1024])
torch.Size([512])


### 5. Building Models

In [117]:
class BertClassifier(nn.Module):
    
    def __init__(self,config, embed_dim =1024):
        super(BertClassifier,self).__init__()
        self.relu = nn.ReLU()
        self.config = config
        self.conv1  = nn.Conv1d(in_channels = 1,out_channels = 33, kernel_size = 10, stride=1)
        self.fc1 = nn.Linear(33495,config['dim_1'])
        self.fully_connected_layers_1 = nn.ModuleList([nn.Linear(config['dim_1']*2,config['dim_1']*2)
                                                    for _ in range(config['layer_fc_1'])])
        
        
    
        self.fc_2 = nn.Linear(config['dim_1']*2,config['dim_2'])
        self.fully_connected_layers_2 = nn.ModuleList([nn.Linear(config['dim_2'],config['dim_2'])
                                                    for _ in range(config['layer_fc_2'])]) 
        self.bn2 = nn.BatchNorm1d(num_features=config['dim_2'])
        self.fc3 = nn.Linear(config['dim_2'],config['dim_1'])
        
        self.fc4 = nn.Linear(config['dim_1'],256)
        self.drop = nn.Dropout(p = 0.2)
        self.fc5 = nn.Linear(256,128)
        self.fc6 = nn.Linear(128,64)
        self.fc7 = nn.Linear(64,32)
        self.fc8 = nn.Linear(32,16)
        self.fc9 = nn.Linear(16,8)
        self.fc10 = nn.Linear(8,1)
    
    def forward(self, inputs_A,inputs_B):
        
        
        inputs_A = inputs_A.reshape(512,1,1024)
        output_conv_A = self.relu(self.conv1(inputs_A))
        output_conv_A = output_conv_A.reshape(512,33495) 
        output_A = self.relu(self.fc1(output_conv_A))
        
        inputs_B = inputs_B.reshape(512,1,1024)
        output_conv_B = self.relu(self.conv1(inputs_B))
        output_conv_B = output_conv_B.reshape(512,33495) 
        output_B = self.relu(self.fc1(output_conv_B))
        
        
        output = torch.cat((output_A, output_B),1)
        for i in range(self.config['layer_fc_1']):
            output = self.relu(self.fully_connected_layers_1[i](output))
        output = self.relu(self.fc_2(output))
        for i in range(self.config['layer_fc_2']):
            output = self.relu(self.fully_connected_layers_2[i](output))
            
        output  = self.bn2(output)
        output = self.relu(self.fc3(output))
        if self.config['dropout']:
            output = self.drop(output)
            
        output = self.relu(self.fc4(output))
        output = self.relu(self.fc5(output))
        output = self.relu(self.fc6(output))
        output = self.fc7(output)
        output = self.fc8(output)
        output = self.fc9(output)
        output = self.fc10(output)

        return output

In [118]:
def build_optimizer(network, optimizer,learning_rate, momentum, weight_decay, amsgrad,momentum_decay):
    
    
    if optimizer == "sgd":
        optimizer_ = optim.SGD(network.parameters(),
                              lr = learning_rate, momentum = momentum, weight_decay = weight_decay,
                              )
        
        
    elif optimizer == "adam":
        optimizer_ = optim.Adam(network.parameters(),
                               lr = learning_rate, betas = (0.9,0.999), weight_decay = weight_decay,
                               amsgrad = amsgrad)
        
    elif optimizer == "nadam":
        optimizer_ = torch.optim.NAdam(network.parameters(), lr = learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=  weight_decay, momentum_decay=momentum_decay, foreach=None)
        
        
    elif optimizer == "radam":
        optimizer_ = torch.optim.RAdam(network.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, foreach=None)
        
    return optimizer_

In [119]:
# sweep_config = {
#     'method': 'random'
    
#     }
# metric = {
#     'name': 'val_accuracy',
#     'goal': 'maximize'   
#     }
# early_terminate = {"type": "hyperband",
#       "min_iter": 3 }

# sweep_config['metric'] = metric 
# sweep_config['early_terminate'] = early_terminate 

# parameters_dict = {
    
#     'layer_fc_1': {
#         'values': [2]
#         },
   
#     'dim_1': {
#           'values': [2048]
#         },
    
#     'layer_fc_2': {
#         'values': [2]
#         },
#     'dim_2': {
#           'values': [512]
#         },
    
    
#     'dropout': {
#           'values': [False]
#         },
   
    
    
#     'optimizer': {
#           'values': ['adam',"rms_prop"]   #
#         }
#     ,
  
    
#     'learning_rate': {
#             'values':[0.0001,0.001,0.01,0.1]
#         },
    
    
#     'momentum': {
#           'values': [0.95,0.9,0.99]
#         },
    
#     'weight_decay': {
#             'values': [0.009827436437331628,0.095,0.95]
#         },
   
        
#     'amsgrad': {
#           'values': [False]
#         },
    
    
#     }


# sweep_config['parameters'] = parameters_dict
# parameters_dict.update({
#     'epochs': {
#         'value': 30}
#     })


# import pprint

config = dict(layer_fc_1 = 2,
        dim_1 = 512,
        layer_fc_2 = 2,
        dim_2 = 512,
        dropout = False,
        optimizer = 'adam',
        learning_rate = 0.00007,
        momentum = 0.95,
        weight_decay = 0.0001,
        amsgrad = False,
        epochs = 100,
        momentum_decay = 0.004)


import pprint

pprint.pprint(config)

{'amsgrad': False,
 'dim_1': 512,
 'dim_2': 512,
 'dropout': False,
 'epochs': 100,
 'layer_fc_1': 2,
 'layer_fc_2': 2,
 'learning_rate': 7e-05,
 'momentum': 0.95,
 'momentum_decay': 0.004,
 'optimizer': 'adam',
 'weight_decay': 0.0001}


### 6. Training

In [138]:
import random
import time
from tqdm import tqdm
loss_fn = nn.BCEWithLogitsLoss()
def train(config, train_dataloader,val_dataloader = None):
    
    best_accuracy = 0
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print("Start training...\n")
    epochs = config['epochs']
    
    model = BertClassifier(config).to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer,learning_rate, momentum, weight_decay, amsgrad,momentum_decay = config['optimizer'],config['learning_rate'], config['momentum'], config['weight_decay'], config['amsgrad'],config['momentum_decay']
    optimizer = build_optimizer(model,optimizer,learning_rate, momentum, weight_decay, amsgrad,momentum_decay)
    scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.09)
    scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,15], gamma=0.1)
    
    for epoch_i in range(1,epochs+1):
        
        total_loss = 0
        model.train()
        
        for step,batch in tqdm(enumerate(train_dataloader)):
            
            inputs_A,inputs_B, b_labels = tuple(t.to(device) for t in batch)
            b_labels = b_labels.reshape((1,512,1)).squeeze(0)
            model.zero_grad()
            logits = model(inputs_A,inputs_B)
            loss = loss_fn(logits,b_labels.float()) 
            total_loss += loss.item()
            loss.mean().backward()
            optimizer.step()
        scheduler1.step()
        scheduler2.step()
               
        avg_train_loss = total_loss / len(train_dataloader)
        
    
        if val_dataloader is not None:
                
                val_loss, val_accuracy = evaluate(model, val_dataloader)
                print(f"best accuracy {best_accuracy}")
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save({
                        'epoch': epoch_i ,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss_fn,
                        }, 'best_model_trained_fc_v3.pth')
        
                
        print(f"Epoch: {epoch_i} | Training Loss: {avg_train_loss}  | Validation Loss: {val_loss}  | Accuracy: {val_accuracy:.2f}")
        with open('result.txt', 'a') as f:
            print(f"Epoch: {epoch_i} | Training Loss: {avg_train_loss}  | Validation Loss: {val_loss}  | Accuracy: {val_accuracy:.2f}", file=f) 
    print("\n")
   
    print(f"Training complete! Best accuracy: {best_accuracy:.2f}%.")
    

def evaluate(model,val_dataloader):
    
    
    model.eval()
    val_accuracy = []
    val_loss = []
    for batch in tqdm(val_dataloader):
       
        device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
        inputs_A,inputs_B ,b_labels = tuple(t.to(device) for t in batch)
        b_labels = b_labels.reshape((1,512,1)).squeeze(0)
        with torch.no_grad():
                logits = model(inputs_A,inputs_B)
        
        loss = loss_fn(logits, b_labels.float())
        val_loss.append(loss.item())
        preds = torch.round(torch.sigmoid(logits))
        
        accuracy = (preds.float() == b_labels.float()).cpu().numpy().mean() * 100
        val_accuracy.append(accuracy)
    
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)

    return val_loss, val_accuracy 

In [139]:
def make(config):
    # Make the data
    if torch.cuda.is_available():
        device = torch.device("cuda:1")
        print('The code uses GPU...')
    else:
        device = torch.device('cpu')
        print('The code uses CPU!!!')

    
    model = BertClassifier(config).to(device)

    # Make the loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer,learning_rate, momentum, weight_decay, amsgrad,momentum_decay = config['optimizer'],config['learning_rate'], config['momentum'], config['weight_decay'], config['amsgrad'],config['momentum_decay']
    optimizer = build_optimizer(model,optimizer,learning_rate, momentum, weight_decay, amsgrad,momentum_decay)
    
    return model, criterion, optimizer

In [140]:
def model_pipeline(config=None):

    # tell wandb to get started
    
      # access all HPs through wandb.config, so logging matches execution!
      
      
      # make the model, data, and optimization problem
      
      
      
      # and use them to train the model
      train(config, train_loader,val_dataloader = val_loader)
     
      for i,j in test_loader:
          x = i
     


In [141]:
config = dict(layer_fc_1 = 1,
        dim_1 = 256,
        layer_fc_2 = 1,
        dim_2 = 256,
        dropout = False,
        optimizer = 'adam',
        learning_rate = 0.00007,
        momentum = 0.85,
        weight_decay = 0.00001,
        amsgrad = False,
        epochs = 100,
        momentum_decay = 0.04)


import pprint

pprint.pprint(config)

model_pipeline(config)

{'amsgrad': False,
 'dim_1': 256,
 'dim_2': 256,
 'dropout': False,
 'epochs': 100,
 'layer_fc_1': 1,
 'layer_fc_2': 1,
 'learning_rate': 7e-05,
 'momentum': 0.85,
 'momentum_decay': 0.04,
 'optimizer': 'adam',
 'weight_decay': 1e-05}
Start training...



155it [00:03, 50.38it/s]
100%|██████████| 3/3 [00:00<00:00, 70.26it/s]


best accuracy 0
Epoch: 1 | Training Loss: 0.5884150495452266  | Validation Loss: 0.5400416652361552  | Accuracy: 72.92


155it [00:03, 50.36it/s]
100%|██████████| 3/3 [00:00<00:00, 70.64it/s]


best accuracy 72.91666666666667
Epoch: 2 | Training Loss: 0.45559899576248664  | Validation Loss: 0.5162798464298248  | Accuracy: 75.00


155it [00:03, 46.55it/s]
100%|██████████| 3/3 [00:00<00:00, 68.87it/s]


best accuracy 75.0
Epoch: 3 | Training Loss: 0.43983732442702017  | Validation Loss: 0.5167311231295267  | Accuracy: 75.00


155it [00:03, 49.38it/s]
100%|██████████| 3/3 [00:00<00:00, 68.09it/s]


best accuracy 75.0
Epoch: 4 | Training Loss: 0.4381897909025992  | Validation Loss: 0.5099599361419678  | Accuracy: 75.39


155it [00:03, 49.40it/s]
100%|██████████| 3/3 [00:00<00:00, 67.85it/s]


best accuracy 75.390625
Epoch: 5 | Training Loss: 0.43849333428567455  | Validation Loss: 0.5206258296966553  | Accuracy: 74.48


155it [00:03, 45.72it/s]
100%|██████████| 3/3 [00:00<00:00, 69.65it/s]


best accuracy 75.390625
Epoch: 6 | Training Loss: 0.43819870237381225  | Validation Loss: 0.5228941837946574  | Accuracy: 74.48


155it [00:03, 49.39it/s]
100%|██████████| 3/3 [00:00<00:00, 69.96it/s]


best accuracy 75.390625
Epoch: 7 | Training Loss: 0.4382821604128807  | Validation Loss: 0.5231680075327555  | Accuracy: 74.61


155it [00:03, 49.45it/s]
100%|██████████| 3/3 [00:00<00:00, 69.15it/s]


best accuracy 75.390625
Epoch: 8 | Training Loss: 0.438079825908907  | Validation Loss: 0.5193617443243662  | Accuracy: 74.61


155it [00:03, 45.91it/s]
100%|██████████| 3/3 [00:00<00:00, 69.22it/s]


best accuracy 75.390625
Epoch: 9 | Training Loss: 0.43832576505599485  | Validation Loss: 0.5193424622217814  | Accuracy: 74.54


155it [00:03, 49.34it/s]
100%|██████████| 3/3 [00:00<00:00, 69.25it/s]


best accuracy 75.390625
Epoch: 10 | Training Loss: 0.4383297724108542  | Validation Loss: 0.518080602089564  | Accuracy: 74.87


155it [00:03, 49.39it/s]
100%|██████████| 3/3 [00:00<00:00, 77.61it/s]


best accuracy 75.390625
Epoch: 11 | Training Loss: 0.43816626129611846  | Validation Loss: 0.519973615805308  | Accuracy: 75.00


155it [00:03, 46.01it/s]
100%|██████████| 3/3 [00:00<00:00, 68.33it/s]


best accuracy 75.390625
Epoch: 12 | Training Loss: 0.4384944158215677  | Validation Loss: 0.5181633730729421  | Accuracy: 74.74


155it [00:03, 49.71it/s]
100%|██████████| 3/3 [00:00<00:00, 73.86it/s]


best accuracy 75.390625
Epoch: 13 | Training Loss: 0.4385120068826983  | Validation Loss: 0.5131854812304179  | Accuracy: 75.07


155it [00:03, 49.45it/s]
100%|██████████| 3/3 [00:00<00:00, 68.13it/s]


best accuracy 75.390625
Epoch: 14 | Training Loss: 0.43825615067635815  | Validation Loss: 0.5234959522883097  | Accuracy: 74.15


155it [00:03, 46.13it/s]
100%|██████████| 3/3 [00:00<00:00, 67.77it/s]


best accuracy 75.390625
Epoch: 15 | Training Loss: 0.43869044780731203  | Validation Loss: 0.5120989680290222  | Accuracy: 75.20


155it [00:03, 49.01it/s]
100%|██████████| 3/3 [00:00<00:00, 69.63it/s]


best accuracy 75.390625
Epoch: 16 | Training Loss: 0.4382560366584409  | Validation Loss: 0.5134421785672506  | Accuracy: 75.26


155it [00:03, 49.51it/s]
100%|██████████| 3/3 [00:00<00:00, 69.75it/s]


best accuracy 75.390625
Epoch: 17 | Training Loss: 0.43836645618561776  | Validation Loss: 0.5161235829194387  | Accuracy: 74.87


155it [00:03, 49.65it/s]
100%|██████████| 3/3 [00:00<00:00, 68.84it/s]


best accuracy 75.390625
Epoch: 18 | Training Loss: 0.43837068984585426  | Validation Loss: 0.5216646989186605  | Accuracy: 74.54


155it [00:03, 49.22it/s]
100%|██████████| 3/3 [00:00<00:00, 70.29it/s]


best accuracy 75.390625
Epoch: 19 | Training Loss: 0.4378655066413264  | Validation Loss: 0.5121968984603882  | Accuracy: 74.87


155it [00:03, 46.20it/s]
100%|██████████| 3/3 [00:00<00:00, 69.74it/s]


best accuracy 75.390625
Epoch: 20 | Training Loss: 0.4380676111867351  | Validation Loss: 0.5194561084111532  | Accuracy: 74.54


155it [00:03, 49.56it/s]
100%|██████████| 3/3 [00:00<00:00, 70.10it/s]


best accuracy 75.390625
Epoch: 21 | Training Loss: 0.43793478588904106  | Validation Loss: 0.5237301190694174  | Accuracy: 74.22


155it [00:03, 49.51it/s]
100%|██████████| 3/3 [00:00<00:00, 68.82it/s]


best accuracy 75.390625
Epoch: 22 | Training Loss: 0.438488019281818  | Validation Loss: 0.5216884811719259  | Accuracy: 74.54


155it [00:03, 46.03it/s]
100%|██████████| 3/3 [00:00<00:00, 69.52it/s]


best accuracy 75.390625
Epoch: 23 | Training Loss: 0.43828402526917  | Validation Loss: 0.5175898770491282  | Accuracy: 74.87


155it [00:03, 49.66it/s]
100%|██████████| 3/3 [00:00<00:00, 70.17it/s]


best accuracy 75.390625
Epoch: 24 | Training Loss: 0.43817469842972295  | Validation Loss: 0.5196505089600881  | Accuracy: 74.41


155it [00:03, 46.10it/s]
100%|██████████| 3/3 [00:00<00:00, 68.16it/s]


best accuracy 75.390625
Epoch: 25 | Training Loss: 0.43823693010114856  | Validation Loss: 0.5150837500890096  | Accuracy: 75.07


155it [00:03, 49.27it/s]
100%|██████████| 3/3 [00:00<00:00, 69.25it/s]


best accuracy 75.390625
Epoch: 26 | Training Loss: 0.4384529819411616  | Validation Loss: 0.513094445069631  | Accuracy: 75.07


155it [00:03, 49.62it/s]
100%|██████████| 3/3 [00:00<00:00, 70.17it/s]


best accuracy 75.390625
Epoch: 27 | Training Loss: 0.43818537765933624  | Validation Loss: 0.5207540988922119  | Accuracy: 74.74


155it [00:03, 49.43it/s]
100%|██████████| 3/3 [00:00<00:00, 69.84it/s]


best accuracy 75.390625
Epoch: 28 | Training Loss: 0.4384802031901575  | Validation Loss: 0.5208157102266947  | Accuracy: 74.67


155it [00:03, 45.74it/s]
100%|██████████| 3/3 [00:00<00:00, 68.76it/s]


best accuracy 75.390625
Epoch: 29 | Training Loss: 0.4383002898385448  | Validation Loss: 0.5161800980567932  | Accuracy: 75.07


155it [00:03, 49.59it/s]
100%|██████████| 3/3 [00:00<00:00, 69.45it/s]


best accuracy 75.390625
Epoch: 30 | Training Loss: 0.43818801641464233  | Validation Loss: 0.5203109383583069  | Accuracy: 74.54


155it [00:03, 45.95it/s]
100%|██████████| 3/3 [00:00<00:00, 69.61it/s]


best accuracy 75.390625
Epoch: 31 | Training Loss: 0.43871129635841616  | Validation Loss: 0.5194349884986877  | Accuracy: 75.00


155it [00:03, 49.58it/s]
100%|██████████| 3/3 [00:00<00:00, 69.30it/s]


best accuracy 75.390625
Epoch: 32 | Training Loss: 0.43862001665176886  | Validation Loss: 0.5184207955996195  | Accuracy: 74.48


155it [00:03, 49.63it/s]
100%|██████████| 3/3 [00:00<00:00, 76.91it/s]


best accuracy 75.390625
Epoch: 33 | Training Loss: 0.4380664477425237  | Validation Loss: 0.5123016635576884  | Accuracy: 75.39


155it [00:03, 49.54it/s]
100%|██████████| 3/3 [00:00<00:00, 69.85it/s]


best accuracy 75.390625
Epoch: 34 | Training Loss: 0.438270386380534  | Validation Loss: 0.5212702453136444  | Accuracy: 74.61


155it [00:03, 49.62it/s]
100%|██████████| 3/3 [00:00<00:00, 67.97it/s]


best accuracy 75.390625
Epoch: 35 | Training Loss: 0.4381669500181752  | Validation Loss: 0.514831135670344  | Accuracy: 74.80


155it [00:03, 45.68it/s]
100%|██████████| 3/3 [00:00<00:00, 68.92it/s]


best accuracy 75.390625
Epoch: 36 | Training Loss: 0.43851785044516284  | Validation Loss: 0.512914220492045  | Accuracy: 75.00


155it [00:03, 49.17it/s]
100%|██████████| 3/3 [00:00<00:00, 68.78it/s]


best accuracy 75.390625
Epoch: 37 | Training Loss: 0.43830032713951605  | Validation Loss: 0.5144792099793752  | Accuracy: 74.93


155it [00:03, 49.47it/s]
100%|██████████| 3/3 [00:00<00:00, 69.47it/s]


best accuracy 75.390625
Epoch: 38 | Training Loss: 0.4380388811711342  | Validation Loss: 0.5163825651009878  | Accuracy: 74.80


155it [00:03, 46.22it/s]
100%|██████████| 3/3 [00:00<00:00, 69.14it/s]


best accuracy 75.390625
Epoch: 39 | Training Loss: 0.43799181253679337  | Validation Loss: 0.5162126620610555  | Accuracy: 74.93


155it [00:03, 49.24it/s]
100%|██████████| 3/3 [00:00<00:00, 74.28it/s]


best accuracy 75.390625
Epoch: 40 | Training Loss: 0.43866238305645605  | Validation Loss: 0.5024237235387167  | Accuracy: 75.65


155it [00:03, 49.49it/s]
100%|██████████| 3/3 [00:00<00:00, 70.02it/s]


best accuracy 75.65104166666667
Epoch: 41 | Training Loss: 0.43865638086872716  | Validation Loss: 0.5151950716972351  | Accuracy: 74.93


155it [00:03, 46.00it/s]
100%|██████████| 3/3 [00:00<00:00, 67.76it/s]


best accuracy 75.65104166666667
Epoch: 42 | Training Loss: 0.4382599997904993  | Validation Loss: 0.5245776573816935  | Accuracy: 74.41


155it [00:03, 49.23it/s]
100%|██████████| 3/3 [00:00<00:00, 68.64it/s]


best accuracy 75.65104166666667
Epoch: 43 | Training Loss: 0.4383506111560329  | Validation Loss: 0.5201704899470011  | Accuracy: 74.41


155it [00:03, 48.75it/s]
100%|██████████| 3/3 [00:00<00:00, 68.40it/s]


best accuracy 75.65104166666667
Epoch: 44 | Training Loss: 0.438204469988423  | Validation Loss: 0.5254964033762614  | Accuracy: 74.22


155it [00:03, 45.61it/s]
100%|██████████| 3/3 [00:00<00:00, 68.99it/s]


best accuracy 75.65104166666667
Epoch: 45 | Training Loss: 0.43833802823097473  | Validation Loss: 0.514864424864451  | Accuracy: 74.87


155it [00:03, 49.58it/s]
100%|██████████| 3/3 [00:00<00:00, 70.06it/s]


best accuracy 75.65104166666667
Epoch: 46 | Training Loss: 0.43826047201310436  | Validation Loss: 0.5178589820861816  | Accuracy: 74.80


155it [00:03, 49.44it/s]
100%|██████████| 3/3 [00:00<00:00, 67.69it/s]


best accuracy 75.65104166666667
Epoch: 47 | Training Loss: 0.4382505338038168  | Validation Loss: 0.520646870136261  | Accuracy: 74.48


155it [00:03, 45.87it/s]
100%|██████████| 3/3 [00:00<00:00, 72.49it/s]


best accuracy 75.65104166666667
Epoch: 48 | Training Loss: 0.4383307651165993  | Validation Loss: 0.5170087615648905  | Accuracy: 74.74


155it [00:03, 48.91it/s]
100%|██████████| 3/3 [00:00<00:00, 67.94it/s]


best accuracy 75.65104166666667
Epoch: 49 | Training Loss: 0.43803377343762306  | Validation Loss: 0.519640843073527  | Accuracy: 74.48


155it [00:03, 49.29it/s]
100%|██████████| 3/3 [00:00<00:00, 66.05it/s]


best accuracy 75.65104166666667
Epoch: 50 | Training Loss: 0.4379986155417658  | Validation Loss: 0.518661896387736  | Accuracy: 74.74


155it [00:03, 45.88it/s]
100%|██████████| 3/3 [00:00<00:00, 70.09it/s]


best accuracy 75.65104166666667
Epoch: 51 | Training Loss: 0.43791879607785134  | Validation Loss: 0.5162906448046366  | Accuracy: 75.07


155it [00:03, 49.43it/s]
100%|██████████| 3/3 [00:00<00:00, 69.69it/s]


best accuracy 75.65104166666667
Epoch: 52 | Training Loss: 0.4381064718769443  | Validation Loss: 0.5184447268644968  | Accuracy: 75.07


155it [00:03, 45.71it/s]
100%|██████████| 3/3 [00:00<00:00, 67.99it/s]


best accuracy 75.65104166666667
Epoch: 53 | Training Loss: 0.43830524586862135  | Validation Loss: 0.5138219197591146  | Accuracy: 75.26


155it [00:03, 49.38it/s]
100%|██████████| 3/3 [00:00<00:00, 69.10it/s]


best accuracy 75.65104166666667
Epoch: 54 | Training Loss: 0.4384203045598922  | Validation Loss: 0.517377108335495  | Accuracy: 75.07


155it [00:03, 49.47it/s]
100%|██████████| 3/3 [00:00<00:00, 68.54it/s]


best accuracy 75.65104166666667
Epoch: 55 | Training Loss: 0.43884330065019667  | Validation Loss: 0.5191065073013306  | Accuracy: 74.61


155it [00:03, 49.40it/s]
100%|██████████| 3/3 [00:00<00:00, 69.35it/s]


best accuracy 75.65104166666667
Epoch: 56 | Training Loss: 0.4380667951799208  | Validation Loss: 0.5163869261741638  | Accuracy: 74.80


155it [00:03, 45.87it/s]
100%|██████████| 3/3 [00:00<00:00, 68.92it/s]


best accuracy 75.65104166666667
Epoch: 57 | Training Loss: 0.4381325060321439  | Validation Loss: 0.5171574254830679  | Accuracy: 74.67


155it [00:03, 49.27it/s]
100%|██████████| 3/3 [00:00<00:00, 69.97it/s]


best accuracy 75.65104166666667
Epoch: 58 | Training Loss: 0.4381982482248737  | Validation Loss: 0.525213360786438  | Accuracy: 74.09


155it [00:03, 49.67it/s]
100%|██████████| 3/3 [00:00<00:00, 68.13it/s]


best accuracy 75.65104166666667
Epoch: 59 | Training Loss: 0.4384203705095476  | Validation Loss: 0.5159481763839722  | Accuracy: 75.13


155it [00:03, 45.56it/s]
100%|██████████| 3/3 [00:00<00:00, 68.08it/s]


best accuracy 75.65104166666667
Epoch: 60 | Training Loss: 0.438071895030237  | Validation Loss: 0.5147464871406555  | Accuracy: 74.93


155it [00:03, 49.09it/s]
100%|██████████| 3/3 [00:00<00:00, 68.49it/s]


best accuracy 75.65104166666667
Epoch: 61 | Training Loss: 0.4382727546076621  | Validation Loss: 0.5151610871156057  | Accuracy: 75.07


155it [00:03, 49.40it/s]
100%|██████████| 3/3 [00:00<00:00, 75.32it/s]


best accuracy 75.65104166666667
Epoch: 62 | Training Loss: 0.43811519895830464  | Validation Loss: 0.5134928127129873  | Accuracy: 75.07


155it [00:03, 46.15it/s]
100%|██████████| 3/3 [00:00<00:00, 67.71it/s]


best accuracy 75.65104166666667
Epoch: 63 | Training Loss: 0.4384600498983937  | Validation Loss: 0.5226327379544576  | Accuracy: 74.67


155it [00:03, 48.67it/s]
100%|██████████| 3/3 [00:00<00:00, 69.07it/s]


best accuracy 75.65104166666667
Epoch: 64 | Training Loss: 0.4382097215421738  | Validation Loss: 0.5234674413998922  | Accuracy: 74.54


155it [00:03, 49.23it/s]
100%|██████████| 3/3 [00:00<00:00, 69.02it/s]


best accuracy 75.65104166666667
Epoch: 65 | Training Loss: 0.43835596999814436  | Validation Loss: 0.5159341295560201  | Accuracy: 74.74


155it [00:03, 46.00it/s]
100%|██████████| 3/3 [00:00<00:00, 70.22it/s]


best accuracy 75.65104166666667
Epoch: 66 | Training Loss: 0.4380831295444119  | Validation Loss: 0.5242762366930643  | Accuracy: 74.35


155it [00:03, 49.32it/s]
100%|██████████| 3/3 [00:00<00:00, 69.28it/s]


best accuracy 75.65104166666667
Epoch: 67 | Training Loss: 0.438420186696514  | Validation Loss: 0.5260798136393229  | Accuracy: 74.48


155it [00:03, 49.16it/s]
100%|██████████| 3/3 [00:00<00:00, 68.17it/s]


best accuracy 75.65104166666667
Epoch: 68 | Training Loss: 0.438129521377625  | Validation Loss: 0.5138694643974304  | Accuracy: 74.74


155it [00:03, 45.91it/s]
100%|██████████| 3/3 [00:00<00:00, 67.42it/s]


best accuracy 75.65104166666667
Epoch: 69 | Training Loss: 0.4384561452173418  | Validation Loss: 0.5216773748397827  | Accuracy: 74.48


155it [00:03, 48.77it/s]
100%|██████████| 3/3 [00:00<00:00, 69.64it/s]


best accuracy 75.65104166666667
Epoch: 70 | Training Loss: 0.4380194067955017  | Validation Loss: 0.5156113902727762  | Accuracy: 74.87


155it [00:03, 45.96it/s]
100%|██████████| 3/3 [00:00<00:00, 68.88it/s]


best accuracy 75.65104166666667
Epoch: 71 | Training Loss: 0.4385559343522595  | Validation Loss: 0.5179776151974996  | Accuracy: 74.74


155it [00:03, 49.64it/s]
100%|██████████| 3/3 [00:00<00:00, 69.49it/s]


best accuracy 75.65104166666667
Epoch: 72 | Training Loss: 0.4380067963753977  | Validation Loss: 0.5152046084403992  | Accuracy: 75.00


155it [00:03, 49.39it/s]
100%|██████████| 3/3 [00:00<00:00, 70.07it/s]


best accuracy 75.65104166666667
Epoch: 73 | Training Loss: 0.4381939845700418  | Validation Loss: 0.5141197840372721  | Accuracy: 75.20


155it [00:03, 49.32it/s]
100%|██████████| 3/3 [00:00<00:00, 69.41it/s]


best accuracy 75.65104166666667
Epoch: 74 | Training Loss: 0.43799942885675736  | Validation Loss: 0.5214619636535645  | Accuracy: 74.28


155it [00:03, 46.12it/s]
100%|██████████| 3/3 [00:00<00:00, 68.48it/s]


best accuracy 75.65104166666667
Epoch: 75 | Training Loss: 0.43813660721625053  | Validation Loss: 0.5205340087413788  | Accuracy: 74.41


155it [00:03, 49.55it/s]
100%|██████████| 3/3 [00:00<00:00, 69.43it/s]


best accuracy 75.65104166666667
Epoch: 76 | Training Loss: 0.4380654336944703  | Validation Loss: 0.5220190982023875  | Accuracy: 74.74


155it [00:03, 49.58it/s]
100%|██████████| 3/3 [00:00<00:00, 69.56it/s]


best accuracy 75.65104166666667
Epoch: 77 | Training Loss: 0.4382580391822323  | Validation Loss: 0.5199108918507894  | Accuracy: 74.87


12it [00:00, 45.95it/s]


KeyboardInterrupt: 

### 7. Testing model

In [None]:
model = bert

In [None]:
train_features_Protein_A

In [None]:
config