In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
from sklearn.linear_model import Ridge

In [4]:
from scr.preprocess.data_process import split_protrain_loader

In [10]:
"""Pre-processing the dataset"""

from __future__ import annotations

from collections import Sequence, defaultdict

import os
from glob import glob
import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

from scr.utils import pickle_save, pickle_load, replace_ext
from scr.params.sys import RAND_SEED
from scr.params.emb import TRANSFORMER_INFO
from scr.preprocess.seq_loader import SeqLoader
from scr.encoding.encoding_classes import AbstractEncoder, ESMEncoder, CARPEncoder


def get_mut_name(mut_seq: str, parent_seq: str) -> str:
    """
    A function for returning the mutant name

    Args:
    - mut_seq: str, the full mutant sequence
    - parent_seq: str, the full parent sequence

    Returns:
    - str, parent, indel, or mutant name in the format of
        ParentAAMutLocMutAA:ParentAAMutLocMutAA:..., ie. W39W:D40G:G41C:V54Q
    """

    mut_list = []
    if parent_seq == mut_seq:
        return "parent"
    elif len(parent_seq) == len(mut_seq):
        for i, (p, m) in enumerate(zip(list(parent_seq), list(mut_seq))):
            if p != m:
                mut_list.append(f"{p}{i+1}{m}")
        return ":".join(mut_list)
    else:
        return "indel"


class AddMutInfo:
    """A class for appending mutation info for mainly protein engineering tasks"""

    def __init__(self, parent_seq_path: str, csv_path: str):

        # Load the parent sequence from the fasta file
        self._parent_seq = SeqLoader(parent_seq_path=parent_seq_path)

        # load the dataframe
        self._init_df = pd.read_csv(csv_path)

        self._df = self._init_df.copy()
        # add a column with the mutant names
        self._df["mut_name"] = self._init_df["sequence"].apply(
            get_mut_name, parent_seq=self._parent_seq
        )
        # add a column with the number of mutations
        self._df["mut_numb"] = (
            self._df["mut_name"].str.split(":").map(len, na_action="ignore")
        )

        # get the pickle file path
        self._pkl_path = replace_ext(input_path=csv_path, ext=".pkl")

        pickle_save(what2save=self._df, where2save=self._pkl_path)

    @property
    def parent_seq(self) -> str:
        """Return the parent sequence"""
        return self._parent_seq

    @property
    def pkl_path(self) -> str:
        """Return the pkl file path for the processed dataframe"""
        return self._pkl_path

    @property
    def df(self) -> pd.DataFrame:
        """Return the processed dataframe"""
        return self._df


class TaskProcess:
    """A class for handling different downstream tasks"""

    def __init__(self, data_folder: str = "data/"):
        """
        Args:
        - data_folder: str, a folder path with all the tasks as subfolders where
            all the subfolders have datasets as the subsubfolders, ie

            {data_folder}/
                proeng/
                    aav/
                        one_vs_many.csv
                        two_vs_many.csv
                        P03135.fasta
                    thermo/
                        mixed.csv
        """

        if data_folder[-1] == "/":
            self._data_folder = data_folder
        else:
            self._data_folder = data_folder + "/"

        # sumamarize all files i nthe data folder
        self._sum_file_df = self.sum_files()

    def sum_files(self) -> pd.DataFrame:
        """
        Summarize all files in the data folder

        Returns:
        - A dataframe with "task", "dataset", "split", "csv_path", "fasta_path", "pkl_path" as columns, ie.
            (proeng, gb1, low_vs_high, data/proeng/gb1/low_vs_high.csv, data/proeng/gb1/5LDE_1.fasta)
            note that csv_path is the list of lmdb files for the structure task
        """
        dataset_folders = glob(f"{self._data_folder}*/*")
        # need a list of tuples in the order of:
        # (task, dataset, split, csv_path, fasta_path)
        list_for_df = []
        for dataset_folder in dataset_folders:
            _, task, dataset = dataset_folder.split("/")
            if task == "structure":
                structure_file_list = [
                    file_path
                    for file_path in glob(f"{dataset_folder}/*.*")
                    if os.path.basename(os.path.splitext(file_path)[0]).split("_")[-1]
                    in ["train", "valid", "cb513"]
                ]
                list_for_df.append(
                    tuple([task, dataset, "cb513", structure_file_list, "", ""])
                )
            else:
                csv_paths = glob(f"{dataset_folder}/*.csv")
                fasta_paths = glob(f"{dataset_folder}/*.fasta")
                pkl_paths = glob(f"{dataset_folder}/*.pkl")

                assert len(csv_paths) >= 1, "Less than one csv"
                assert len(fasta_paths) <= 1, "More than one fasta"

                for csv_path in csv_paths:
                    # if parent seq fasta exists
                    if len(fasta_paths) == 1:
                        fasta_path = fasta_paths[0]

                        # if no existing pkl file, generate and save
                        if len(pkl_paths) == 0:
                            print(f"Adding mutation info to {csv_path}...")
                            pkl_path = AddMutInfo(
                                parent_seq_path=fasta_path, csv_path=csv_path
                            ).pkl_path
                        # pkl file exits
                        else:
                            pkl_path = replace_ext(input_path=csv_path, ext=".pkl")
                    # no parent fasta no pkl file
                    else:
                        fasta_path = ""
                        pkl_path = ""

                    list_for_df.append(
                        tuple(
                            [
                                task,
                                dataset,
                                os.path.basename(os.path.splitext(csv_path)[0]),
                                csv_path,
                                fasta_path,
                                pkl_path,
                            ]
                        )
                    )

        return pd.DataFrame(
            list_for_df,
            columns=["task", "dataset", "split", "csv_path", "fasta_path", "pkl_path"],
        )

    @property
    def sum_file_df(self) -> pd.DataFrame:
        """A summary table for all files in the data folder"""
        return self._sum_file_df


