In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from utils.dataloader import *
from utils.HIS7 import *
from utils.torch_utils import *
from Groundtruth_model.CNN import *
from Groundtruth_model.ensemble import *
from Oracle_model.oracle_from_CbAS import *
import warnings
warnings.filterwarnings("ignore")

## Groundtruth model using a CNN

In [None]:
#################################################
df = pd.read_csv('HIS7_data/his7.csv')
X, y = get_HIS7_X_y_aa(df, large_only=False, ignore_stops=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=seed, shuffle=True)
train_set = torch.utils.data.TensorDataset(X_train, y_train)
# train_set = torch.utils.data.TensorDataset(X, y)
# X_test, y_test = X, y
loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle=True)
#################################################

In [11]:
Epoch = 400; lr=0.001; batch_size=512; seed=1; model_id = 1
device = torch.device("cuda")
seed_everything(seed=seed)
CNN = CNN_ground(seq_len = X.shape[1], hidden_fc=128, hidden_conv=12, n_chars = 20)
CNN = CNN.to(device=device)
criterion = NLL_loss 
MSE = nn.MSELoss()
optimizer = torch.optim.Adam(CNN.parameters(), lr=lr)

In [10]:
def save_checkpoints(CNN, model_id, epoch):
    torch.save({
        'epoch': epoch,
        'model_id': model_id,
        'CNN_state_dict': CNN.state_dict(),
    }, os.path.join('HIS7_data/ground_model/'+'CNN' + str(model_id) + \
        '_epoch_{}'.format(epoch)))
    return

def train(epoch, loss_tr, loss_te, train_full=False):
    train_loss_running = 0
    for batch, (X_b, y_b) in enumerate(loader):
        X_b, y_b = X_b.to(device), y_b.to(device)
        optimizer.zero_grad()
        y_pred = CNN(X_b).squeeze(-1)
        loss = criterion(y_b, y_pred)
        loss.backward()
        optimizer.step()
        train_loss_running += loss.item()
    # get avg training loss 
    train_loss = train_loss_running/batch
    # eval
    if train_full == False:
        with torch.no_grad():
            y_pred = CNN(X_test.to(device)).squeeze(-1)
            loss_test = MSE(y_pred[:, 0], y_test.to(device)) 
            print("Epoch:", epoch)
            print("training loss is:", train_loss, 
                  "; testing MSE is:", loss_test.item())
    else:
        print("Epoch:", epoch, "loss for training is:", train_loss)
    loss_tr.append(train_loss)
    loss_te.append(loss_test.item())
    return loss_tr, loss_te

In [7]:
loss_tr, loss_te = [], []
for e in range(10):
    loss_tr, loss_te = train(e, loss_tr, loss_te)

Epoch: 0
training loss is: -0.5659738105727781 ; testing MSE is: 0.034320078790187836
Epoch: 1
training loss is: -0.5842535290410442 ; testing MSE is: 0.04224729537963867
Epoch: 2
training loss is: -0.5953121509475092 ; testing MSE is: 0.033062104135751724
Epoch: 3
training loss is: -0.5997616277202483 ; testing MSE is: 0.03397340700030327
Epoch: 4
training loss is: -0.5813622550425991 ; testing MSE is: 0.03795141354203224
Epoch: 5
training loss is: -0.5847968480663914 ; testing MSE is: 0.03109155222773552
Epoch: 6
training loss is: -0.6115484714508057 ; testing MSE is: 0.03557571768760681
Epoch: 7
training loss is: -0.6363733656944767 ; testing MSE is: 0.03517623618245125
Epoch: 8
training loss is: -0.6151200048769674 ; testing MSE is: 0.03355681151151657
Epoch: 9
training loss is: -0.6310107684135438 ; testing MSE is: 0.031025070697069168


In [8]:
save_checkpoints(CNN, 0, 200)

## Oracle model

In [4]:
##################################################
# generate oracles for n_experiment 
n_experiment = 10
train_size = 96 # use 256 sequence with fitness values for each experiment
##################################################
seed_everything()
df = pd.read_csv('HIS7_data/his7.csv')
X, _ = get_HIS7_X_y_aa(df, large_only=True, ignore_stops=True)
# since there are different fitness measurement for the same sequence
X = get_unique_X(X)
CNNs = load_HIS7_ground_AAVs(n_models=8, train_epoch=200)
y_gt = ensemble_infer(CNNs, torch.tensor(X))
WT_HIS7_encoding = onehot_encoder(WT_HIS7).argmax(-1)

#### Obtain oracle trained with only single-mutation sequences

In [14]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y_by_EditDist(
        X, y_gt, WT_HIS7_encoding, train_size=train_size, 
        max_edit_distance=2, random_state=ensemble_id, 
        return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, suffix='Double', protein='HIS7',
        train_size=train_size, n_models=5,
        ensemble_id=ensemble_id, n_char=20)

