In [2]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [3]:
from sklearn.linear_model import Ridge

In [8]:
from scr.preprocess.data_process import split_protrain_loader

In [9]:
train_loader, val_loader, test_loader = split_protrain_loader(
    dataset_path="data/proeng/gb1/two_vs_rest.pkl",
    encoder_name="esm1_t6_43M_UR50S",
    embed_layer=6,
    embed_batch_size=128,
    flatten_emb="mean",
    embed_path=None,
    seq_start_idx=0,
    seq_end_idx=56,
    subset_list=["train", "val", "test"],
    loader_batch_size=64,
    worker_seed=42,
)

Loading esm1_t6_43M_UR50S using 6 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 3/3 [00:04<00:00,  1.59s/it]
Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Loading esm1_t6_43M_UR50S using 6 layer embedding


100%|██████████| 1/1 [00:00<00:00,  2.14it/s]
Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Loading esm1_t6_43M_UR50S using 6 layer embedding


100%|██████████| 65/65 [02:14<00:00,  2.08s/it]


In [11]:
# training model with 0.5 alpha value
model = Ridge(alpha = 0.5, normalize = True, tol = 0.001, \
              solver ='auto', random_state = 42)

train_preds = []
val_preds = []
test_preds = []

for (x, y, _, _, _) in train_loader:
    model.fit(x, y)
    
for (x, y, _, _, _) in train_loader:
    train_preds.append(model.predict(x))

for (x, y, _, _, _) in val_loader:
    val_preds.append(model.predict(x))

# need to add to loop over alpha etc

for (x, y, _, _, _) in test_loader:
    test_preds.append(model.predict(x))

In [13]:
import numpy as np

In [15]:
len(train_preds)

6

In [16]:
np.concatenate(train_preds).shape

(381, 1)

In [None]:
"""Classes for models"""

# Import third party modules
from copy import deepcopy
from abc import ABC, abstractmethod

from sklearn.preprocessing import StandardScaler
import sklearn.metrics as sm
import numpy as np
import scipy.stats as ss
import torch

# Import custom objects
from code.params.globals import BATCH_SIZE
from code.params.model_info import *
# from code.models.utils import pick_best_params


