In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from peft import LoraConfig, get_peft_model

In [None]:
import esm

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.esm.esm2 import ESM2

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from scipy import stats
from sklearn.model_selection import train_test_split

In [None]:
import warnings

warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*Set a lower value for log_every_n_steps*")

#### Define Functions

In [None]:
esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')

In [None]:
def get_embeddings_mean(sequences):
    embeddings = []
    for seq in tqdm(sequences):
        rep = esm2.get_res(sequence=seq)
        embeddings.append(rep['representations'][33][:,1:-1,:].mean(1).cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    return embeddings

def get_embeddings_flatten(sequences):
    embeddings = []
    for seq in tqdm(sequences):
        rep = esm2.get_res(sequence=seq)
        embeddings.append(rep['representations'][33][:,1:-1,:].cpu().numpy()[0].flatten())

    embeddings = np.stack(embeddings, axis=0)

    return embeddings

def get_embeddings_mean_batch(sequences, batch_size=3):
    embeddings = []
    for i in tqdm(range(0, len(sequences), batch_size)):
        seqs = sequences[i:i+batch_size]
        rep, batch_lens = esm2.get_res_batch(sequences=seqs)
        assert (batch_lens == batch_lens[0]).all() == True
        embeddings.append(rep['representations'][33][:,1:-1,:].mean(1).cpu().numpy())

    embeddings = np.concatenate(embeddings, axis=0)

    return embeddings

def get_embeddings_full(sequences):
    embeddings = []
    for seq in tqdm(sequences):
        rep = esm2.get_res(sequence=seq)
        embeddings.append(rep['representations'][33][:,1:-1,:].cpu().numpy()[0])

    return embeddings

def one_hot_encode(sequences: list[str]) -> np.ndarray:
    """Encode a protein sequence as a one-hot array."""
    embeddings = []
    for seq in tqdm(sequences):
        amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
        aa_to_index = {aa: i for i, aa in enumerate(amino_acids)}
        one_hot = np.zeros((len(seq), len(amino_acids)))
        for i, aa in enumerate(seq):
            if aa in amino_acids:
                one_hot[i, aa_to_index[aa]] = 1
    
        embeddings.append(one_hot.flatten())  

    embeddings = np.stack(embeddings, axis=0)

    return embeddings

In [None]:
class RFSurrogate():
    def __init__(self) -> None:

        self.model = RandomForestRegressor(n_estimators=100, criterion='friedman_mse', max_depth=None, min_samples_split=2,
                                            min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=1.0,
                                            max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False,
                                            n_jobs=None, random_state=1, verbose=0, warm_start=False, ccp_alpha=0.0,
                                            max_samples=None)
    
    def trainmodel(self, X, y, val=None, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        _ = self.model.fit(X, y)
        if debug:
            self.print_eval(X, y, label='train')
            if val is not None:
                X_val, y_val = val
                self.print_eval(X_val, y_val, label='val')

    
    def print_eval(self, X, y, label='set'):
        ypred = self.model.predict(X)
        mse = mean_squared_error(ypred, y)
        corr = stats.spearmanr(ypred, y)

        print(f'{label}: mse = {mse}, spearman correlation = {corr.statistic}')

    def predict(self, X):

        return self.model.predict(X)

In [None]:
class RidgeSurrogate():
    def __init__(self) -> None:

        self.model = Ridge(alpha=1.0, fit_intercept=True, random_state=1)
    
    def trainmodel(self, X, y, val=None, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        _ = self.model.fit(X, y)
        if debug:
            self.print_eval(X, y, label='train')
            if val is not None:
                X_val, y_val = val
                self.print_eval(X_val, y_val, label='val')
    
    def print_eval(self, X, y, label='set'):
        ypred = self.model.predict(X)
        mse = mean_squared_error(ypred, y)
        corr = stats.spearmanr(ypred, y)

        print(f'{label}: mse = {mse}, spearman correlation = {corr.statistic}')

    def predict(self, X):

        return self.model.predict(X)

In [None]:
class ProteinFunDataset(Dataset):
    def __init__(self, X, y):
        self.X, self.y = X, y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class MLPSurrogate(pl.LightningModule):
    def __init__(self, config={'layers': [1280, 2048, 1280, 1], 
                               'epoch': 10, 
                               'batch_size': 16,
                               'patience': 10,
                               'lr': 1e-3,
                               'early_stopping': True}
                ) -> None:
        super().__init__()
        self.config = config

        layers = []
        for i in range(1, len(config['layers'])-1):
            layers.append(nn.Linear(config['layers'][i-1], config['layers'][i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(config['layers'][-2], config['layers'][-1]))
        self.mlp = nn.Sequential(*layers)

        self.accumulate_batch_loss_train = []
        self.accumulate_batch_loss_val = []
        self.debug=True

    def forward(self, x):
        x = self.mlp(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_train.append(loss.item())
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("val/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_val.append(loss.item())
    

    def trainmodel(self, X, y, val=None, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        self.debug = debug
        
        train_dataset = ProteinFunDataset(X, y)

        val_loader = None
        if val is not None:
            X_val, y_val = val
            val_dataset = ProteinFunDataset(X_val, y_val)
            val_loader = DataLoader(val_dataset, batch_size=self.config['batch_size'], shuffle=False)

        train_loader = DataLoader(train_dataset, batch_size=self.config['batch_size'], shuffle=True)
        
        callbacks = None
        if self.config['early_stopping']:
            callbacks = []
            earlystopping_callback = EarlyStopping(monitor="val/loss", patience=self.config['patience'], verbose=False, mode="min")
            callbacks.append(earlystopping_callback)

        trainer = pl.Trainer(max_epochs=self.config['epoch'], callbacks=callbacks,
                                accelerator="auto",
                                enable_progress_bar=False,
                                enable_model_summary=True
                                )

        trainer.fit(model=self, train_dataloaders=train_loader, val_dataloaders=val_loader)

        ## Needs to change - we need to load the least val loss model
        if val is not None:
            y_pred = self.predict(X_val)
            val_mse = mean_squared_error(y_pred, y_val)
            print(f'Train end val mse: {val_mse}')

    def on_train_epoch_start(self):
        self.accumulate_batch_loss_train.clear()
        self.accumulate_batch_loss_val.clear()
    
    def on_train_epoch_end(self):
        if self.current_epoch % self.config['print_every_n_epoch'] == 0 and self.debug:
            print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def on_train_end(self):
        print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def predict(self, X):
        '''
            X is numpy array
        '''
        with torch.no_grad():
            y = self(torch.tensor(X))
        return y.numpy().flatten()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.config['lr'])

In [None]:
class ProteinFunDataset(Dataset):
    def __init__(self, X, y):
        self.X, self.y = X, y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
class ESM2Regression(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.esm2, self.alphabet = esm.pretrained.load_model_and_alphabet(config['model_path'])
        self.batch_converter = self.alphabet.get_batch_converter()

        self.tok_to_idx = self.alphabet.tok_to_idx
        self.idx_to_tok = {v:k for k,v in self.tok_to_idx.items()}

        layers = []
        for i in range(1, len(config['layers'])-1):
            layers.append(nn.Linear(config['layers'][i-1], config['layers'][i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(config['layers'][-2], config['layers'][-1]))
        self.mlp = nn.Sequential(*layers)

    def forward(self, batch_tokens):
        rep = self.esm2(batch_tokens, repr_layers=[30], return_contacts=True)
        embedding = rep['representations'][30].mean(1)

        pred = self.mlp(embedding)

        return pred

class ESM2loraMLPSurrogate(pl.LightningModule):
    def __init__(self, config={'layers': [1280, 2048, 1280, 1], 
                               'epoch': 10, 
                               'batch_size': 16,
                               'patience': 10,
                               'lr': 1e-3,
                               'early_stopping': True}
                ) -> None:
        super().__init__()
        self.config = config

        regressor = ESM2Regression(config)

        lora_config = LoraConfig(
            r=4, 
            lora_alpha=1,
            target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
            lora_dropout=0.1,
            bias="all",
        )

        self.model = get_peft_model(regressor, lora_config)
        for name, param in self.model.named_parameters():
            if 'mlp' in name:
                param.requires_grad = True
                
        if config['device'] == 'gpu':
            self.model.cuda()

        self.accumulate_batch_loss_train = []
        self.accumulate_batch_loss_val = []
        self.debug=True

    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        data = [
            (fun, seq) for (seq, fun) in zip(x, y)
            ]
        batch_labels, batch_strs, batch_tokens = self.model.batch_converter(data)
        batch_lens = (batch_tokens != self.model.alphabet.padding_idx).sum(1)

        if self.config['device'] == 'gpu':
            batch_tokens = batch_tokens.cuda()

        y_hat = self(batch_tokens)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_train.append(loss.item())
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        data = [
            (fun, seq) for (seq, fun) in zip(x, y)
            ]
        batch_labels, batch_strs, batch_tokens = self.model.batch_converter(data)
        batch_lens = (batch_tokens != self.model.alphabet.padding_idx).sum(1)

        if self.config['device'] == 'gpu':
            batch_tokens = batch_tokens.cuda()

        y_hat = self(batch_tokens)
        loss = nn.functional.mse_loss(y_hat.flatten(), y)
        self.log("val/loss", loss, on_step=True, on_epoch=True)
        self.accumulate_batch_loss_val.append(loss.item())
    

    def trainmodel(self, X, y, val=None, debug=True):
        '''
            X - embeddings from esm2
            X - shape (n, features)
            y - shape (n, )
        '''
        self.debug = debug
        
        train_dataset = ProteinFunDataset(X, y)

        val_loader = None
        if val is not None:
            X_val, y_val = val
            val_dataset = ProteinFunDataset(X_val, y_val)
            val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

        train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
        
        callbacks = None
        if self.config['early_stopping']:
            callbacks = []
            earlystopping_callback = EarlyStopping(monitor="val/loss", patience=self.config['patience'], verbose=False, mode="min")
            callbacks.append(earlystopping_callback)

        trainer = pl.Trainer(max_epochs=self.config['epoch'], callbacks=callbacks,
                                accelerator="auto",
                                enable_progress_bar=False,
                                enable_model_summary=True,
                                precision="16-mixed",
                                accumulate_grad_batches=self.config['batch_size']
                                )

        trainer.fit(model=self, train_dataloaders=train_loader, val_dataloaders=val_loader)

        ## Needs to change - we need to load the least val loss model
        if val is not None:
            y_pred = self.predict(X_val)
            val_mse = mean_squared_error(y_pred, y_val)
            print(f'Train end val mse: {val_mse}')

    def on_train_epoch_start(self):
        self.accumulate_batch_loss_train.clear()
        self.accumulate_batch_loss_val.clear()
    
    def on_train_epoch_end(self):
        if self.current_epoch % self.config['print_every_n_epoch'] == 0 and self.debug:
            print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def on_train_end(self):
        print(f'Epoch: {self.current_epoch}: train mse: {np.mean(self.accumulate_batch_loss_train)} val mse: {np.mean(self.accumulate_batch_loss_val)}')

    def predict(self, X, setting='cpu'):
        '''
            X is list seq
        '''
        data = [
                    (f'P{i}', seq) for i, seq in enumerate(X)
                ]
        batch_labels, batch_strs, batch_tokens = self.model.batch_converter(data)
        batch_lens = (batch_tokens != self.model.alphabet.padding_idx).sum(1)

        if setting=='cpu':
            assert next(self.parameters()).is_cuda == False
            pred = []
            for i in range(0, batch_tokens.shape[0], 10):
                with torch.no_grad():
                    y_pred = self(batch_tokens[i:i+10])
                    pred.append(y_pred.cpu().numpy().flatten())

            pred = np.concatenate(pred)
        else:
            raise Exception('Not working')

        return pred
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.config['lr'])

#### Load Data

In [None]:
root = '../..'
data_path = os.path.join(root, 'Data/al_test_experiments/Evolvepro')

In [None]:
file_name = os.path.join(data_path, 'brenan.csv')
df = pd.read_csv(file_name)

In [None]:
df.shape

In [None]:
df.head()

In [None]:
len(df['seq'][0])

#### Tests

In [None]:
## ESM2 mean-pooled embeddings

# embeddings = get_embeddings_mean(df['seq'])
# file_name = os.path.join(data_path, 'brenan_embeddings.npy')
# np.save(file_name, embeddings)

file_name = os.path.join(data_path, 'brenan_embeddings.npy')
embeddings = np.load(file_name)

In [None]:
## ESM2 flattened/concateneted embeddings

# embeddings = get_embeddings_flatten(df['seq'])
# file_name = os.path.join(data_path, 'brenan_embeddings_concat.npy')
# np.save(file_name, embeddings)

# file_name = os.path.join(data_path, 'brenan_embeddings_concat.npy')
# embeddings = np.load(file_name)

In [None]:
# embeddings = one_hot_encode(df['seq'])

In [None]:
embeddings.shape

In [None]:
np.random.seed(0)

In [None]:
test_split = 0.2
num_seq_in_pos = 19
block_size = 10
num_blocks = int(df.shape[0]*test_split // (num_seq_in_pos * block_size) + 1)
positions = df['pos'].unique()
step_size = len(positions) // num_blocks
block_indices = [i for i in range(0, len(positions) - block_size + 1, step_size)][:num_blocks]
blocks = [positions[i:i + block_size] for i in block_indices]

for _block in blocks:
    assert len(_block) == 10
    for i in _block:
        assert i in positions

test_positions = np.concatenate(blocks)
test_indices = np.array(df[df['pos'].isin(test_positions)].index)

val_split = 0.1
n_pos_val = int((~df['pos'].isin(test_positions)).sum()*val_split // num_seq_in_pos + 1)
val_positions = np.random.choice(df.loc[~df['pos'].isin(test_positions), 'pos'].unique(), n_pos_val, replace=False)
for i in val_positions:
    assert i not in test_positions
val_indices = np.array(df[df['pos'].isin(val_positions)].index)
train_indices = np.array(df[~df['pos'].isin(np.concatenate([val_positions, test_positions]))].index)

In [None]:
# index = np.array(df.index)
# np.random.shuffle(index)
# train_indices = index[:-int(index.shape[0]*0.2)-int(index.shape[0]*0.1)]
# val_indices = index[-int(index.shape[0]*0.2)-int(index.shape[0]*0.1):-int(index.shape[0]*0.2)]
# test_indices = index[-int(index.shape[0]*0.2):]

In [None]:
print(f'Total size: {df.shape[0]}')
print(f'train size: {train_indices.shape[0]} ({round(train_indices.shape[0]*100/df.shape[0], 2)}%)')
print(f'val size  : {val_indices.shape[0]} ({round(val_indices.shape[0]*100/df.shape[0], 2)}%)')
print(f'test size : {test_indices.shape[0]} ({round(test_indices.shape[0]*100/df.shape[0], 2)}%)')

In [None]:
scaled = False
if scaled:
    property_label = 'function_scaled'
else:
    property_label = 'function'

In [None]:
X = embeddings
y = df[property_label].to_numpy().astype(np.float32)

In [None]:
X_train = X[train_indices]
y_train = y[train_indices]

X_val = X[val_indices]
y_val = y[val_indices]

X_test = X[test_indices]
y_test = y[test_indices]

In [None]:
surrogate = RidgeSurrogate()
surrogate.train(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
surrogate = RFSurrogate()
surrogate.train(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
config={'layers': [460800, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 10,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 10}
surrogate = MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
def predict(model, X, setting='cpu'):
    '''
        X is list seq
    '''

    if setting=='cpu':
        assert next(model.parameters()).is_cuda == False
        pred = []
        for i in range(0, X.shape[0], 10):
            with torch.no_grad():
                y_pred = model(torch.tensor(X[i:i+10]))
                pred.append(y_pred.cpu().numpy().flatten())

        pred = np.concatenate(pred)
    else:
        raise Exception('Not working')

    return pred

In [None]:
surrogate

In [None]:
y_train_pred = surrogate.predict(X_train.astype(np.float32))
y_val_pred = surrogate.predict(X_val.astype(np.float32))
y_test_pred = surrogate.predict(X_test.astype(np.float32))
# y_train_pred = predict(surrogate, X_train)
# y_val_pred = predict(surrogate, X_val)
# y_test_pred = predict(surrogate, X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.3)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
config={'model_path': '/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t30_150M_UR50D.pt',
        'layers': [640, 1280, 640, 1], 
        'epoch': 50, 
        'batch_size': 16,
        'patience': 10,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 1,
        'device': 'gpu'}
surrogate = ESM2loraMLPSurrogate(config=config)
surrogate.model.print_trainable_parameters()

In [None]:
X_train = df['seq'].to_numpy()[train_indices]
X_val = df['seq'].to_numpy()[val_indices]
X_test = df['seq'].to_numpy()[test_indices]
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))