# MLP Controller Network Training

This Notebook is used to create and train a neural network model for the purpose of predicting the worm movement.

In [None]:
# fix imports
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, ChainDataset, ConcatDataset

import numpy as np
import pandas as pd

from wtracker.utils.path_utils import Files, join_paths
from wtracker.utils.config_base import print_initialization
from wtracker.neural.config import DatasetConfig, TrainConfig, IOConfig, LOSSES, OPTIMIZERS
from wtracker.neural.dataset import NumpyDataset
from wtracker.neural.mlp import MlpBlock, RMLP, WormPredictor
from wtracker.neural.training import MLPTrainer
from wtracker.neural.train_results import FitResult
from wtracker.utils.gui_utils import UserPrompt


pd.options.display.max_columns = 30

torch.set_printoptions(sci_mode=False)

### Configure the Model, Dataset and Training Parameters

In [None]:
################################ User Input ################################

# path to the log files, used to train the network, should be a list of paths 
# if None, a file dialog will open to select log files
log_paths = ['D:/Guy_Gilad/Exp1_GuyGilad/logs_yolo/init_bboxes.csv'] 

# io config for 6 frame cycle (100ms) with 3 frames of pred and movement
# io_config = IOConfig(
#     input_frames=[0, -2, -9, -11, -18, -20, -27],  # list
#     pred_frames=[9],  # list
# )

# io config for 15 frame cycle (200ms) with 3 frames of pred and movement
io_config = IOConfig(
    input_frames=[0, -3, -15, -18, -30, -33, -45],  # list
    pred_frames=[3, 6, 9, 12],  # list
)

############################################################################

if log_paths is None:
    log_paths = []
    while True:
        path = UserPrompt.open_file(f"Please select log file {len(log_paths)} to use for training, cancel if done")
        if len(path) == 0:
            break
        log_paths.append(path)

dataset_config = DatasetConfig.from_io_config(io_config, log_paths) # create a dataset config object from the io_config and log_paths
print(f"dataset_config= {dataset_config.__dict__}")

In [None]:
# create a neural network model based on the dataset_config. Will be ignored if a model path is given by TrainConfig
block_in_dim = 80
block_dims = [40, 10, 40, 80]
block_activations = ["relu"] * (len(block_dims))
in_dim = 4 * len(dataset_config.input_frames)
out_dim = 2 * len(dataset_config.pred_frames)

model = RMLP(
    in_dim=in_dim,
    block_in_dim=block_in_dim,
    block_dims=block_dims,
    block_nonlins=block_activations,
    n_blocks=4,
    out_dim=out_dim,
)

In [None]:
# we wrap the model in a WormPredictor object, which will hold the io_config for future use and distinguish it from general Neural Network models.
model = WormPredictor(model, io_config)

In [None]:
print(model) # print the model layers

In [None]:
################################ User Input ################################

train_config = TrainConfig(
    seed=42,  # int
    dataset=dataset_config,  # Dataset
    model=model,  # Union[nn.Module, str]
    loss_fn="mse",  # nn.Module
    optimizer="adam",  # Union[Optimizer, str]
    device=f"cpu",  # str
    log=True,  # bool
    num_epochs=100,  # int
    checkpoints="ResMLP(1)_config1.pt",  # str
    early_stopping=15,  # int
    print_every=5,  # int
    learning_rate=0.001,  # float
    weight_decay=1e-05,  # float
    batch_size=128,  # int
    shuffle=True,  # bool
    num_workers=0,  # int
    train_test_split=0.8,  # float
)

############################################################################

### Run The Training Process

In [None]:
datasets = []
for path in dataset_config.log_path:
    # create a dataset according to the dataset_config
    config = DatasetConfig.from_io_config(io_config, path)
    datasets.append(NumpyDataset.create_from_config(config))

In [None]:
# combine datasets of all log files
dataset = ConcatDataset(datasets)

In [None]:
# Split the dataset
ds_train, ds_test = random_split(dataset, [train_config.train_test_split, 1 - train_config.train_test_split])

In [None]:
# Create the dataloaders
dl_train = DataLoader(ds_train, batch_size=train_config.batch_size, shuffle=train_config.shuffle)
dl_test = DataLoader(ds_test, batch_size=train_config.batch_size, shuffle=train_config.shuffle)

In [None]:
# initialize the loss object
loss_fn = LOSSES[train_config.loss_fn]()

In [None]:
# initialize the optimizer object
lr = train_config.learning_rate
weight_decay = train_config.weight_decay
optimizer = OPTIMIZERS[train_config.optimizer](model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
# create the trainer object
device = torch.device(train_config.device)
trainer = MLPTrainer(model, loss_fn, optimizer, device=device, log=train_config.log)

In [None]:
# train the model
epochs = train_config.num_epochs
checkpoints = train_config.checkpoints
early_stopping = train_config.early_stopping
print_every = train_config.print_every

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=epochs,
    checkpoints=checkpoints,
    print_every=print_every,
    early_stopping=early_stopping,
)

In [None]:
# save the training configuration if logging is enabled
if train_config.log:
    train_config.save_pickle(join_paths(trainer.logger.log_dir, "train_config.pkl"))