# MLP Controller Network Training

This Notebook is used to create and train a neural network model for the purpose of predicting the worm movement.

In [1]:
# fix imports
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, ChainDataset, ConcatDataset

import numpy as np
import pandas as pd

from wtracker.utils.path_utils import Files
from wtracker.utils.config_base import print_initialization
from wtracker.neural.config import DatasetConfig, TrainConfig, IOConfig, LOSSES, OPTIMIZERS
from wtracker.neural.dataset import NumpyDataset
from wtracker.neural.mlp import MlpBlock, RMLP, WormPredictor
from wtracker.neural.training import MLPTrainer
from wtracker.neural.train_results import FitResult
from wtracker.utils.gui_utils import UserPrompt


pd.options.display.max_columns = 30

torch.set_printoptions(sci_mode=False)

### Configure the Model, Dataset and Training Parameters

In [3]:
################################ User Input ################################

# path to the log files, used to train the network, should be a list of paths 
# if None, a file dialog will open to select log files
# log_paths = ['D:/Guy_Gilad/Exp0_GuyGilad/logs_yolo/init_bboxes.csv', 'D:/Guy_Gilad/Exp2_GuyGilad/logs_yolo/init_bboxes.csv', 'D:/Guy_Gilad/Exp3_GuyGilad/logs_yolo/init_bboxes.csv'] 
log_paths = ['D:/Guy_Gilad/Exp1_GuyGilad/logs_yolo/init_bboxes.csv'] 

# io config for 6 frame cycle (100ms) with 3 frames of pred and movement
# io_config = IOConfig(
#     input_frames=[0, -2, -9, -11, -18, -20, -27],  # list
#     pred_frames=[9],  # list
# )

# io config for 15 frame cycle (200ms) with 3 frames of pred and movement
io_config = IOConfig(
    input_frames=[0, -3, -15, -18, -30, -33, -45],  # list
    pred_frames=[3, 6, 9, 12],  # list
)

############################################################################

if log_paths is None:
    path = "start"
    log_paths = []
    while len(path) > 0:
        path = UserPrompt.open_file(f"Please select log file {len(log_paths)} to use for training, cancel if done")
        log_paths.append(path)
    log_paths.pop(-1)

dataset_config = DatasetConfig.from_io_config(io_config, log_paths) # create a dataset config object from the io_config and log_paths
print(f"dataset_config= {dataset_config.__dict__}")

dataset_config= {'input_frames': [0, -3, -15, -18, -30, -33, -45], 'pred_frames': [3, 6, 9, 12], 'log_path': ['D:/Guy_Gilad/Exp1_GuyGilad/logs_yolo/init_bboxes.csv']}


In [4]:
# create a neural network model based on the dataset_config. Will be ignored if a model path is given by TrainConfig
block_in_dim = 80
block_dims = [40, 10, 40, 80]
block_activations = ["relu"] * (len(block_dims))
in_dim = 4 * len(dataset_config.input_frames)
out_dim = 2 * len(dataset_config.pred_frames)

model = RMLP(
    in_dim=in_dim,
    block_in_dim=block_in_dim,
    block_dims=block_dims,
    block_nonlins=block_activations,
    n_blocks=4,
    out_dim=out_dim,
)

In [5]:
# we wrap the model in a WormPredictor object, which will hold the io_config for future use and distinguish it from general Neural Network models.
model = WormPredictor(model, io_config)

In [6]:
print(model) # print the model layers

