# Generate splits for an experimental dataset
This notebook shows how to generate splits for an experimental dataset, using the avGFP dataset as an example.

You can generate multiple types of splits.
- A "super test" or withholding split. It's a simple random sample of variants meant to be completely held out until the final model training and evaluation.
- Classic train, validation, and test splits based on percentages of the total dataset.
- Reduced dataset sizes for evaluating performance as a function of training set size.
- Extrapolation splits (mutation, position, score, and regime extrapolation) for testing the generalization performance of the models.


This example generates a single replicate of each type of split, but it is recommended to use multiple replicates in practice.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# define the name of the project root directory
project_root_dir_name = "metl"

# find the project root by checking each parent directory
current_dir = os.getcwd()
while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):
    current_dir = os.path.dirname(current_dir)

# change the current working directory to the project root directory
if os.path.basename(current_dir) == project_root_dir_name:
    os.chdir(current_dir)
else:
    print("project root directory not found")
    
# add the project code folder to the system path so imports work
module_path = os.path.abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import random
import split_dataset as sd
import utils

In [4]:
import logging
logging.basicConfig()
logger = logging.getLogger("METL")
logger.setLevel(logging.INFO)

# Load the dataset

In [5]:
ds_name = "avgfp"
ds = utils.load_dataset(ds_name)

# some additional info needed for extrapolation splits
datasets = utils.load_dataset_metadata()
seq_len = len(datasets[ds_name]["wt_aa"])
wt_ofs = datasets[ds_name]["wt_ofs"]

# Withhold a "super test" set

I recommend having a completely held-out "super test" set. Don't use this set for development of the algorithm and don't look at evaluation results on this set until the very end, when you are ready to publish. Here we will create a super test set for avgfp and save it to the avgfp splits directory [data/avgfp/splits](../data/avgfp/splits).



In [6]:
out_dir = "data/dms_data/avgfp/splits/"

# use a fixed random seed for demonstration purposes
# rseed = random.randint(1000, 9999)
rseed = 5958

supertest_idxs, supertest_fn = sd.supertest(ds, size=.1, rseed=rseed, out_dir=out_dir, overwrite=False)
supertest_fn

INFO:METL.split_dataset:saving supertest split to file data/dms_data/avgfp/splits/supertest_w1abc2f4e9a64_s0.1_r5958.txt


'data/dms_data/avgfp/splits/supertest_w1abc2f4e9a64_s0.1_r5958.txt'

# Standard train, validation, and test splits

This will randomly sample train, validation, and test splits from the full dataset. You must specify the size of each set as a fraction of the total number of examples.

In [9]:
out_dir = "data/dms_data/avgfp/splits/standard"

# specify the super test set from above
# this set will be withheld from this train test split
withhold_fn = "data/dms_data/avgfp/splits/supertest_w1abc2f4e9a64_s0.1_r5958.txt"

# specify 80% train, 10% validation, and 10% test sizes
train_size = 0.8
val_size = 0.1
test_size = 0.1

# multiple replicate splits
replicates = 1

# random seeds
# rseeds = [random.randint(1000, 9999) for _ in range(replicates)]
# for purposes of this demo, make the rseeds constant
rseeds = [3597]

for rseed in rseeds:    
    split, out_dir_split = sd.train_val_test(ds, 
                                             train_size=train_size, 
                                             val_size=val_size, 
                                             test_size=test_size, 
                                             withhold=withhold_fn, 
                                             out_dir=out_dir, 
                                             rseed=rseed, 
                                             overwrite=False)

INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/standard/standard_tr0.8_tu0.1_te0.1_w1abc2f4e9a64_r3597


# Resampled dataset sizes

This splits enable you to evaluate performance as a function of train size.

In [8]:
out_dir = "data/dms_data/avgfp/splits/resampled"
withhold_fn = "data/dms_data/avgfp/splits/supertest_w1abc2f4e9a64_s0.1_r5958.txt"

# specify the dataset sizes and number of replicates per dataset size
dataset_sizes = [10, 20, 40, 80, 160, 320, 640, 1280, 2560, 5120, 10240, 20480]
# just one replicate for each dataset size for this example
replicates = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# use multiple replicates in practice
# replicates = [101, 23, 11, 11, 11, 11, 7, 7, 5, 5, 3, 3]

# rseed = random.randint(1000, 9999)
rseed = 8099

# the test set is sampled from the full dataset
test_fraction = 0.1

# the validation set is sampled from the reduced dataset size
# the train set will be 1 minus the validation fraction
# so in this case, the train set will be 80%, and the validation set 20%
val_fraction = 0.2

