In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from utils.dataloader import *
from utils.GFP import *
from utils.AAV import * 
from utils.pesudo_MSA import AAV_END, AAV_START, WT_AAV2
from utils.torch_utils import *
from Groundtruth_model.CNN import *
from Groundtruth_model.ensemble import *
from Oracle_model.oracle_from_CbAS import *
import warnings
warnings.filterwarnings("ignore")

## Groundtruth model using a CNN

In [3]:
Epoch = 400; lr=0.0005; batch_size=128; seed=2; model_id = 0
device = torch.device("cuda")
#################################################
# seed_everything(seed=seed)
# df = pd.read_csv('GFP_data/gfp_data.csv')
# X, y = get_gfp_X_y_aa(df, large_only=True, ignore_stops=True)
df = pd.read_csv('AAV_data/AAV_library.csv')
X, y, seqs = get_AAV_X_y_aa(df, large_only=False, return_str=True)
#################################################
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=seed, shuffle=True)
train_set = torch.utils.data.TensorDataset(X_train, y_train)
# train_set = torch.utils.data.TensorDataset(X, y)
# X_test, y_test = X, y
loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)

CNN = CNN_ground_AAV(seq_len = X.shape[1], hidden_fc=512, hidden_conv=12, n_chars = 21)
CNN = CNN.to(device=device)
criterion = NLL_loss 
MSE = nn.MSELoss()
optimizer = torch.optim.Adam(CNN.parameters(), lr=lr)
 

In [9]:
def save_checkpoints(CNN, model_id, epoch):
    torch.save({
        'epoch': epoch,
        'model_id': model_id,
        'CNN_state_dict': CNN.state_dict(),
    }, os.path.join('AAV_data/ground_model/'+'CNN' + str(model_id) + \
        '_epoch_{}'.format(epoch)))
    return

def train(epoch, loss_tr, loss_te, train_full=False):
    train_loss_running = 0
    for batch, (X_b, y_b) in enumerate(loader):
        X_b, y_b = X_b.to(device), y_b.to(device)
        optimizer.zero_grad()
        y_pred = CNN(X_b).squeeze(-1)
        loss = criterion(y_b, y_pred)
        loss.backward()
        optimizer.step()
        train_loss_running += loss.item()
    # get avg training loss 
    train_loss = train_loss_running/batch
    # eval
    if train_full == False:
        with torch.no_grad():
            y_pred = CNN(X_test.to(device)).squeeze(-1)
            loss_test = MSE(y_pred[:, 0], y_test.to(device)) 
            print("Epoch:", epoch)
            print("training loss is:", train_loss, 
                  "; testing MSE is:", loss_test.item())
    else:
        print("Epoch:", epoch, "loss for training is:", train_loss)
    loss_tr.append(train_loss)
    loss_te.append(loss_test.item())
    return loss_tr, loss_te

In [10]:
loss_tr, loss_te = [], []
for e in range(300):
    loss_tr, loss_te = train(e, loss_tr, loss_te)

Epoch: 0
training loss is: 2.1255210611762565 ; testing MSE is: 2.670632839202881
Epoch: 1
training loss is: 1.6882629860007068 ; testing MSE is: 2.4085123538970947
Epoch: 2
training loss is: 1.635924819777395 ; testing MSE is: 2.3074216842651367
Epoch: 3
training loss is: 1.6012870977138558 ; testing MSE is: 2.2508997917175293
Epoch: 4
training loss is: 1.5775697422909363 ; testing MSE is: 2.1630165576934814
Epoch: 5
training loss is: 1.5632242526472209 ; testing MSE is: 2.1274120807647705
Epoch: 6
training loss is: 1.5512531548102584 ; testing MSE is: 2.0148847103118896
Epoch: 7
training loss is: 1.5382396330365415 ; testing MSE is: 1.9992750883102417
Epoch: 8
training loss is: 1.5296867254957196 ; testing MSE is: 1.9720592498779297
Epoch: 9
training loss is: 1.5208473279330335 ; testing MSE is: 1.9234211444854736
Epoch: 10
training loss is: 1.5158523795106162 ; testing MSE is: 1.927497148513794
Epoch: 11
training loss is: 1.5084619484788833 ; testing MSE is: 1.9655823707580566
Epoch

In [None]:
CNNs = load_GFP_ground_CNNs(n_models=2, train_epoch=300)
y_pred = ensemble_infer(CNNs, X_test[:10])
y_pred, y_test[:10]

(array([3.61016524, 3.10935128, 3.34050274, 3.65482473, 3.64157796,
        3.53691149, 3.49341369, 3.43271708, 3.4210366 , 3.63521731]),
 tensor([3.5584, 3.1370, 3.5275, 3.5782, 3.7308, 3.3761, 3.7312, 3.4574, 3.4884,
         3.7136]))

## Oracle model