class ProtranDataset(Dataset):

    """A dataset class for processing protein transfer data"""

    def __init__(
        self,
        dataset_path: str,
        subset: str,
        encoder_class: AbstractEncoder,
        encoder_name: str,
        embed_layer: int,
        embed_batch_size: int = 0,
        flatten_emb: bool | str = False,
        embed_path: str = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        **encoder_params,
    ):

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - subset: str, train, val, test
        - encoder_class: AbstractEncoder, the encoder class
        - encoder_name: str, the name of the encoder
        - embed_layer: int, the layer number of the embedding
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_path: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - encoder_params: kwarg, additional parameters for encoding
        """

        # with additional info mut_name, mut_numb
        if os.path.splitext(dataset_path)[-1] in [".pkl", ".PKL", ""]:
            self._df = pickle_load(dataset_path)
            self._add_mut_info = True
        # without such info
        else:
            self._df = pd.read_csv(dataset_path)
            self._add_mut_info = False

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

        self._df_train = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] != True)
        ]
        self._df_val = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] == True)
        ]
        self._df_test = self._df.loc[(self._df["set"] == "test")]

        self._df_dict = {
            "train": self._df_train,
            "val": self._df_val,
            "test": self._df_test,
        }

        assert subset in list(
            self._df_dict.keys()
        ), "split can only be 'train', 'val', or 'test'"
        self._subset = subset

        self._subdf_len = len(self._df_dict[self._subset])

        # not specified seq start will be from 0
        if seq_start_idx == False:
            self._seq_start_idx = 0
        else:
            self._seq_start_idx = int(seq_start_idx)
        # not specified seq end will be the full sequence length
        if seq_end_idx == False:
            self._seq_end_idx = -1
        else:
            self._seq_end_idx = int(seq_end_idx)

        # get unencoded string of input sequence
        # will need to convert data type
        self.sequence = self._get_column_value("sequence")

        # get the encoder
        self._encoder = encoder_class(encoder_name=encoder_name)

        self._max_emb_layer = self._encoder.max_emb_layer
        self._include_input_layer = self._encoder.include_input_layer

        # check if pregenerated embedding
        if embed_path is not None:
            print(f"Loading pregenerated embeddings from {embed_path}")
            encoded_dict = pickle_load(embed_path)

        # encode the sequences without the mut_name
        else:
            """encoded_dict = {
                layer: [] for layer in range(self._max_emb_layer + self._include_input_layer)
            }"""
            
            encoded_dict = defaultdict(list)

            for encoded_batch_dict in self._encoder.encode(
                mut_seqs=self.sequence,
                batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                **encoder_params,
            ):

                for layer, emb in encoded_batch_dict.items():
                    encoded_dict[layer].append(emb)

        # assign each layer as its own variable
        for layer, emb in encoded_dict.items():
            setattr(self, "layer" + str(layer), torch.tensor(np.vstack(emb), dtype=torch.float32))
            
        # get and format the fitness or secondary structure values
        # can be numbers or string
        # will need to convert data type
        # make 1D tensor 2D
        self.y = np.expand_dims(self._get_column_value("target"), 1)

        # add mut_name and mut_numb for relevant proeng datasets
        if self._add_mut_info:
            self.mut_name = self._get_column_value("mut_name")
            self.mut_numb = self._get_column_value("mut_numb")
        else:
            self.mut_name = [""] * self._subdf_len
            self.mut_numb = [np.nan] * self._subdf_len

    def __len__(self):
        """Return the length of the selected subset of the dataframe"""
        return self._subdf_len

    def __getitem__(self, idx: int):

        """
        Return the item in the order of
        encoded sequence (x), target (y), sequence, mut_name (optional), mut_numb (optional)

        Args:
        - idx: int
        """
        
        return (
            # self.x[layer_numb][idx],
            self.y[idx],
            self.sequence[idx],
            self.mut_name[idx],
            self.mut_numb[idx],
            *(
                getattr(self,  "layer" + str(layer))[idx]
                for layer in range(self._max_emb_layer + self._include_input_layer)
            ),
        )

    def _get_column_value(self, column_name: str) -> np.ndarray:
        """
        Check and return the column values of the selected dataframe subset
        """
        if column_name in self._df.columns:
            if column_name == "sequence":
                return (
                    self._df_dict[self._subset]["sequence"]
                    .astype(str)
                    .str[self._seq_start_idx : self._seq_end_idx]
                    .values
                )
            else:
                return self._df_dict[self._subset][column_name].values

    @property
    def df_full(self) -> pd.DataFrame:
        """Return the full loaded dataset"""
        return self._df

    @property
    def df_train(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_train

    @property
    def df_val(self) -> pd.DataFrame:
        """Return the dataset for validation only"""
        return self._df_val

    @property
    def df_test(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_test


def split_protrain_loader(
    dataset_path: str,
    encoder_name: str,
    embed_layer: int,
    embed_batch_size: int = 128,
    flatten_emb: bool | str = False,
    embed_path: str | None = None,
    seq_start_idx: bool | int = False,
    seq_end_idx: bool | int = False,
    subset_list: list[str] = ["train", "val", "test"],
    loader_batch_size: int = 64,
    worker_seed: int = RAND_SEED,
    **encoder_params,
):

    """
    A function encode and load the data from a path

    Args:
    - dataset_path: str, full path to the dataset, in pkl or panda readable format
        columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
    - encoder_name: str, the name of the encoder
    - embed_layer: int, the layer number of the embedding
    - embed_batch_size: int, set to 0 to encode all in a single batch
    - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
    - embed_path: str = None, path to presaved embedding
    - seq_start_idx: bool | int = False, the index for the start of the sequence
    - seq_end_idx: bool | int = False, the index for the end of the sequence
    - subset_list: list of str, train, val, test
    - loader_batch_size: int, the batch size for train, val, and test dataloader
    - worker_seed: int, the seed for dataloader
    - encoder_params: kwarg, additional parameters for encoding
    """

    assert set(subset_list) <= set(
        ["train", "val", "test"]
    ), "subset_list can only contrain terms with in be 'train', 'val', or 'test'"

    # specify no shuffling for validation and test
    if_shuffle_list = [True if subset == "train" else False for subset in subset_list]

    if encoder_name in TRANSFORMER_INFO.keys():
        encoder_class = ESMEncoder

    return (
        DataLoader(
            dataset=ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_class=encoder_class,
                encoder_name=encoder_name,
                embed_layer=embed_layer,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_path=embed_path,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                **encoder_params,
            ),
            batch_size=loader_batch_size,
            shuffle=if_shuffle,
            worker_init_fn=worker_seed,
        )
        for subset, if_shuffle in zip(subset_list, if_shuffle_list)
    )

In [11]:
val_dataset=ProtranDataset(
                dataset_path="data/proeng/gb1/two_vs_rest.pkl",
                subset="val",
                encoder_class=ESMEncoder,
                encoder_name="esm1_t6_43M_UR50S",
                embed_layer=6,
                embed_batch_size=128,
                flatten_emb="mean",
                embed_path=None,
                seq_start_idx=0,
                seq_end_idx=56,
                # **encoder_params,
            ),

Loading esm1_t6_43M_UR50S upto 6 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 1/1 [00:00<00:00,  2.14it/s]


In [92]:
val_dataset, val_dataset[0]

((<__main__.ProtranDataset at 0x7f0559a2f190>,),
 <__main__.ProtranDataset at 0x7f0559a2f190>)

In [89]:
val_dataset[0].layer0.shape

torch.Size([43, 768])

In [12]:
# train_loader, val_loader, test_loader = split_protrain_loader(
train_loader, val_loader = split_protrain_loader(
    dataset_path="data/proeng/gb1/two_vs_rest.pkl",
    encoder_name="esm1_t6_43M_UR50S",
    embed_layer=6,
    embed_batch_size=128,
    flatten_emb="mean",
    embed_path=None,
    seq_start_idx=0,
    seq_end_idx=56,
    # subset_list=["train", "val", "test"],
    subset_list=["train", "val"],
    loader_batch_size=64,
    worker_seed=42,
)

Loading esm1_t6_43M_UR50S upto 6 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 3/3 [00:04<00:00,  1.62s/it]
Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Loading esm1_t6_43M_UR50S upto 6 layer embedding


100%|██████████| 1/1 [00:00<00:00,  2.24it/s]


In [11]:
encoder_name="esm1_t6_43M_UR50S"
embed_batch_size=128
flatten_emb="mean"
embed_path=None

In [19]:
for (y, sequence, mut_name, mut_numb, *layer_emb) in val_loader:
    print(y, type(y), layer_emb[0], type(layer_emb[0]), len(layer_emb[0]))

    model = Ridge(alpha = 0.5, normalize = True, tol = 0.001, \
              solver ='auto', random_state = 42)
    
    model.fit(layer_emb[0], y)

    print(model.predict(layer_emb[0]))

tensor([[1.8029e+00],
        [7.1123e-01],
        [8.6005e-01],
        [6.1898e-01],
        [2.2602e+00],
        [1.7291e+00],
        [1.7197e+00],
        [2.2677e+00],
        [1.7626e+00],
        [2.3160e-01],
        [6.1434e-01],
        [6.1031e-01],
        [1.0339e+00],
        [5.1069e-01],
        [6.9899e-01],
        [1.8930e+00],
        [0.0000e+00],
        [6.9803e-01],
        [5.9696e-01],
        [1.1877e+00],
        [9.7038e-01],
        [5.5813e-01],
        [4.1580e+00],
        [8.5351e-01],
        [6.0882e-01],
        [1.0669e+00],
        [1.7081e+00],
        [3.0943e+00],
        [2.9738e-03],
        [7.3959e-02],
        [1.9298e+00],
        [1.9539e+00],
        [6.2064e-01],
        [7.0479e-01],
        [1.1327e+00],
        [1.5569e-01],
        [1.9346e+00],
        [1.2409e+00],
        [9.9316e-01],
        [1.4394e+00],
        [2.0135e+00],
        [4.0264e+00],
        [1.2880e+00]], dtype=torch.float64) <class 'torch.Tensor'> tensor([[

In [None]:
def pick_model(params_list: list[dict]):
    for params in params_list:
        # init model
        model = Ridge(**params)

        train_preds = []
        val_preds = []
        test_preds = []

        


In [13]:
# training model with 0.5 alpha value
model = Ridge(alpha = 0.5, normalize = True, tol = 0.001, \
              solver ='auto', random_state = 42)

train_preds = []
val_preds = []
test_preds = []

layer_numb = 0

for (x, y, _, _, _) in train_loader:
    print(x)
    model.fit(x[layer_numb], y)
    
for (x, y, _, _, _) in train_loader:
    train_preds.append(model.predict(x[layer_numb]))

for (x, y, _, _, _) in val_loader:
    val_preds.append(model.predict(x[layer_numb]))

# need to add to loop over alpha etc

for (x, y, _, _, _) in test_loader:
    test_preds.append(model.predict(x[layer_numb]))

KeyError: 42

In [13]:
import numpy as np

In [15]:
len(train_preds)

6

In [16]:
np.concatenate(train_preds).shape

(381, 1)

In [None]:
"""Classes for models"""

# Import third party modules
from copy import deepcopy
from abc import ABC, abstractmethod

from sklearn.preprocessing import StandardScaler
import sklearn.metrics as sm
import numpy as np
import scipy.stats as ss
import torch

# Import custom objects
from code.params.globals import BATCH_SIZE
from code.params.model_info import *
# from code.models.utils import pick_best_params


# Define an abstract class for holding all models
class AbstractModelWrapper(ABC):
    """
    This is an abstract class for wrapping all models used for making
    mutant fitness predictions. It will be inherited by a model class
    for making both supervised and unsupervised predictions.
    """

    # Initialization loads the parent sequence
    def __init__(self, *constructor_args, **constructor_kwargs):
        """
        Loads the reference sequence that will be used to make predictions and
        stores parameters for the model that will be trained.
        
        Args:
        - constructor_args: Positional arguments that will be passed to the model
            constructor for this wrapper.
        - constructor_kwargs: Keyword arguments that will be passed to the model
            constructor for this wrapper.
        """
        # Initialize the sequence loader
        super().__init__(fasta_loc)

        # print("AbstractModelWrapper constructor_kwargs inputs before assigning")
        # print(constructor_kwargs)

        # Store the constructor args and kwargs as instance variables
        self._constructor_args = constructor_args
        self._constructor_kwargs = deepcopy(constructor_kwargs)

        # print("AbstractModelWrapper constructor_kwargs inputs")
        # print(constructor_kwargs)
        # print("AbstractModelWrapper self._constructor_kwargs")
        # print(self._constructor_kwargs)

        # The default model is "None". We need to initialize a model later.
        self.model = None

    # All classes will be able to make a prediction
    def predict(self, mutants, muts_preencoded=False, **pred_kwargs):
        """
        Makes predictions given an input of mutants or encoded mutants. If
        unencoded mutants are passed in, then they will first be encoded.
        Otherwise, encodings passed in will be assumed correct and fed into the
        appropriate model.
        Parameters
        ----------
        mutants (tuple of tuple of tuples OR ndarray/torch tensor):
            If a tuple of tuple of tuples: All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
            If ndarray or torch tensor: This is mutants that have already been
            encoded.
        Returns
        -------
        predictions (ndarray): Numpy array containing the predictions for the
            input set of mutants.
        Examples
        --------
        An example input for `mutants` when tuple of tuple of tuples:
            mutants = (
                ((V, 39, D), (D, 40, G)),
                ((V, 39, D), (D, 40, G), (G, 41, F))
                )
        """

        # print("""This is the predict inside the AbstractModelWrapper \
        # calling _complete_prediction""")

        # Confirm that we have a model and scaler for making predictions
        assert self.model is not None, "Must train a model before making predictions"

        # Check mutations for accuracy based on whether they are pre-encoded
        # or not
        self._check_mutants(mutants, muts_preencoded)

        # Every child class will need to finish this function
        unscaled_preds = self._complete_prediction(
            mutants, muts_preencoded=muts_preencoded, **pred_kwargs
        )

        # print("mutants and unscaled_preds dims:")
        # print(mutants.shape, unscaled_preds.shape)

        # Confirm that predictions are a numpy array and that they are 2d
        assert isinstance(
            unscaled_preds, np.ndarray
        ), "Predictions should be numpy array"
        assert len(unscaled_preds.shape) == 2, "Expect 2D prediction array"

        # Return the unscaled preds. Supervised models will scale them based on
        # the standard scaler used at input
        return unscaled_preds

    def evaluate(self, mutants, true_fitness, muts_preencoded=False, **pred_kwargs):
        """
        Wraps making predictions and evaluation of predictions using a number
        of evaluation metrics.
        mutants (tuple of tuple of tuples): All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
        true_fitness (ndarray): The true fitness values of the mutants.
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        Examples
        --------
        An example input for `mutants`:
            mutants = (
                ((V, 39, D), (D, 40, G)),
                ((V, 39, D), (D, 40, G), (G, 41, F))
                )
        """
        
        # print("This is the evulate inside the AbstractModelWrapper")

        # Make predictions
        predictions = self.predict(
            mutants, muts_preencoded=muts_preencoded, **pred_kwargs
        )

        # Confirm that the predictions and true values are the same shape and
        # that they are two dimensional
        
        # print("The kwargs for the evaluate function in AbstractModelWrapper")
        # print(pred_kwargs)
        # print("After the predict in the evaluate function in AbstractModelWrapper")
        # print(mutants.shape, predictions.shape, true_fitness.shape)
        
        """assert isinstance(true_fitness, np.ndarray), "True values should be numpy array"
        assert (
            predictions.shape == true_fitness.shape
        ), "Mismatch in truth and prediction shapes"
        n_fitness_vals = predictions.shape[1]
        # Evaluate the predictions using spearman rho
        rho_vals = np.array(
            [
                ss.spearmanr(true_fitness[:, i], predictions[:, i])[0]
                for i in range(n_fitness_vals)
            ]
        )
        # Evaluate predictions using mse. Keep errors separate by label.
        mses = sm.mean_squared_error(
            true_fitness, predictions, multioutput="raw_values"
        )
        return [predictions, [rho_vals, mses]]"""
        
        return self.evaluate_metrics(predictions, true_fitness, muts_preencoded, **pred_kwargs)
    
    def evaluate_metrics(self, predictions, true_fitness, muts_preencoded=False, **pred_kwargs):
        """
        Wraps making evaluation metrics calculation given predictions.
        predictions (ndarray): 
        true_fitness (ndarray): The true fitness values of the mutants.
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        """
        # print("predictions.shape, true_fitness.shape")
        # print(predictions.shape, true_fitness.shape)
        assert isinstance(true_fitness, np.ndarray), "True values should be numpy array"
        assert (
            predictions.shape == true_fitness.shape
        ), "Mismatch in truth and prediction shapes"
        n_fitness_vals = predictions.shape[1]

        # Evaluate the predictions using spearman rho
        rho_vals = np.array(
            [
                ss.spearmanr(true_fitness[:, i], predictions[:, i])[0]
                for i in range(n_fitness_vals)
            ]
        )

        # Evaluate predictions using mse. Keep errors separate by label.
        mses = sm.mean_squared_error(
            true_fitness, predictions, multioutput="raw_values"
        )

        return [predictions, [rho_vals, mses]]

    @abstractmethod
    def _check_mutants(self, mutants, muts_preencoded):
        pass

    @abstractmethod
    def _complete_prediction(self, mutants, muts_preencoded, **pred_kwargs):
        """
        Every child class will need to finish up the prediction function.
        """
        pass

    @property
    def constructor_args(self):
        return self._constructor_args

    @property
    def constructor_kwargs(self):
        return self._constructor_kwargs


