In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# import pytorch_lightning as pl
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.esm.esm2 import ESM2

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from scipy import stats
from sklearn.model_selection import train_test_split

In [None]:
import warnings

warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*Set a lower value for log_every_n_steps*")

#### Define Functions

In [None]:
esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')

In [None]:
def get_embeddings_mean(sequences):
    embeddings = []
    for seq in tqdm(sequences):
        rep = esm2.get_res(sequence=seq)
        embeddings.append(rep['representations'][33].mean(1).cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    return embeddings

def get_embeddings_mean_batch(sequences, batch_size=3):
    embeddings = []
    for i in tqdm(range(0, len(sequences), batch_size)):
        seqs = sequences[i:i+batch_size]
        rep, batch_lens = esm2.get_res_batch(sequences=seqs)
        assert (batch_lens == batch_lens[0]).all() == True
        embeddings.append(rep['representations'][33].mean(1).cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    return embeddings

def get_embeddings_full(sequences):
    embeddings = []
    for seq in tqdm(sequences):
        rep = esm2.get_res(sequence=seq)
        embeddings.append(rep['representations'][33].cpu().numpy()[0])

    return embeddings

In [None]:
class RFSurrogate():
    def __init__(self) -> None:

        self.model = RandomForestRegressor(n_estimators=100, criterion='friedman_mse', max_depth=None, min_samples_split=2,
                                            min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=1.0,
                                            max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False,
                                            n_jobs=None, random_state=1, verbose=0, warm_start=False, ccp_alpha=0.0,
                                            max_samples=None)
    
    def train(self, X, y, val=True, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        if val:
            idx = np.arange(X.shape[0])
            train_idx, val_idx = train_test_split(idx, test_size=0.2)
            _ = self.model.fit(X[train_idx], y[train_idx])

            if debug:
                self.print_eval(X[train_idx], y[train_idx], label='train')
                self.print_eval(X[val_idx], y[val_idx], label='val')
        else:
            _ = self.model.fit(X, y)
            if debug:
                self.print_eval(X, y, label='train')
    
    def print_eval(self, X, y, label='set'):
        ypred = self.model.predict(X)
        mse = mean_squared_error(ypred, y)
        corr = stats.spearmanr(ypred, y)

        print(f'{label}: mse = {mse}, spearman correlation = {corr.statistic}')

    def predict(self, X):

        return self.model.predict(X)

In [None]:
class RidgeSurrogate():
    def __init__(self) -> None:

        self.model = Ridge(alpha=1.0, fit_intercept=True, random_state=1)
    
    def train(self, X, y, val=True, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        if val:
            idx = np.arange(X.shape[0])
            train_idx, val_idx = train_test_split(idx, test_size=0.2)
            _ = self.model.fit(X[train_idx], y[train_idx])

            if debug:
                self.print_eval(X[train_idx], y[train_idx], label='train')
                self.print_eval(X[val_idx], y[val_idx], label='val')
        else:
            _ = self.model.fit(X, y)
            if debug:
                self.print_eval(X, y, label='train')
    
    def print_eval(self, X, y, label='set'):
        ypred = self.model.predict(X)
        mse = mean_squared_error(ypred, y)
        corr = stats.spearmanr(ypred, y)

        print(f'{label}: mse = {mse}, spearman correlation = {corr.statistic}')

    def predict(self, X):

        return self.model.predict(X)

In [None]:
class ProteinFunDataset(Dataset):
    def __init__(self, X, y):
        self.X, self.y = X, y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class MLPSurrogate(pl.LightningModule):
    def __init__(self, inp_size=1280, hidden_size=512, 
                 config={'epoch': 10, 
                         'batch_size': 16}
                ) -> None:
        super().__init__()
        self.config = config
        self.mlp = nn.Sequential(
            nn.Linear(inp_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )
        self.accumulate_batch_loss_train = []
        self.accumulate_batch_loss_val = []
        self.debug=True

    def forward(self, x):
        x = self.mlp(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_train.append(loss.item())
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("val/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_val.append(loss.item())
    
    @staticmethod
    def trainmodel(model, X, y, val=True, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        model.debug = debug
        if val:
            idx = np.arange(X.shape[0])
            train_idx, val_idx = train_test_split(idx, test_size=0.2)
            train_dataset = ProteinFunDataset(X[train_idx], y[train_idx])
            val_dataset = ProteinFunDataset(X[val_idx], y[val_idx])
            train_loader = DataLoader(train_dataset, batch_size=model.config['batch_size'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=model.config['batch_size'], shuffle=False)

            earlystopping_callback = EarlyStopping(monitor="val/loss", patience=5, verbose=False, mode="min")

            trainer = pl.Trainer(max_epochs=model.config['epoch'], callbacks=[earlystopping_callback],
                                 accelerator="auto",
                                 enable_progress_bar=False,
                                 enable_model_summary=False
                                 )
            trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

            ## Needs to change - we need to load the least val loss model
            y_pred = model.predict(X[val_idx])
            val_mse = mean_squared_error(y_pred, y[val_idx])
            print(f'Train end val mse: {val_mse}')

        else:
            raise Exception("Needs Fix")
            train_dataset = ProteinFunDataset(X, y)
            train_loader = DataLoader(train_dataset)

            trainer = pl.Trainer(max_epochs=95, 
                                 enable_progress_bar=False,
                                 accelerator="auto"
                                 )
            trainer.fit(model=model, train_dataloaders=train_loader)

    def on_train_epoch_start(self):
        self.accumulate_batch_loss_train.clear()
        self.accumulate_batch_loss_val.clear()
    
    def on_train_epoch_end(self):
        if self.current_epoch % self.config['print_every_n_epoch'] == 0 and self.debug:
            print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def on_train_end(self):
        print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def predict(self, X):
        '''
            X is numpy array
        '''
        with torch.no_grad():
            y = self(torch.tensor(X))
        return y.numpy().flatten()
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.0001)

#### Load Data

In [None]:
root = '../..'
data_path = os.path.join(root, 'Data/al_test_experiments/Evolvepro')

In [None]:
file_name = os.path.join(data_path, 'brenan.csv')
df = pd.read_csv(file_name)

In [None]:
df.head()

#### Tests

In [None]:
embeddings = get_embeddings_mean(df['seq'][:100])

In [None]:
embeddings.shape

In [None]:
config = {'epoch': 500, 'batch_size': 16, 'print_every_n_epoch': 50}
surrogate = MLPSurrogate(config=config)

In [None]:
MLPSurrogate.trainmodel(model=surrogate, X=embeddings, y=df['function'][:100].to_numpy().astype(np.float32))

In [None]:
surrogate.predict(embeddings)

#### AL Cycle

In [None]:
embeddings = get_embeddings_mean(df['seq'])

In [None]:
embeddings.shape

In [None]:
# surrogate = RFSurrogate()
# surrogate = RidgeSurrogate()
config = {'epoch': 500, 'batch_size': 16, 'print_every_n_epoch': 50}
surrogate = MLPSurrogate(config=config)

In [None]:
import logging
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)

In [None]:
## Only sample initial points with function < 0.3

df_al = df.copy()
df_al['round'] = -1

rounds = 10
thresh = 0.3
n_sample = 10
high_thresh = 0.6
df_al['high'] = 0
df_al.loc[list(df_al.loc[(df_al['round'] == -1) & (df_al['function'] > high_thresh)].index), 'high'] = 1
index = list(df_al.loc[(df_al['round'] == -1) & (df_al['function'] < thresh)].index)
sample_index = random.sample(index, n_sample)

In [None]:
for i in range(rounds):
    print(' ')
    print(f'Round {i}')

    df_al.loc[sample_index, 'round'] = i
    train_index = list(df_al.loc[df_al['round'] != -1].index) ## whatever was selected
    X = embeddings[train_index]
    y = df_al.loc[train_index, 'function'].to_numpy().astype(np.float32)
    
    assert X.shape[0] == y.shape[0]
    print(f'number of samples for training: {X.shape[0]}')

    # surrogate.train(X, y)
    # ypred = surrogate.predict(embeddings)

    config = {'epoch': 500, 'batch_size': 16, 'print_every_n_epoch': 50}
    surrogate = MLPSurrogate(config=config)
    MLPSurrogate.trainmodel(model=surrogate, X=X, y=y)
    ypred = surrogate.predict(embeddings)

    df_al[f'round_{i}_pred'] = ypred

    sample_index = df_al.loc[df_al['round']==-1, f'round_{i}_pred'].nlargest(n_sample).index

In [None]:
fig, ax = plt.subplots()
ax.set_ylabel('Titer')
for i, _round in enumerate(range(rounds)):
    ax.boxplot(df_al.loc[df_al['round']==_round, 'function'], positions=[i], labels=[_round], showmeans=True)
ax.axhline(high_thresh, ls='--', color='red')

In [None]:
fig, ax = plt.subplots()
ax.set_ylabel('Titer')

counts = []
for i in range(rounds):
    counts.append(df_al.loc[df_al['round']==i, 'high'].sum())

ax.plot(counts)

In [None]:
found_hits = []
n_exp = 100
for _ in tqdm(range(n_exp)):

    df_al = df.copy()
    df_al['round'] = -1

    rounds = 10
    thresh = 0.3
    n_sample = 10
    high_thresh = 0.6
    df_al['high'] = 0
    df_al.loc[list(df_al.loc[(df_al['round'] == -1) & (df_al['function'] > high_thresh)].index), 'high'] = 1
    index = list(df_al.loc[(df_al['round'] == -1) & (df_al['function'] < thresh)].index)
    sample_index = random.sample(index, n_sample)

    for i in range(rounds):
        df_al.loc[sample_index, 'round'] = i
        train_index = list(df_al.loc[df_al['round'] != -1].index) ## whatever was selected
        X = embeddings[train_index]
        y = df_al.loc[train_index, 'function'].to_numpy().astype(np.float32)
        
        assert X.shape[0] == y.shape[0]

        # surrogate.train(X, y, debug=False)
        config = {'epoch': 100, 'batch_size': 16, 'print_every_n_epoch': 50}
        surrogate = MLPSurrogate(config=config)
        MLPSurrogate.trainmodel(model=surrogate, X=X, y=y, debug=False)
        ypred = surrogate.predict(embeddings)

        df_al[f'round_{i}_pred'] = ypred

        sample_index = df_al.loc[df_al['round']==-1, f'round_{i}_pred'].nlargest(n_sample).index

    hits = []
    for i in range(rounds):
        hits.append(df_al.loc[df_al['round']==i, 'high'].sum())

    found_hits.append(hits)

In [None]:
## reset trainn weights
## revove this print

In [None]:
found_hits_cum = np.cumsum(np.array(found_hits), axis=1)/df_al['high'].sum()

In [None]:
# np.cumsum(np.array(found_hits), axis=1)

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(15,6), layout='constrained')
for i, ax in enumerate(axs.reshape(-1)):
    ax.hist(found_hits_cum[:, i]) 

In [None]:
# found_hits_mean_ = []
# ci_lower_ = []
# ci_upper_ = []
# for i in range(found_hits_cum.shape[1]):
#     if i == 0:
#         found_hits_mean_.append(0)
#         ci_lower_.append(0)
#         ci_upper_.append(0)
#     else:
#         alpha, loc, beta = stats.gamma.fit(found_hits_cum[:,i])
#         found_hits_mean_.append(stats.gamma.mean(alpha, loc, beta))
#         confidence = 1 - 0.95
#         lower_bound = stats.gamma.ppf(confidence / 2, alpha, loc=loc, scale=beta)
#         upper_bound = stats.gamma.ppf(1 - confidence / 2, alpha, loc=loc, scale=beta)
#         ci_lower_.append(lower_bound)
#         ci_upper_.append(upper_bound)

In [None]:
found_hits_mean = found_hits_cum.mean(0)
found_hits_sem = stats.sem(found_hits_cum, axis=0)
ci = stats.norm.interval(confidence=0.95,  
                        loc=found_hits_mean, 
                        scale=found_hits_sem)

In [None]:
plt.plot(np.arange(10), found_hits_mean, label='Mean Line', color='blue')
plt.fill_between(np.arange(10), ci[0], ci[1], color='blue', alpha=0.2, label='Confidence Interval')
# plt.plot(np.arange(10), found_hits_mean_, label='Mean Line', color='green')
# plt.fill_between(np.arange(10), ci_lower_, ci_upper_, color='green', alpha=0.2, label='Confidence Interval')