In [2]:
##################################################
# generate oracles for n_experiment 
n_experiment = 10
train_size = 512 # use 256 sequence with fitness values for each experiment
##################################################
seed_everything()
df = pd.read_csv('AAV_data/AAV_library.csv')
X, y, seqs = get_AAV_X_y_aa(df, large_only=False, return_str=True)
WT_AAV_encoding = torch.Tensor(
    one_hot_aav_mutation_seq(WT_AAV2[AAV_START:AAV_END]).argmax(-1))
# since there are different fitness measurement for the same sequence
X = get_unique_X(X)
groundtruth_models = load_AAV_ground_CNNs(n_models=8, train_epoch=300)
y_gt = ensemble_infer(groundtruth_models, torch.tensor(X))

In [24]:
train_size = 256

#### Obtain oracle trained with only single-mutation sequences

In [25]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y_by_EditDist(
        X, y_gt, WT_AAV_encoding, train_size=train_size, 
        max_edit_distance=1, random_state=ensemble_id, 
        return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, suffix='Single', protein='AAV',
        train_size=train_size, n_models=5, 
        ensemble_id=ensemble_id, n_char=21)

Epoch 1/100
8/8 - 0s - loss: 1.5804 - val_loss: 1.7511
Epoch 2/100
8/8 - 0s - loss: 1.3081 - val_loss: 1.3486
Epoch 3/100
8/8 - 0s - loss: 1.2220 - val_loss: 1.3367
Epoch 4/100
8/8 - 0s - loss: 1.0722 - val_loss: 1.5520
Epoch 5/100
8/8 - 0s - loss: 1.0035 - val_loss: 1.5852
Epoch 6/100
8/8 - 0s - loss: 1.0021 - val_loss: 1.6198
Epoch 7/100
8/8 - 0s - loss: 0.9437 - val_loss: 1.4193
Epoch 8/100
8/8 - 0s - loss: 0.9515 - val_loss: 1.4145
Epoch 00008: early stopping
Epoch 1/100
8/8 - 0s - loss: 1.2141 - val_loss: 1.3681
Epoch 2/100
8/8 - 0s - loss: 1.0636 - val_loss: 1.4689
Epoch 3/100
8/8 - 0s - loss: 1.0345 - val_loss: 1.4990
Epoch 4/100
8/8 - 0s - loss: 1.0093 - val_loss: 1.4161
Epoch 5/100
8/8 - 0s - loss: 1.0317 - val_loss: 1.3373
Epoch 6/100
8/8 - 0s - loss: 0.9711 - val_loss: 1.4538
Epoch 7/100
8/8 - 0s - loss: 0.9066 - val_loss: 1.8384
Epoch 8/100
8/8 - 0s - loss: 0.9722 - val_loss: 1.5260
Epoch 9/100
8/8 - 0s - loss: 0.8843 - val_loss: 1.5709
Epoch 10/100
8/8 - 0s - loss: 0.8443 

#### Obtain oracle trained with bottom 20% fit sequences

In [26]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y(
        X, y_gt, percentile=20, train_size=train_size, 
        random_state=ensemble_id, return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, 
        protein = "AAV", suffix='AAV', train_size=train_size, 
        n_models=5, ensemble_id=ensemble_id, n_char=21)

Epoch 1/100
8/8 - 0s - loss: 1.6446 - val_loss: 1.1653
Epoch 2/100
8/8 - 0s - loss: 1.2737 - val_loss: 0.9435
Epoch 3/100
8/8 - 0s - loss: 0.9320 - val_loss: 0.9189
Epoch 4/100
8/8 - 0s - loss: 0.7412 - val_loss: 0.5398
Epoch 5/100
8/8 - 0s - loss: 0.3674 - val_loss: 0.3176
Epoch 6/100
8/8 - 0s - loss: 0.0029 - val_loss: 0.4384
Epoch 7/100
8/8 - 0s - loss: -2.3278e-01 - val_loss: 0.7330
Epoch 8/100
8/8 - 0s - loss: -3.3733e-01 - val_loss: 1.1727
Epoch 9/100
8/8 - 0s - loss: -3.8766e-01 - val_loss: 1.9084
Epoch 10/100
8/8 - 0s - loss: -4.8807e-01 - val_loss: 1.8214
Epoch 00010: early stopping
Epoch 1/100
8/8 - 0s - loss: 1.5384 - val_loss: 1.0596
Epoch 2/100
8/8 - 0s - loss: 1.1909 - val_loss: 0.8950
Epoch 3/100
8/8 - 0s - loss: 0.8149 - val_loss: 0.7153
Epoch 4/100
8/8 - 0s - loss: 0.4089 - val_loss: 0.2387
Epoch 5/100
8/8 - 0s - loss: -3.0268e-02 - val_loss: 0.3956
Epoch 6/100
8/8 - 0s - loss: -2.3978e-01 - val_loss: 0.7401
Epoch 7/100
8/8 - 0s - loss: -2.4598e-01 - val_loss: 0.6015
E

In [14]:
oracles = load_oracles(protein='AAV', suffix='AAV', 
                       n_models=5, ensemble_id=2, train_size=train_size)

In [None]:
results = get_balaji_predictions(oracles, X[60:120])
results