# Define an abstract class for holding all models
class AbstractModelWrapper(ABC):
    """
    This is an abstract class for wrapping all models used for making
    mutant fitness predictions. It will be inherited by a model class
    for making both supervised and unsupervised predictions.
    """

    # Initialization loads the parent sequence
    def __init__(self, *constructor_args, **constructor_kwargs):
        """
        Loads the reference sequence that will be used to make predictions and
        stores parameters for the model that will be trained.
        
        Args:
        - constructor_args: Positional arguments that will be passed to the model
            constructor for this wrapper.
        - constructor_kwargs: Keyword arguments that will be passed to the model
            constructor for this wrapper.
        """
        # Initialize the sequence loader
        super().__init__(fasta_loc)

        # print("AbstractModelWrapper constructor_kwargs inputs before assigning")
        # print(constructor_kwargs)

        # Store the constructor args and kwargs as instance variables
        self._constructor_args = constructor_args
        self._constructor_kwargs = deepcopy(constructor_kwargs)

        # print("AbstractModelWrapper constructor_kwargs inputs")
        # print(constructor_kwargs)
        # print("AbstractModelWrapper self._constructor_kwargs")
        # print(self._constructor_kwargs)

        # The default model is "None". We need to initialize a model later.
        self.model = None

    # All classes will be able to make a prediction
    def predict(self, mutants, muts_preencoded=False, **pred_kwargs):
        """
        Makes predictions given an input of mutants or encoded mutants. If
        unencoded mutants are passed in, then they will first be encoded.
        Otherwise, encodings passed in will be assumed correct and fed into the
        appropriate model.
        Parameters
        ----------
        mutants (tuple of tuple of tuples OR ndarray/torch tensor):
            If a tuple of tuple of tuples: All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
            If ndarray or torch tensor: This is mutants that have already been
            encoded.
        Returns
        -------
        predictions (ndarray): Numpy array containing the predictions for the
            input set of mutants.
        Examples
        --------
        An example input for `mutants` when tuple of tuple of tuples:
            mutants = (
                ((V, 39, D), (D, 40, G)),
                ((V, 39, D), (D, 40, G), (G, 41, F))
                )
        """

        # print("""This is the predict inside the AbstractModelWrapper \
        # calling _complete_prediction""")

        # Confirm that we have a model and scaler for making predictions
        assert self.model is not None, "Must train a model before making predictions"

        # Check mutations for accuracy based on whether they are pre-encoded
        # or not
        self._check_mutants(mutants, muts_preencoded)

        # Every child class will need to finish this function
        unscaled_preds = self._complete_prediction(
            mutants, muts_preencoded=muts_preencoded, **pred_kwargs
        )

        # print("mutants and unscaled_preds dims:")
        # print(mutants.shape, unscaled_preds.shape)

        # Confirm that predictions are a numpy array and that they are 2d
        assert isinstance(
            unscaled_preds, np.ndarray
        ), "Predictions should be numpy array"
        assert len(unscaled_preds.shape) == 2, "Expect 2D prediction array"

        # Return the unscaled preds. Supervised models will scale them based on
        # the standard scaler used at input
        return unscaled_preds

    def evaluate(self, mutants, true_fitness, muts_preencoded=False, **pred_kwargs):
        """
        Wraps making predictions and evaluation of predictions using a number
        of evaluation metrics.
        mutants (tuple of tuple of tuples): All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
        true_fitness (ndarray): The true fitness values of the mutants.
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        Examples
        --------
        An example input for `mutants`:
            mutants = (
                ((V, 39, D), (D, 40, G)),
                ((V, 39, D), (D, 40, G), (G, 41, F))
                )
        """
        
        # print("This is the evulate inside the AbstractModelWrapper")

        # Make predictions
        predictions = self.predict(
            mutants, muts_preencoded=muts_preencoded, **pred_kwargs
        )

        # Confirm that the predictions and true values are the same shape and
        # that they are two dimensional
        
        # print("The kwargs for the evaluate function in AbstractModelWrapper")
        # print(pred_kwargs)
        # print("After the predict in the evaluate function in AbstractModelWrapper")
        # print(mutants.shape, predictions.shape, true_fitness.shape)
        
        """assert isinstance(true_fitness, np.ndarray), "True values should be numpy array"
        assert (
            predictions.shape == true_fitness.shape
        ), "Mismatch in truth and prediction shapes"
        n_fitness_vals = predictions.shape[1]
        # Evaluate the predictions using spearman rho
        rho_vals = np.array(
            [
                ss.spearmanr(true_fitness[:, i], predictions[:, i])[0]
                for i in range(n_fitness_vals)
            ]
        )
        # Evaluate predictions using mse. Keep errors separate by label.
        mses = sm.mean_squared_error(
            true_fitness, predictions, multioutput="raw_values"
        )
        return [predictions, [rho_vals, mses]]"""
        
        return self.evaluate_metrics(predictions, true_fitness, muts_preencoded, **pred_kwargs)
    
    def evaluate_metrics(self, predictions, true_fitness, muts_preencoded=False, **pred_kwargs):
        """
        Wraps making evaluation metrics calculation given predictions.
        predictions (ndarray): 
        true_fitness (ndarray): The true fitness values of the mutants.
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        """
        # print("predictions.shape, true_fitness.shape")
        # print(predictions.shape, true_fitness.shape)
        assert isinstance(true_fitness, np.ndarray), "True values should be numpy array"
        assert (
            predictions.shape == true_fitness.shape
        ), "Mismatch in truth and prediction shapes"
        n_fitness_vals = predictions.shape[1]

        # Evaluate the predictions using spearman rho
        rho_vals = np.array(
            [
                ss.spearmanr(true_fitness[:, i], predictions[:, i])[0]
                for i in range(n_fitness_vals)
            ]
        )

        # Evaluate predictions using mse. Keep errors separate by label.
        mses = sm.mean_squared_error(
            true_fitness, predictions, multioutput="raw_values"
        )

        return [predictions, [rho_vals, mses]]

    @abstractmethod
    def _check_mutants(self, mutants, muts_preencoded):
        pass

    @abstractmethod
    def _complete_prediction(self, mutants, muts_preencoded, **pred_kwargs):
        """
        Every child class will need to finish up the prediction function.
        """
        pass

    @property
    def constructor_args(self):
        return self._constructor_args

    @property
    def constructor_kwargs(self):
        return self._constructor_kwargs