Epoch 1/100
29/29 - 1s - loss: 0.5966 - val_loss: 0.0417
Epoch 2/100
29/29 - 0s - loss: -2.6101e-01 - val_loss: -3.4125e-01
Epoch 3/100
29/29 - 0s - loss: -2.3986e-01 - val_loss: -4.7684e-01
Epoch 4/100
29/29 - 0s - loss: -8.8088e-01 - val_loss: -1.2947e+00
Epoch 5/100
29/29 - 0s - loss: -1.5508e+00 - val_loss: -3.8324e-01
Epoch 6/100
29/29 - 0s - loss: -4.5353e-01 - val_loss: -1.1237e+00
Epoch 7/100
29/29 - 0s - loss: -1.2911e+00 - val_loss: -1.5547e+00
Epoch 8/100
29/29 - 0s - loss: -1.7519e+00 - val_loss: -2.0023e+00
Epoch 9/100
29/29 - 0s - loss: -1.6482e+00 - val_loss: -1.7593e+00
Epoch 10/100
29/29 - 0s - loss: -1.1895e+00 - val_loss: -1.2093e+00
Epoch 11/100
29/29 - 0s - loss: -1.5702e+00 - val_loss: -1.7720e+00
Epoch 12/100
29/29 - 0s - loss: -1.8138e+00 - val_loss: -1.8826e+00
Epoch 13/100
29/29 - 0s - loss: -1.7440e+00 - val_loss: -1.5971e+00
Epoch 00013: early stopping
Epoch 1/100
29/29 - 0s - loss: 0.8481 - val_loss: -1.6138e-01
Epoch 2/100
29/29 - 0s - loss: -6.9012e-01 - 

#### Obtain oracle trained with bottom 20% fit sequences

In [1]:
def partition_data(X, y, percentile=40, train_size=1000, random_state=1, return_test=False):
    """Partition a (X, y) data set by a percentile of the y values"""
    np.random.seed(random_state)
    assert (percentile*0.01 * len(y) >= train_size)
    y_percentile = np.percentile(y, percentile)
    idx = np.where(y < y_percentile)[0]
    rand_idx = np.random.choice(idx, size=train_size, replace=False)
    X_train = X[rand_idx]
    y_train = y[rand_idx]
    if return_test:
        test_idx = [i for i in idx if i not in rand_idx]
        X_test = X[test_idx]
        y_test = y[test_idx]
        return X_train, y_train, X_test, y_test
    else:
        return X_train, y_train


In [15]:
for ensemble_id in range(n_experiment):
    X_oracle, y_oracle = get_experimental_X_y(
        X, y_gt, percentile=20, train_size=train_size, 
        random_state=ensemble_id, return_y_noise=True)
    train_and_save_oracles(
        X_oracle, y_oracle, suffix='HIS7', protein='HIS7',
        train_size=train_size, n_models=5, ensemble_id=ensemble_id)

Epoch 1/100
29/29 - 1s - loss: 0.3739 - val_loss: -8.2955e-01
Epoch 2/100
29/29 - 0s - loss: -6.7849e-01 - val_loss: -1.1124e+00
Epoch 3/100
29/29 - 0s - loss: -4.9406e-01 - val_loss: -4.8797e-01
Epoch 4/100
29/29 - 0s - loss: -7.3733e-01 - val_loss: -1.1479e+00
Epoch 5/100
29/29 - 0s - loss: -6.5517e-01 - val_loss: -8.0710e-01
Epoch 6/100
29/29 - 0s - loss: -9.8043e-01 - val_loss: -1.2503e+00
Epoch 7/100
29/29 - 0s - loss: -1.1786e+00 - val_loss: -1.2410e+00
Epoch 8/100
29/29 - 0s - loss: -1.2410e+00 - val_loss: -6.7701e-01
Epoch 9/100
29/29 - 0s - loss: -1.1755e+00 - val_loss: -1.1982e+00
Epoch 10/100
29/29 - 0s - loss: -1.2922e+00 - val_loss: -1.1951e+00
Epoch 11/100
29/29 - 0s - loss: -1.2124e+00 - val_loss: -1.3225e+00
Epoch 12/100
29/29 - 0s - loss: -1.3277e+00 - val_loss: -9.0857e-01
Epoch 13/100
29/29 - 0s - loss: -1.3794e+00 - val_loss: -1.3748e+00
Epoch 14/100
29/29 - 0s - loss: -1.2042e+00 - val_loss: -3.1831e-01
Epoch 15/100
29/29 - 0s - loss: -1.0641e+00 - val_loss: -1.375