# create the suite of resampled dataset size splits
for ds_size, reps in zip(dataset_sizes, replicates):
    splits, reduced_split_dir = sd.resampled_dataset_size(full_dataset_size=ds.shape[0], 
                                                          test_fraction=test_fraction, 
                                                          dataset_size=ds_size,
                                                          val_fraction=val_fraction,
                                                          num_replicates=reps, 
                                                          withhold=withhold_fn, 
                                                          rseed=rseed, 
                                                          out_dir=out_dir,
                                                          overwrite=False)

INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds10_val0.2_te0.1_w1abc2f4e9a64_s101_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds20_val0.2_te0.1_w1abc2f4e9a64_s23_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds40_val0.2_te0.1_w1abc2f4e9a64_s11_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds80_val0.2_te0.1_w1abc2f4e9a64_s11_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds160_val0.2_te0.1_w1abc2f4e9a64_s11_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resampled_ds320_val0.2_te0.1_w1abc2f4e9a64_s11_r8099
INFO:METL.split_dataset:saving resampled split to directory data/dms_data/avgfp/splits/resampled/resamp

# Position extrapolation

In [9]:
out_dir = "data/dms_data/avgfp/splits/position"

# 80% of positions are designated train pool, 20% are desinated test pool
train_pos_size = 0.8

# the training pool is split into 90% training set and 10% validation set 
val_size = 0.1

# if the dataset is very large or you want to standardize the dataset size at
# which you perform position extrapolation, you can optionally specify that
# dataset size here
resample_dataset_size = None

replicates = 9
# rseeds = [random.randint(1000, 9999) for _ in range(replicates)]
rseeds = [6822, 6138, 9152, 1796, 7395, 2048, 4155, 2443, 3743]

for rseed in rseeds:    
    split, out_dir_split, additional_info = sd.position_split(ds, 
                                                              seq_len, 
                                                              wt_ofs, 
                                                              train_pos_size, 
                                                              val_size,
                                                              resample_dataset_size=resample_dataset_size,
                                                              out_dir=out_dir, 
                                                              rseed=rseed, 
                                                              overwrite=False)

INFO:METL.split_dataset:num_train_positions: 190, num_test_positions: 47
INFO:METL.split_dataset:train pool size: 25528, test pool size: 655, overlap pool size: 25531
INFO:METL.split_dataset:num_train: 22975, num_val: 2553, num_test: 655
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/position/position_tr-pos0.8_tu0.1_r6822
INFO:METL.split_dataset:num_train_positions: 190, num_test_positions: 47
INFO:METL.split_dataset:train pool size: 20496, test pool size: 1115, overlap pool size: 30103
INFO:METL.split_dataset:num_train: 18446, num_val: 2050, num_test: 1115
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/position/position_tr-pos0.8_tu0.1_r6138
INFO:METL.split_dataset:num_train_positions: 190, num_test_positions: 47
INFO:METL.split_dataset:train pool size: 24690, test pool size: 704, overlap pool size: 26320
INFO:METL.split_dataset:num_train: 22221, num_val: 2469, num_test: 704
INFO:METL.split_data

# Mutation extrapolation

In [10]:
out_dir = "data/dms_data/avgfp/splits/mutation"

resample_dataset_size = None

# 80% of mutations are designated train pool, 20% are desinated test pool
train_muts_size = 0.8

# the training pool is split into 90% training set and 10% validation set 
val_size = 0.1

replicates = 9
# rseeds = [random.randint(1000, 9999) for _ in range(replicates)]
rseeds = [4419, 8561, 9891, 4386, 6389, 3367, 3294, 6504, 2035]

for rseed in rseeds:    
    split, out_dir_split, additional_info = sd.mutation_split(ds, 
                                                              train_muts_size,
                                                              val_size, 
                                                              out_dir=out_dir, 
                                                              rseed=rseed, 
                                                              resample_dataset_size=resample_dataset_size,
                                                              overwrite=False)