# Define an abstract class for unsupervised models
class AbstractUnsupervisedModelWrapper(AbstractModelWrapper):
    def format_mutants(self, unformatted_mutlists, *reformat_args):

        # First confirm that the unformatted mutants are correct
        self._check_unformatted_mutants(unformatted_mutlists)

        # Complete formatting
        reformatted_mutlists = self._complete_reformat(
            unformatted_mutlists, *reformat_args
        )

        # Check the mutants at the end for accuracy
        self._check_formatted_mutants(reformatted_mutlists)

        return reformatted_mutlists

    def _check_mutants(self, mutants, muts_preencoded):

        # The function we use for checking depends on whether or not the mutants
        # are pre-encoded
        if muts_preencoded:
            self._check_formatted_mutants(mutants)
        else:
            self._check_unformatted_mutants(mutants)

    @abstractmethod
    def _check_formatted_mutants(self, mutants):
        """
        Confirms that the mutants are input to the class in a format expected by
        the model.
        """
        pass

    # All classes will need a function for formatting mutants. This can mean
    # tokenizing them for the NLP models or just rearranging the input mutants
    # for EVcouplings or triad.
    @abstractmethod
    def _complete_reformat(self, mutants):
        """
        When mutations are not pre-encoded, this function handles their processing
        so that they can be converted from the standard input format to a format
        useful for the model. This must be overwritten by inheriting classes.
        """
        pass


