In [22]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from utils.dataloader import *
from utils.GFP import *
from utils.torch_utils import *
from Groundtruth_model.CNN import *
from Groundtruth_model.ensemble import *
from Oracle_model.oracle_from_CbAS import *
from utils.pesudo_MSA import WT_GFP
import warnings
warnings.filterwarnings("ignore")

## Groundtruth model using a CNN

In [7]:
Epoch = 400; lr=0.001; batch_size=64; seed=2; model_id = 0
device = torch.device("cuda")
#################################################
seed_everything(seed=seed)
df = pd.read_csv('GFP_data/gfp_data.csv')
X, y = get_gfp_X_y_aa(df, large_only=True, ignore_stops=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=seed, shuffle=True)
train_set = torch.utils.data.TensorDataset(X_train, y_train)
# train_set = torch.utils.data.TensorDataset(X, y)
# X_test, y_test = X, y
loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)
#################################################
CNN = CNN_ground(seq_len = X.shape[1], hidden_fc=128, hidden_conv=12, n_chars = 20)
CNN = CNN.to(device=device)
criterion = NLL_loss 
MSE = nn.MSELoss()
optimizer = torch.optim.Adam(CNN.parameters(), lr=lr)
 

In [8]:
def save_checkpoints(CNN, model_id, epoch):
    torch.save({
        'epoch': epoch,
        'model_id': model_id,
        'CNN_state_dict': CNN.state_dict(),
    }, os.path.join('GFP_data/ground_model/'+'CNN' + str(model_id) + \
        '_epoch_{}'.format(epoch)))
    return

def train(epoch, loss_tr, loss_te, train_full=False):
    train_loss_running = 0
    for batch, (X_b, y_b) in enumerate(loader):
        X_b, y_b = X_b.to(device), y_b.to(device)
        optimizer.zero_grad()
        y_pred = CNN(X_b).squeeze(-1)
        loss = criterion(y_b, y_pred)
        loss.backward()
        optimizer.step()
        train_loss_running += loss.item()
    # get avg training loss 
    train_loss = train_loss_running/batch
    # eval
    if train_full == False:
        with torch.no_grad():
            y_pred = CNN(X_test.to(device)).squeeze(-1)
            loss_test = MSE(y_pred[:, 0], y_test.to(device)) 
            print("Epoch:", epoch)
            print("training loss is:", train_loss, 
                  "; testing MSE is:", loss_test.item())
    else:
        print("Epoch:", epoch, "loss for training is:", train_loss)
    loss_tr.append(train_loss)
    loss_te.append(loss_test.item())
    return loss_tr, loss_te

In [9]:
loss_tr, loss_te = [], []
for e in range(10):
    loss_tr, loss_te = train(e, loss_tr, loss_te)

Epoch: 0
training loss is: 0.3097219543061524 ; testing MSE is: 0.04097069799900055
Epoch: 1
training loss is: -0.19544812995139685 ; testing MSE is: 0.037589818239212036
Epoch: 2
training loss is: -0.22043320188455492 ; testing MSE is: 0.03730425611138344
Epoch: 3
training loss is: -0.2449879591988626 ; testing MSE is: 0.034457772970199585
Epoch: 4
training loss is: -0.24600978961614806 ; testing MSE is: 0.03486085683107376
Epoch: 5
training loss is: -0.2668217604405412 ; testing MSE is: 0.03345346078276634
Epoch: 6
training loss is: -0.2659815056858776 ; testing MSE is: 0.03235723078250885
Epoch: 7
training loss is: -0.29629118378474334 ; testing MSE is: 0.030377931892871857
Epoch: 8
training loss is: -0.3014507502595955 ; testing MSE is: 0.029972100630402565
Epoch: 9
training loss is: -0.3105189248780224 ; testing MSE is: 0.03227581828832626


In [8]:
CNNs = load_GFP_ground_CNNs(n_models=2, train_epoch=300)
y_pred = ensemble_infer(CNNs, X_test[:10])
y_pred, y_test[:10]

(array([3.61016524, 3.10935128, 3.34050274, 3.65482473, 3.64157796,
        3.53691149, 3.49341369, 3.43271708, 3.4210366 , 3.63521731]),
 tensor([3.5584, 3.1370, 3.5275, 3.5782, 3.7308, 3.3761, 3.7312, 3.4574, 3.4884,
         3.7136]))

## Oracle model

