In [3]:
from pathlib import Path

import pandas as pd
import fastai
import fastai.learner
import fastai.metrics
import fastai.callback.schedule
import fastai.callback.progress

from utils.prepare_data import read_data, save_data, target_columns
from utils.project_parameters import data_dimension, summary_statistic_order, epochs_of_training, neural_network_batch_size
import utils.deeplearning

In [4]:
target_col = target_columns[snakemake.wildcards["target"]]

In [5]:
datasets = {
    "training": utils.deeplearning.SweepsDataset(
        snakemake.input["data"], read_data(snakemake.input["training"]), target_col
    ),
    "validation": utils.deeplearning.SweepsDataset(
        snakemake.input["data"], read_data(snakemake.input["validation"]), target_col, is_validation=True
    )
}
loader = fastai.data.core.DataLoaders.from_dsets(datasets["training"], datasets["validation"], bs=neural_network_batch_size)

In [6]:
# Modify in the future when feature analysis is a thing
# feature_subset = get_feature_subset(snakemake.wildcards["features"])
feature_subset = None
if feature_subset is not None:
    # num_channels = len(feature_subset) 
    pass
else:
    num_channels = len(summary_statistic_order)

In [7]:
loss_functions = {
    'classification': fastai.losses.CrossEntropyLossFlat(),
    'regression': fastai.losses.MSELossFlat(),
}

metric_functions = {
    'classification': fastai.metrics.accuracy,
    'regression': fastai.metrics.rmse,
}

model_type, output_dim = datasets["training"].get_task()
metric = metric_functions[model_type]
loss = loss_functions[model_type]

In [8]:
neural_network = utils.deeplearning.SimpleCNN2Layer(
    input_dim=data_dimension, output_dim=output_dim, in_channels=num_channels
)

model = fastai.learner.Learner(loader, neural_network, loss_func=loss, metrics=metric)
# This happens automatically:
# model.add_cb(fastai.callback.progress.ProgressCallback())

In [9]:
if snakemake.params["only_one_epoch"]:
    num_epochs = 1
else:
    num_epochs = epochs_of_training[snakemake.wildcards["target"]]

In [10]:
model.fit_one_cycle(num_epochs)

### Save model fitting outcomes

In [11]:
def save_label(filename, loader):
    label_raw = loader.train_ds.labels
    if isinstance(label_raw, str):
        labels = [label_raw]
    elif isinstance(label_raw, list):
        labels = label_raw
    with open(filename, 'w') as f:
        f.write("\n".join(label for label in labels))

In [12]:
if snakemake.params["save_model"]:
    model.model_dir = str(Path(snakemake.output["fit_model"]).parent)
    model.save(str(Path(snakemake.output["fit_model"]).stem))
    model.export(fname=str(Path(snakemake.output["model_object"])))
    save_label(snakemake.output["model_labels"], loader)

In [13]:
fit_report = (
    pd.DataFrame.from_records(
        model.recorder.values, columns=model.recorder.metric_names[1:-1]
    )
    .assign(epoch=range(1, model.recorder.n_epoch + 1))
    .set_index('epoch', drop=True)
)
fit_report.to_csv(snakemake.output["fit_report"], sep='\t', index=True)

In [14]:
if snakemake.params["save_inferences"]:
    train_inferences = utils.deeplearning.get_training_inferences(model, loader, "training")
    train_inferences.to_csv(snakemake.output["training_inferences"], sep='\t', index=True)
    valid_inferences = utils.deeplearning.get_training_inferences(model, loader, "validation")
    valid_inferences.to_csv(snakemake.output["validation_inferences"], sep='\t', index=True)