INFO:METL.split_dataset:number of unique mutations in ds: 1810
INFO:METL.split_dataset:num_train_mutations: 1448, num_test_mutations: 362
INFO:METL.split_dataset:train pool size: 23327, test pool size: 877, overlap pool size: 27510
INFO:METL.split_dataset:num_train: 20994, num_val: 2333, num_test: 877
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/mutation/mutation_tr-muts0.8_tu0.1_r4419
INFO:METL.split_dataset:number of unique mutations in ds: 1810
INFO:METL.split_dataset:num_train_mutations: 1448, num_test_mutations: 362
INFO:METL.split_dataset:train pool size: 22078, test pool size: 963, overlap pool size: 28673
INFO:METL.split_dataset:num_train: 19870, num_val: 2208, num_test: 963
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/mutation/mutation_tr-muts0.8_tu0.1_r8561
INFO:METL.split_dataset:number of unique mutations in ds: 1810
INFO:METL.split_dataset:num_train_mutations: 1448, num_test_mutat

# Score extrapolation

In [11]:
out_dir = "data/dms_data/avgfp/splits/score"

resample_dataset_size = None

# set the wild-type score for this dataset
wt_score = 0
score_name = "score"

# training pool is split into 90% train and 10% validation sets 
val_size = 0.1

replicates = 9
# rseeds = [random.randint(1000, 9999) for _ in range(replicates)]
rseeds = [5265, 1219, 7249, 3595, 7351, 4097, 6631, 8421, 2425]

for rseed in rseeds:    
    split, out_dir_split = sd.score_extrapolation_split(ds, 
                                                     score_name=score_name, 
                                                     wt_score=wt_score, 
                                                     val_size=val_size,
                                                     resample_dataset_size=resample_dataset_size,
                                                     out_dir=out_dir,
                                                     rseed=rseed, 
                                                     overwrite=False)

INFO:METL.split_dataset:train pool size: 46683, test pool size: 5031
INFO:METL.split_dataset:num_train: 42014, num_val: 4669, num_test: 5031
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/score/score_thresh0_tu0.1_r5265
INFO:METL.split_dataset:train pool size: 46683, test pool size: 5031
INFO:METL.split_dataset:num_train: 42014, num_val: 4669, num_test: 5031
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/score/score_thresh0_tu0.1_r1219
INFO:METL.split_dataset:train pool size: 46683, test pool size: 5031
INFO:METL.split_dataset:num_train: 42014, num_val: 4669, num_test: 5031
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/score/score_thresh0_tu0.1_r7249
INFO:METL.split_dataset:train pool size: 46683, test pool size: 5031
INFO:METL.split_dataset:num_train: 42014, num_val: 4669, num_test: 5031
INFO:METL.split_dataset:saving train-val-test split to directory

# Regime extrapolation

In [12]:
out_dir = "data/dms_data/avgfp/splits/regime"

train_regimes = 1
test_regimes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

# for the train pool (all single mutants)
# use 80% as the training set and 20% as the validation set
train_size = 0.8
val_size = 0.2

# for the test pool (all 2+ mutants), don't use all for the test set
# to help lessen risk of overfitting to full test set during development
# the test set will be just 10% of all available 2+ mutants
test_size = 0.1


replicates = 9
# rseeds = [random.randint(1000, 9999) for _ in range(replicates)]
rseeds = [8903, 1980, 8938, 9968, 3493, 3390, 2479, 2302, 3586]

for _ in range(replicates):    
    rseed = random.randint(1000,9999)
    split, out_dir_split, additional_info = sd.regime_split(ds, 
                                                            train_regimes=train_regimes, 
                                                            test_regimes=test_regimes, 
                                                            train_size=train_size, 
                                                            val_size=val_size, 
                                                            test_size=test_size,
                                                            rseed=rseed, 
                                                            out_dir=out_dir, 
                                                            overwrite=False)

INFO:METL.split_dataset:train pool size: 1084, test pool size: 50630, discard pool size: 0
INFO:METL.split_dataset:num_train: 867, num_val: 217, num_test: 5063
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/regime/regime_tr-reg1_te-reg2-3-4-5-6-7-8-9-10-11-12-13-14-15_tr0.8_tu0.2_te0.1_r6698
INFO:METL.split_dataset:train pool size: 1084, test pool size: 50630, discard pool size: 0
INFO:METL.split_dataset:num_train: 867, num_val: 217, num_test: 5063
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/regime/regime_tr-reg1_te-reg2-3-4-5-6-7-8-9-10-11-12-13-14-15_tr0.8_tu0.2_te0.1_r5138
INFO:METL.split_dataset:train pool size: 1084, test pool size: 50630, discard pool size: 0
INFO:METL.split_dataset:num_train: 867, num_val: 217, num_test: 5063
INFO:METL.split_dataset:saving train-val-test split to directory data/dms_data/avgfp/splits/regime/regime_tr-reg1_te-reg2-3-4-5-6-7-8-9-10-11-12-13-14-15_tr0.8_tu0.