In [2]:
##################################################
# generate oracles for n_experiment 
n_experiment = 10
train_size = 96 # use 256 sequence with fitness values for each experiment
##################################################
seed_everything()
df = pd.read_csv('GFP_data/gfp_data.csv')
X, _ = get_gfp_X_y_aa(df, large_only=True, ignore_stops=True)
# since there are different fitness measurement for the same sequence
X = get_unique_X(X)
CNNs = load_GFP_ground_CNNs(n_models=8, train_epoch=300)
y_gt = ensemble_infer(CNNs, torch.tensor(X))
WT_GFP_encoding = onehot_encoder(WT_GFP).argmax(-1)

In [23]:
train_size = 512

#### Obtain oracle trained with only single-mutation sequences

In [25]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y_by_EditDist(
        X, y_gt, WT_GFP_encoding, train_size=train_size, 
        max_edit_distance=2, random_state=ensemble_id, 
        return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, suffix='Single', protein='GFP',
        train_size=train_size, n_models=5, ensemble_id=ensemble_id)

Epoch 1/100
15/15 - 0s - loss: 2.4771 - val_loss: 1.6947
Epoch 2/100
15/15 - 0s - loss: 1.7255 - val_loss: 1.7237
Epoch 3/100
15/15 - 0s - loss: 1.6480 - val_loss: 1.5741
Epoch 4/100
15/15 - 0s - loss: 1.4844 - val_loss: 1.3436
Epoch 5/100
15/15 - 0s - loss: 1.1577 - val_loss: 0.8710
Epoch 6/100
15/15 - 0s - loss: 0.4762 - val_loss: -1.4915e-01
Epoch 7/100
15/15 - 0s - loss: 0.1748 - val_loss: -1.8824e-02
Epoch 8/100
15/15 - 0s - loss: -1.8576e-01 - val_loss: -2.7537e-01
Epoch 9/100
15/15 - 0s - loss: -5.5958e-01 - val_loss: -7.7972e-01
Epoch 10/100
15/15 - 0s - loss: -8.4482e-01 - val_loss: 3.1187
Epoch 11/100
15/15 - 0s - loss: -2.8819e-01 - val_loss: -6.6012e-01
Epoch 12/100
15/15 - 0s - loss: -7.6366e-01 - val_loss: -4.3057e-01
Epoch 13/100
15/15 - 0s - loss: -9.9656e-01 - val_loss: -2.5118e-01
Epoch 14/100
15/15 - 0s - loss: -7.8085e-01 - val_loss: -7.9743e-01
Epoch 15/100
15/15 - 0s - loss: -1.1481e+00 - val_loss: -4.5592e-01
Epoch 16/100
15/15 - 0s - loss: -1.3039e+00 - val_loss

#### Obtain oracle trained with bottom 20% fit sequences

In [24]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y(
        X, y_gt, percentile=20, train_size=train_size, 
        random_state=ensemble_id, return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, suffix='GFP', protein='GFP',
        train_size=train_size, n_models=5, ensemble_id=ensemble_id)

Epoch 1/100
15/15 - 0s - loss: 2.4760 - val_loss: 1.9766
Epoch 2/100
15/15 - 0s - loss: 1.9417 - val_loss: 1.9181
Epoch 3/100
15/15 - 0s - loss: 1.8255 - val_loss: 1.7165
Epoch 4/100
15/15 - 0s - loss: 1.5519 - val_loss: 1.2381
Epoch 5/100
15/15 - 0s - loss: 0.6391 - val_loss: -2.8862e-01
Epoch 6/100
15/15 - 0s - loss: -2.3976e-01 - val_loss: -1.6013e-01
Epoch 7/100
15/15 - 0s - loss: -7.6885e-01 - val_loss: -8.4242e-01
Epoch 8/100
15/15 - 0s - loss: -9.2194e-01 - val_loss: 0.5081
Epoch 9/100
15/15 - 0s - loss: -6.1868e-01 - val_loss: -8.0745e-01
Epoch 10/100
15/15 - 0s - loss: -9.4769e-01 - val_loss: -7.6175e-01
Epoch 11/100
15/15 - 0s - loss: -1.2756e+00 - val_loss: -7.8523e-01
Epoch 12/100
15/15 - 0s - loss: -1.0174e+00 - val_loss: 2.5414
Epoch 00012: early stopping
Epoch 1/100
15/15 - 0s - loss: 2.2803 - val_loss: 1.8038
Epoch 2/100
15/15 - 0s - loss: 1.7661 - val_loss: 1.7220
Epoch 3/100
15/15 - 0s - loss: 1.5785 - val_loss: 1.3549
Epoch 4/100
15/15 - 0s - loss: 0.8680 - val_loss:

In [15]:
oracles = load_oracles(protein='GFP', suffix='Single', 
                       n_models=5, ensemble_id=2, train_size=train_size)

In [None]:
results = get_balaji_predictions(oracles, X[60:120])
results