WormPredictor(
  (model): RMLP(
    (input): MLPLayer(
      (mlp_layer): Sequential(
        (0): Linear(in_features=28, out_features=80, bias=True)
        (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (blocks): ModuleList(
      (0-3): 4 x MlpBlock(
        (sequence): Sequential(
          (0): MLPLayer(
            (mlp_layer): Sequential(
              (0): Linear(in_features=80, out_features=40, bias=True)
              (1): BatchNorm1d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
          )
          (1): MLPLayer(
            (mlp_layer): Sequential(
              (0): Linear(in_features=40, out_features=10, bias=True)
              (1): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU()
            )
          )
          (2): MLPLayer(
            (mlp_layer): Sequential(
          

In [7]:
################################ User Input ################################

train_config = TrainConfig(
    seed=42,  # int
    dataset=dataset_config,  # Dataset
    model=model,  # Union[nn.Module, str]
    loss_fn="mse",  # nn.Module
    optimizer="adam",  # Union[Optimizer, str]
    device=f"cpu",  # str
    log=True,  # bool
    num_epochs=100,  # int
    checkpoints="ResMLP(1)_config1.pt",  # str
    early_stopping=15,  # int
    print_every=5,  # int
    learning_rate=0.001,  # float
    weight_decay=1e-05,  # float
    batch_size=128,  # int
    shuffle=True,  # bool
    num_workers=0,  # int
    train_test_split=0.8,  # float
)

############################################################################

### Run The Training Process

In [8]:
datasets = []
for path in dataset_config.log_path:
    # create a dataset according to the dataset_config
    config = DatasetConfig.from_io_config(io_config, path)
    datasets.append(NumpyDataset.create_from_config(config))

In [9]:
# combine datasets of all log files
dataset = ConcatDataset(datasets)

In [10]:
# Split the dataset
ds_train, ds_test = random_split(dataset, [train_config.train_test_split, 1 - train_config.train_test_split])

In [11]:
# Create the dataloaders
dl_train = DataLoader(ds_train, batch_size=train_config.batch_size, shuffle=train_config.shuffle)
dl_test = DataLoader(ds_test, batch_size=train_config.batch_size, shuffle=train_config.shuffle)

In [12]:
# initialize the loss object
loss_fn = LOSSES[train_config.loss_fn]()

In [13]:
# initialize the optimizer object
lr = train_config.learning_rate
weight_decay = train_config.weight_decay
optimizer = OPTIMIZERS[train_config.optimizer](model.parameters(), lr=lr, weight_decay=weight_decay)

In [14]:
# create the trainer object
device = torch.device(train_config.device)
trainer = MLPTrainer(model, loss_fn, optimizer, device=device, log=train_config.log)

In [15]:
# train the model
epochs = train_config.num_epochs
checkpoints = train_config.checkpoints
early_stopping = train_config.early_stopping
print_every = train_config.print_every

trainer.fit(
    dl_train=dl_train,
    dl_test=dl_test,
    num_epochs=epochs,
    checkpoints=checkpoints,
    print_every=print_every,
    early_stopping=early_stopping,
)

--- EPOCH 1/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=4.342

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=3.656

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=3.253

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=3.085

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.877
--- EPOCH 6/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.862

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.725

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.649

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.589
--- EPOCH 11/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.514

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.487

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.441

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.394

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.261
--- EPOCH 16/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.235

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.186

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.151
--- EPOCH 21/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.125

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.115

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.109

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=2.087
--- EPOCH 26/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.972

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.916
--- EPOCH 31/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.870
--- EPOCH 36/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.844

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.833
--- EPOCH 41/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.814

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.767
--- EPOCH 46/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.749

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.746

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.663
--- EPOCH 51/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]

--- EPOCH 56/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.656

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.632
--- EPOCH 61/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]

--- EPOCH 66/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.630

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.612
--- EPOCH 71/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.608

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.595

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.566
--- EPOCH 76/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.547
--- EPOCH 81/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.516
--- EPOCH 86/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.501
--- EPOCH 91/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.490

*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.477
--- EPOCH 96/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]


*** Saved checkpoint runs\Jun07_23-59-30_bi-slevylab9/ResMLP(1)_config1.pt :: val_loss=1.461
--- EPOCH 100/100 ---


train_batch:   0%|          | 0/399 [00:00<?, ?it/s]

test_batch:   0%|          | 0/100 [00:00<?, ?it/s]

FitResult(num_epochs=100, train_loss=[66.38268280029297, 71.17364501953125, 65.47247314453125, 61.168052673339844, 63.082603454589844, 58.50596618652344, 54.780029296875, 55.84389877319336, 54.95668411254883, 48.22549057006836, 49.656105041503906, 44.56403732299805, 40.55813217163086, 42.38209533691406, 42.41008377075195, 34.74921417236328, 35.766693115234375, 35.20514678955078, 30.13496208190918, 32.06049346923828, 27.81822967529297, 25.560035705566406, 27.804283142089844, 30.765878677368164, 24.958093643188477, 22.92986297607422, 28.002803802490234, 18.676305770874023, 20.728775024414062, 21.048524856567383, 17.376514434814453, 14.689143180847168, 19.372827529907227, 15.994953155517578, 14.748144149780273, 15.562997817993164, 14.316312789916992, 13.354541778564453, 11.949588775634766, 12.357645034790039, 11.347588539123535, 13.990410804748535, 14.249670028686523, 11.584867477416992, 8.077718734741211, 12.139749526977539, 10.907896041870117, 9.780574798583984, 11.41259765625, 7.965291

In [16]:
# save the training configuration if logging is enabled
if train_config.log:
    train_config.save_pickle(trainer.logger.log_dir + "/train_config.pkl")