# Define an abstract class for supervised models
class AbstractSupervisedModelWrapper(AbstractModelWrapper):
    """
    This is an abstract class for wrapping all supervised models used for making
    single-to-multi mutant predictions. This class will be inherited by classes
    specific to the different model architectures/types we will be training.
    """

    # We add a fitness scaler to the initialized attributes
    def __init__(
        self,
        fasta_loc,
        model_class,
        encoder=None,
        *constructor_args,
        **constructor_kwargs
    ):

        # Initialize using parent method
        super().__init__(fasta_loc, *constructor_args, **constructor_kwargs)

        # Record the model class
        self.model_class = model_class

        # print(
        #     """initilizing models in AbstractSupervisedModelWrapper \
        #     with constructor_kwargs and self.constructor_kwargs"""
        # )
        # print(constructor_kwargs)
        # print(self.constructor_kwargs)

        # Initialize the model
        self._initialize_model()

        # Specify the encoder encoding type
        self.encoder = encoder

        # The default fitness scaler is "None". We initialize one during training.
        self.fitness_scaler = None

    # All classes will have a train method. This method might not do anything
    # for some of them though (like triad)
    def train(
        self,
        train_mutants,
        train_fitnesses,
        test_mutants=None,
        test_fitnesses=None,
        muts_preencoded=False,
        batch_size=BATCH_SIZE,
        _skip_check=False,
        **pred_kwargs
    ):
        """
        Trains a model given input mutants and fitnesses.
        Parameters
        ----------
        mutants (tuple of tuple of tuples OR ndarray/torch tensor):
            If a tuple of tuple of tuples: All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
            If ndarray or torch tensor: This is mutants that have already been
            encoded.
        fitnesses (2D array with same type as `mutants` after they have
        been encoded):
            Provides the labels against which training will be performed
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        _skip_check (bool): Private method. Will skip checking the input data.
            Used in conjunction with self.train_cv (self.train_cv handles
            checking, so there's no need to do it on every iteration)
        """
        # Confirm training inputs are acceptable
        if not _skip_check:
            self._check_training_data(
                train_mutants, train_fitnesses, None, muts_preencoded
            )

        # print("""initilizing models in the train \ 
        #     inside the AbstractSupervisedModelWrapper \
        #     with self.constructor_args and self.constructor_kwargs""")
        
        # print(self.constructor_args)
        # print(self.constructor_kwargs)
        
        # Initialize a new model.
        self._initialize_model()

        # Build a scaler and scale fitnesses
        self.fitness_scaler = StandardScaler()
        scaled_train_fitnesses = self.fitness_scaler.fit_transform(train_fitnesses)

        # Check testing data if provided. Scale it too.
        if test_mutants is not None:
            assert test_fitnesses is not None, "Did not provide test_fitnesses"
            self._check_training_data(
                test_mutants, test_fitnesses, None, muts_preencoded
            )
            scaled_test_fitnesses = self.fitness_scaler.transform(
                test_fitnesses, copy=True
            )
        else:
            scaled_test_fitnesses = None

        # print("Now complete training with pred_kwargs")
        # print(pred_kwargs)

        return self._complete_training(
            train_mutants,
            scaled_train_fitnesses,
            test_mutants,
            scaled_test_fitnesses,
            muts_preencoded,
            batch_size,
            **pred_kwargs
        )

    # All classes will have a cv-train method.
    def train_cv(
        self,
        mutants,
        fitnesses,
        muts_preencoded=False,
        positions=None,
        n_cv=N_CV,
        shuffle=True,
        random_state=2,
        split_type="random",
        **pred_kwargs
    ):
        """
        Performs kfold cross validation to train a model.
        """
        # Check the training data
        self._check_training_data(mutants, fitnesses, positions, muts_preencoded)

        # Generate splits from the cross-validator
        splits = self._build_splits(
            fitnesses, positions, n_cv, shuffle, random_state, split_type
        )

        # Loop over all splits
        all_test_errs = np.empty(
            [n_cv, N_EVAL_METRICS, fitnesses.shape[1]]
        )  # (n_cv, N metrics, n fitness values)

        stop_epoch_sum = 0

        for i, (train_inds, test_inds) in enumerate(splits):
            print("Running fold {0}".format(i+1))
            # Make the split. How the split is performed depends on the dtype
            # of the mutants
            train_muts, test_muts = self._make_split(train_inds, test_inds, mutants)
            train_fitness, test_fitness = self._make_split(
                train_inds, test_inds, fitnesses
            )

            # Train
            # print("""This is the train inside the AbstractSupervisedModelWrapper \
            #     train_cv with pred_kwargs""")
            # print(pred_kwargs)

            train_op = self.train(
                train_muts,
                train_fitness,
                test_mutants=test_muts,
                test_fitnesses=test_fitness,
                muts_preencoded=muts_preencoded,
                _skip_check=True,
                **pred_kwargs
            )

            # get the number of early stopping epoch
            if train_op is not None:
                stop_epoch_sum += train_op

            # Evaluate
            # print("""This is the evaluate inside the AbstractSupervisedModelWrapper \
            # train_cv with pred_kwargs""")
            # print(pred_kwargs)

            all_test_errs[i] = self.evaluate(
                test_muts, test_fitness, muts_preencoded=muts_preencoded, **pred_kwargs
            )[1]

        # Return the average test error over the different splits
        # print("avg epoch after train_cv {0}".format(np.int(stop_epoch_sum/n_cv)+1))
        
        """print("all_test_errs")
        print(all_test_errs)
        
        print("all_test_errs taking the mean")
        print(all_test_errs.mean(axis=0))"""
        
        return all_test_errs.mean(axis=0), np.int(stop_epoch_sum / n_cv) + 1

    # We extend the prediction method to rescale any predictions made
    def predict(
        self, mutants, muts_preencoded=False, batch_size=BATCH_SIZE, **pred_kwargs
    ):

        # print(
        #     "This is the predict inside the AbstractSupervisedModelWrapper with pred_kwargs"
        # )
        # print(pred_kwargs)

        # Confirm that we have a scaler
        assert (
            self.fitness_scaler is not None
        ), "Must train model before making predictions"

        # See if we need an encoder
        self._check_for_encoder(muts_preencoded)

        # Get the unscaled predictions
        unscaled_predictions = super().predict(
            mutants,
            muts_preencoded=muts_preencoded,
            batch_size=batch_size,
            **pred_kwargs
        )

        # print("mutants and unscaled_predictions dim:")
        # print(mutants.shape, unscaled_predictions.shape)

        # Return the rescaled predictions
        return self.fitness_scaler.inverse_transform(unscaled_predictions)

    def hyperopt_gridsearch(
        self,
        mutants,
        fitnesses,
        constructor_args_list,
        constructor_kwargs_list,
        **cv_kwargs
    ):
        """
        Performs a grid search to identify the optimal hyperparameters for the
        model.
        """
        
        print("""Inside hyperopt_gridsearch in AbstractSupervisedModelWrapper \
        with constructor_kwargs_list and cv_kwargs""")
        print(constructor_kwargs_list, cv_kwargs)

        # Confirm that there are as many constructor args as there are kwargs
        n_tests = len(constructor_args_list)
        assert n_tests == len(constructor_kwargs_list)

        best_cv_rho = 0
        best_cv_mse = np.inf
        opt_params = np.nan
        train_params = {}

        # Loop over all combinations of constructor args and kwargs and test
        n_fitness_vals = fitnesses.shape[1]
        hyperopt_test_errs = np.empty([n_tests, N_EVAL_METRICS, n_fitness_vals])
        for i, (constructor_args, constructor_kwargs) in enumerate(
            zip(constructor_args_list, constructor_kwargs_list)
        ):

            # Reassign constructor args and kwargs
            self._constructor_args = constructor_args
            self._constructor_kwargs = deepcopy(constructor_kwargs)
            
            # print("""Inside hyperopt_gridsearch in AbstractSupervisedModelWrapper \
            # before train_cv constructor_kwargs and self._constructor_kwargs""")
            # print(constructor_kwargs, self._constructor_kwargs)

            # Run cross-validation and record error
            hyperopt_test_errs[i], stop_epoch_avg = self.train_cv(
                mutants, fitnesses, **cv_kwargs
            )
            # print("stop_epoch_avg in hyper tune is {0}".format(stop_epoch_avg))

            current_rho = hyperopt_test_errs[i][0][0]
            current_mse = hyperopt_test_errs[i][1][0]
            # print("current rho: {0}, mse: {1}".format(current_rho, current_mse))

            """best_cv_mse, best_cv_rho, opt_params, _ = pick_best_params(current_mse, best_cv_mse, 
                                                                    current_rho, best_cv_rho, 
                                                                    self._constructor_kwargs)"""
            """# Update the best errors and associated kwargs
            if current_mse != np.nan and current_mse < best_cv_mse:
                best_cv_mse = current_mse
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs
            elif (
                current_mse != np.nan
                and current_rho != np.nan
                and current_mse == best_cv_mse
                and current_rho > best_cv_rho
            ):
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs"""

            if (current_mse != np.nan and current_mse < best_cv_mse) or (
                current_mse != np.nan
                and current_rho != np.nan
                and current_mse == best_cv_mse
                and current_rho > best_cv_rho
            ):
                best_cv_mse = current_mse
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs

            if stop_epoch_avg != 1:
                train_params = {"epochs": stop_epoch_avg}

        print(
            "Chosen best opt_params, train_params: {0} {1}".format(
                opt_params, train_params
            )
        )
        
        return hyperopt_test_errs, opt_params, train_params, best_cv_mse, best_cv_rho

    # Confirm that we have an encoder if mutations are not preencoded
    def _check_for_encoder(self, muts_preencoded):
        if not muts_preencoded:
            assert (
                self.encoder is not None
            ), "If muts are not pre-encoded, an encoder is needed"

    def _check_training_data(self, mutants, fitnesses, positions, muts_preencoded):
        """
        Confirms assumptions that we make about the training data as well as
        encodes mutants if a tuple of tuple of tuples was passed in.
        """
        # See if we need an encoder
        self._check_for_encoder(muts_preencoded)

        # Confirm that the input mutants will work. Determine whether the mutants
        # are encoded or not.
        self._check_mutants(mutants, muts_preencoded)

        # Confirm that (1) `fitnesses` is 2D and (2) that the lengths of the
        # mutant and fitness arrays match
        assert len(fitnesses.shape) == 2, "Expect 2D array for fitness"
        assert len(mutants) == len(fitnesses), "Mismatch in training data lengths"

        # If the mutants are pre-encoded, then we expect the fitnesses and
        # mutants to have the same datatype
        if muts_preencoded:
            assert type(fitnesses) is type(
                mutants
            ), "Expect x and y to have the same type"

        # If positions are provided, then they should be 1d and and have the
        # same length as mutants and fitnesses
        if positions is not None:
            assert len(positions.shape) == 1, "Expect 1D array for positions"
            assert len(positions) == len(
                fitnesses
            ), "Mismatch in number of positions provided"

    def _build_splits(
        self, fitnesses, positions, n_cv, shuffle, random_state, split_type
    ):

        # Confirm that fitnesses are a numpy array
        assert isinstance(fitnesses, np.ndarray), "Fitnesses expected to be numpy array"

        # Group split kwargs
        splargs = {"n_splits": n_cv, "shuffle": shuffle, "random_state": random_state}

        # Initialize the appropriate object. If the split type is "random", then
        # we slit randomly by position
        if split_type == "random":

            # Build the object
            kfold_obj = KFold(**splargs)

            # Generate splits
            splits = kfold_obj.split(fitnesses)

        # If the split type is "even_pos", then we split such that each position
        # is roughly equivalent represented in each split
        elif split_type == "even_pos":

            # Confirm that classes are provided, the positions are a numpy array,
            # the positions have the same length as fitnesses, and that the positions
            # are a 1d numpy array
            assert positions is not None, "Must provide position information"
            assert isinstance(positions, np.ndarray), "Positions should be numpy array"
            assert len(positions.shape) == 1, "Positions should be 1d array"
            assert len(positions) == len(
                fitnesses
            ), "Mismatch between fitnesses and positions"

            # Build the object
            kfold_obj = StratifiedKFold(**splargs)

            # Generate splits
            splits = kfold_obj.split(X=fitnesses, y=positions)

        # Anything else and we have an unrecognized splitter
        else:
            raise AssertionError("Unrecognized kfold type requested")

        return splits

    def _make_split(self, train_inds, test_inds, to_split):

        # Confirm that the inputs inds are numpy arrays
        assert isinstance(train_inds, np.ndarray)
        assert isinstance(test_inds, np.ndarray)

        # If the object to be split is a numpy array, just fancy index
        if isinstance(to_split, np.ndarray):
            return to_split[train_inds], to_split[test_inds]

        # If the object to be split is a torch tensor, convert training and
        # testing inds to tensors first, then split
        elif isinstance(to_split, torch.Tensor):
            return (
                to_split[torch.from_numpy(train_inds)],
                to_split[torch.from_numpy(test_inds)],
            )

        # If the object to be split is a tuple, then we use list-comp
        elif isinstance(to_split, tuple):
            return (
                tuple(to_split[ind] for ind in train_inds),
                tuple(to_split[ind] for ind in test_inds),
            )

        else:
            raise AssertionError("Unknown datatype input to cross-validation")

    def _check_mutants(self, mutants, muts_preencoded):
        """
        We only check if muts and not preencoded. There is no good way to check
        encoded mutations.
        """
        # The function we use for checking depends on whether or not the mutants
        # are pre-encoded
        if not muts_preencoded:
            self._check_unformatted_mutants(mutants)

    @abstractmethod
    def _initialize_model(self):
        """
        Builds a new instance of self.model using self.constructor_args and
        self.constructor_kwargs
        """
        pass

    @abstractmethod
    def _complete_training(
        self,
        train_mutants,
        scaled_train_fitnesses,
        test_mutants,
        scaled_test_fitnesses,
        muts_preencoded,
        batch_size,
        **pred_kwargs
    ):
        """
        Finishes the `train` method for each child class.
        """
        pass

    @abstractmethod
    def _get_model_class_name(self):
        """
        Obtain the model class name of the class as a string
        """
        pass