In [2]:
from sklearn.model_selection import train_test_split

from utils.prepare_data import read_data, save_data, target_columns

In [3]:
raw = read_data(snakemake.input["sim_params"])

1. Remove failed simulations.
2. Remove neutral simulations.
3. Downsample so each of the remaining selection regimes has exactly 5000 cases. (For now, we crash if there's a case with less than that.)

In [4]:
df = raw.loc[(raw.simulation_status == 'ok') & (raw.regime != 'neutral')]

NUM_SIMS = 5000

if not all(df.regime.value_counts() > NUM_SIMS):
    raise Exception(f"One or more selection regimes did not have {NUM_SIMS} simulations:\n{df.regime.value_counts()}")
    
df = df.groupby('regime').sample(n=NUM_SIMS)

In [5]:
df.regime.value_counts()

In [6]:
try:
    seed = snakemake.params["random_seed"]
except AttributeError: # No random seed given
    seed = None

In [7]:
train, valid = train_test_split(
    df,
    test_size=snakemake.config["validation_percentage"],
    random_state=seed,
    shuffle=True,
    stratify=df.regime
)

In [8]:
save_data(train, snakemake.output["training"])
save_data(valid, snakemake.output["validation"])