# Define an abstract class for unsupervised models
class AbstractUnsupervisedModelWrapper(AbstractModelWrapper):
    def format_mutants(self, unformatted_mutlists, *reformat_args):

        # First confirm that the unformatted mutants are correct
        self._check_unformatted_mutants(unformatted_mutlists)

        # Complete formatting
        reformatted_mutlists = self._complete_reformat(
            unformatted_mutlists, *reformat_args
        )

        # Check the mutants at the end for accuracy
        self._check_formatted_mutants(reformatted_mutlists)

        return reformatted_mutlists

    def _check_mutants(self, mutants, muts_preencoded):

        # The function we use for checking depends on whether or not the mutants
        # are pre-encoded
        if muts_preencoded:
            self._check_formatted_mutants(mutants)
        else:
            self._check_unformatted_mutants(mutants)

    @abstractmethod
    def _check_formatted_mutants(self, mutants):
        """
        Confirms that the mutants are input to the class in a format expected by
        the model.
        """
        pass

    # All classes will need a function for formatting mutants. This can mean
    # tokenizing them for the NLP models or just rearranging the input mutants
    # for EVcouplings or triad.
    @abstractmethod
    def _complete_reformat(self, mutants):
        """
        When mutations are not pre-encoded, this function handles their processing
        so that they can be converted from the standard input format to a format
        useful for the model. This must be overwritten by inheriting classes.
        """
        pass


# Define an abstract class for supervised models
class AbstractSupervisedModelWrapper(AbstractModelWrapper):
    """
    This is an abstract class for wrapping all supervised models used for making
    single-to-multi mutant predictions. This class will be inherited by classes
    specific to the different model architectures/types we will be training.
    """

    # We add a fitness scaler to the initialized attributes
    def __init__(
        self,
        fasta_loc,
        model_class,
        encoder=None,
        *constructor_args,
        **constructor_kwargs
    ):

        # Initialize using parent method
        super().__init__(fasta_loc, *constructor_args, **constructor_kwargs)

        # Record the model class
        self.model_class = model_class

        # print(
        #     """initilizing models in AbstractSupervisedModelWrapper \
        #     with constructor_kwargs and self.constructor_kwargs"""
        # )
        # print(constructor_kwargs)
        # print(self.constructor_kwargs)

        # Initialize the model
        self._initialize_model()

        # Specify the encoder encoding type
        self.encoder = encoder

        # The default fitness scaler is "None". We initialize one during training.
        self.fitness_scaler = None

    # All classes will have a train method. This method might not do anything
    # for some of them though (like triad)
    def train(
        self,
        train_mutants,
        train_fitnesses,
        test_mutants=None,
        test_fitnesses=None,
        muts_preencoded=False,
        batch_size=BATCH_SIZE,
        _skip_check=False,
        **pred_kwargs
    ):
        """
        Trains a model given input mutants and fitnesses.
        Parameters
        ----------
        mutants (tuple of tuple of tuples OR ndarray/torch tensor):
            If a tuple of tuple of tuples: All mutations to make a prediction
            for. Will use the reference sequence loaded during initialization in
            combination with the mutants passed in here. Each entry in the
            outer tuple is a tuple containing mutation information for one
            variant. Each mutation is specified by its own tuple with the format
            (Reference AA, Mutation Position, Mutant AA). Mutation positions are
            assumed to be 1-indexed relative to the sequence. See an example
            below.
            If ndarray or torch tensor: This is mutants that have already been
            encoded.
        fitnesses (2D array with same type as `mutants` after they have
        been encoded):
            Provides the labels against which training will be performed
        muts_preencoded (bool): Whether or not the input mutants are pre-encoded.
            If not pre-encoded, this function expects a tuple of tuples of tuples.
            If it is pre-encoded, then the mutants should be formatted already
            to go into whatever model will be making predictions.
        _skip_check (bool): Private method. Will skip checking the input data.
            Used in conjunction with self.train_cv (self.train_cv handles
            checking, so there's no need to do it on every iteration)
        """
        # Confirm training inputs are acceptable
        if not _skip_check:
            self._check_training_data(
                train_mutants, train_fitnesses, None, muts_preencoded
            )

        # print("""initilizing models in the train \ 
        #     inside the AbstractSupervisedModelWrapper \
        #     with self.constructor_args and self.constructor_kwargs""")
        
        # print(self.constructor_args)
        # print(self.constructor_kwargs)
        
        # Initialize a new model.
        self._initialize_model()

        # Build a scaler and scale fitnesses
        self.fitness_scaler = StandardScaler()
        scaled_train_fitnesses = self.fitness_scaler.fit_transform(train_fitnesses)

        # Check testing data if provided. Scale it too.
        if test_mutants is not None:
            assert test_fitnesses is not None, "Did not provide test_fitnesses"
            self._check_training_data(
                test_mutants, test_fitnesses, None, muts_preencoded
            )
            scaled_test_fitnesses = self.fitness_scaler.transform(
                test_fitnesses, copy=True
            )
        else:
            scaled_test_fitnesses = None

        # print("Now complete training with pred_kwargs")
        # print(pred_kwargs)

        return self._complete_training(
            train_mutants,
            scaled_train_fitnesses,
            test_mutants,
            scaled_test_fitnesses,
            muts_preencoded,
            batch_size,
            **pred_kwargs
        )

    # All classes will have a cv-train method.
    def train_cv(
        self,
        mutants,
        fitnesses,
        muts_preencoded=False,
        positions=None,
        n_cv=N_CV,
        shuffle=True,
        random_state=2,
        split_type="random",
        **pred_kwargs
    ):
        """
        Performs kfold cross validation to train a model.
        """
        # Check the training data
        self._check_training_data(mutants, fitnesses, positions, muts_preencoded)

        # Generate splits from the cross-validator
        splits = self._build_splits(
            fitnesses, positions, n_cv, shuffle, random_state, split_type
        )

        # Loop over all splits
        all_test_errs = np.empty(
            [n_cv, N_EVAL_METRICS, fitnesses.shape[1]]
        )  # (n_cv, N metrics, n fitness values)

        stop_epoch_sum = 0

        for i, (train_inds, test_inds) in enumerate(splits):
            print("Running fold {0}".format(i+1))
            # Make the split. How the split is performed depends on the dtype
            # of the mutants
            train_muts, test_muts = self._make_split(train_inds, test_inds, mutants)
            train_fitness, test_fitness = self._make_split(
                train_inds, test_inds, fitnesses
            )

            # Train
            # print("""This is the train inside the AbstractSupervisedModelWrapper \
            #     train_cv with pred_kwargs""")
            # print(pred_kwargs)

            train_op = self.train(
                train_muts,
                train_fitness,
                test_mutants=test_muts,
                test_fitnesses=test_fitness,
                muts_preencoded=muts_preencoded,
                _skip_check=True,
                **pred_kwargs
            )

            # get the number of early stopping epoch
            if train_op is not None:
                stop_epoch_sum += train_op

            # Evaluate
            # print("""This is the evaluate inside the AbstractSupervisedModelWrapper \
            # train_cv with pred_kwargs""")
            # print(pred_kwargs)

            all_test_errs[i] = self.evaluate(
                test_muts, test_fitness, muts_preencoded=muts_preencoded, **pred_kwargs
            )[1]

        # Return the average test error over the different splits
        # print("avg epoch after train_cv {0}".format(np.int(stop_epoch_sum/n_cv)+1))
        
        """print("all_test_errs")
        print(all_test_errs)
        
        print("all_test_errs taking the mean")
        print(all_test_errs.mean(axis=0))"""
        
        return all_test_errs.mean(axis=0), np.int(stop_epoch_sum / n_cv) + 1

    # We extend the prediction method to rescale any predictions made
    def predict(
        self, mutants, muts_preencoded=False, batch_size=BATCH_SIZE, **pred_kwargs
    ):

        # print(
        #     "This is the predict inside the AbstractSupervisedModelWrapper with pred_kwargs"
        # )
        # print(pred_kwargs)

        # Confirm that we have a scaler
        assert (
            self.fitness_scaler is not None
        ), "Must train model before making predictions"

        # See if we need an encoder
        self._check_for_encoder(muts_preencoded)

        # Get the unscaled predictions
        unscaled_predictions = super().predict(
            mutants,
            muts_preencoded=muts_preencoded,
            batch_size=batch_size,
            **pred_kwargs
        )

        # print("mutants and unscaled_predictions dim:")
        # print(mutants.shape, unscaled_predictions.shape)

        # Return the rescaled predictions
        return self.fitness_scaler.inverse_transform(unscaled_predictions)

    def hyperopt_gridsearch(
        self,
        mutants,
        fitnesses,
        constructor_args_list,
        constructor_kwargs_list,
        **cv_kwargs
    ):
        """
        Performs a grid search to identify the optimal hyperparameters for the
        model.
        """
        
        print("""Inside hyperopt_gridsearch in AbstractSupervisedModelWrapper \
        with constructor_kwargs_list and cv_kwargs""")
        print(constructor_kwargs_list, cv_kwargs)

        # Confirm that there are as many constructor args as there are kwargs
        n_tests = len(constructor_args_list)
        assert n_tests == len(constructor_kwargs_list)

        best_cv_rho = 0
        best_cv_mse = np.inf
        opt_params = np.nan
        train_params = {}

        # Loop over all combinations of constructor args and kwargs and test
        n_fitness_vals = fitnesses.shape[1]
        hyperopt_test_errs = np.empty([n_tests, N_EVAL_METRICS, n_fitness_vals])
        for i, (constructor_args, constructor_kwargs) in enumerate(
            zip(constructor_args_list, constructor_kwargs_list)
        ):

            # Reassign constructor args and kwargs
            self._constructor_args = constructor_args
            self._constructor_kwargs = deepcopy(constructor_kwargs)
            
            # print("""Inside hyperopt_gridsearch in AbstractSupervisedModelWrapper \
            # before train_cv constructor_kwargs and self._constructor_kwargs""")
            # print(constructor_kwargs, self._constructor_kwargs)

            # Run cross-validation and record error
            hyperopt_test_errs[i], stop_epoch_avg = self.train_cv(
                mutants, fitnesses, **cv_kwargs
            )
            # print("stop_epoch_avg in hyper tune is {0}".format(stop_epoch_avg))

            current_rho = hyperopt_test_errs[i][0][0]
            current_mse = hyperopt_test_errs[i][1][0]
            # print("current rho: {0}, mse: {1}".format(current_rho, current_mse))

            """best_cv_mse, best_cv_rho, opt_params, _ = pick_best_params(current_mse, best_cv_mse, 
                                                                    current_rho, best_cv_rho, 
                                                                    self._constructor_kwargs)"""
            """# Update the best errors and associated kwargs
            if current_mse != np.nan and current_mse < best_cv_mse:
                best_cv_mse = current_mse
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs
            elif (
                current_mse != np.nan
                and current_rho != np.nan
                and current_mse == best_cv_mse
                and current_rho > best_cv_rho
            ):
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs"""

            if (current_mse != np.nan and current_mse < best_cv_mse) or (
                current_mse != np.nan
                and current_rho != np.nan
                and current_mse == best_cv_mse
                and current_rho > best_cv_rho
            ):
                best_cv_mse = current_mse
                best_cv_rho = current_rho
                opt_params = self._constructor_kwargs

            if stop_epoch_avg != 1:
                train_params = {"epochs": stop_epoch_avg}

        print(
            "Chosen best opt_params, train_params: {0} {1}".format(
                opt_params, train_params
            )
        )
        
        return hyperopt_test_errs, opt_params, train_params, best_cv_mse, best_cv_rho

    # Confirm that we have an encoder if mutations are not preencoded
    def _check_for_encoder(self, muts_preencoded):
        if not muts_preencoded:
            assert (
                self.encoder is not None
            ), "If muts are not pre-encoded, an encoder is needed"

    def _check_training_data(self, mutants, fitnesses, positions, muts_preencoded):
        """
        Confirms assumptions that we make about the training data as well as
        encodes mutants if a tuple of tuple of tuples was passed in.
        """
        # See if we need an encoder
        self._check_for_encoder(muts_preencoded)

        # Confirm that the input mutants will work. Determine whether the mutants
        # are encoded or not.
        self._check_mutants(mutants, muts_preencoded)

        # Confirm that (1) `fitnesses` is 2D and (2) that the lengths of the
        # mutant and fitness arrays match
        assert len(fitnesses.shape) == 2, "Expect 2D array for fitness"
        assert len(mutants) == len(fitnesses), "Mismatch in training data lengths"

        # If the mutants are pre-encoded, then we expect the fitnesses and
        # mutants to have the same datatype
        if muts_preencoded:
            assert type(fitnesses) is type(
                mutants
            ), "Expect x and y to have the same type"

        # If positions are provided, then they should be 1d and and have the
        # same length as mutants and fitnesses
        if positions is not None:
            assert len(positions.shape) == 1, "Expect 1D array for positions"
            assert len(positions) == len(
                fitnesses
            ), "Mismatch in number of positions provided"

    def _build_splits(
        self, fitnesses, positions, n_cv, shuffle, random_state, split_type
    ):

        # Confirm that fitnesses are a numpy array
        assert isinstance(fitnesses, np.ndarray), "Fitnesses expected to be numpy array"

        # Group split kwargs
        splargs = {"n_splits": n_cv, "shuffle": shuffle, "random_state": random_state}

        # Initialize the appropriate object. If the split type is "random", then
        # we slit randomly by position
        if split_type == "random":

            # Build the object
            kfold_obj = KFold(**splargs)

            # Generate splits
            splits = kfold_obj.split(fitnesses)

        # If the split type is "even_pos", then we split such that each position
        # is roughly equivalent represented in each split
        elif split_type == "even_pos":

            # Confirm that classes are provided, the positions are a numpy array,
            # the positions have the same length as fitnesses, and that the positions
            # are a 1d numpy array
            assert positions is not None, "Must provide position information"
            assert isinstance(positions, np.ndarray), "Positions should be numpy array"
            assert len(positions.shape) == 1, "Positions should be 1d array"
            assert len(positions) == len(
                fitnesses
            ), "Mismatch between fitnesses and positions"

            # Build the object
            kfold_obj = StratifiedKFold(**splargs)

            # Generate splits
            splits = kfold_obj.split(X=fitnesses, y=positions)

        # Anything else and we have an unrecognized splitter
        else:
            raise AssertionError("Unrecognized kfold type requested")

        return splits

    def _make_split(self, train_inds, test_inds, to_split):

        # Confirm that the inputs inds are numpy arrays
        assert isinstance(train_inds, np.ndarray)
        assert isinstance(test_inds, np.ndarray)

        # If the object to be split is a numpy array, just fancy index
        if isinstance(to_split, np.ndarray):
            return to_split[train_inds], to_split[test_inds]

        # If the object to be split is a torch tensor, convert training and
        # testing inds to tensors first, then split
        elif isinstance(to_split, torch.Tensor):
            return (
                to_split[torch.from_numpy(train_inds)],
                to_split[torch.from_numpy(test_inds)],
            )

        # If the object to be split is a tuple, then we use list-comp
        elif isinstance(to_split, tuple):
            return (
                tuple(to_split[ind] for ind in train_inds),
                tuple(to_split[ind] for ind in test_inds),
            )

        else:
            raise AssertionError("Unknown datatype input to cross-validation")

    def _check_mutants(self, mutants, muts_preencoded):
        """
        We only check if muts and not preencoded. There is no good way to check
        encoded mutations.
        """
        # The function we use for checking depends on whether or not the mutants
        # are pre-encoded
        if not muts_preencoded:
            self._check_unformatted_mutants(mutants)

    @abstractmethod
    def _initialize_model(self):
        """
        Builds a new instance of self.model using self.constructor_args and
        self.constructor_kwargs
        """
        pass

    @abstractmethod
    def _complete_training(
        self,
        train_mutants,
        scaled_train_fitnesses,
        test_mutants,
        scaled_test_fitnesses,
        muts_preencoded,
        batch_size,
        **pred_kwargs
    ):
        """
        Finishes the `train` method for each child class.
        """
        pass

    @abstractmethod
    def _get_model_class_name(self):
        """
        Obtain the model class name of the class as a string
        """
        pass