# STIFMap Network Training Notebook

The goal of this notebook is to allow users to re-train the neural networks with the same inputs and hyperparameters as they appear in the manuscript. 

Note that results may be slightly different due to the stochastic nature of network parameter initialization and sample randomization into training/validation groups.

### Import necessary packages

In [None]:
# Import the pipeline for neural network training
from STIFMaps import training

# Other imports
import pandas as pd
import torch.nn as nn

### Specify Input Files

In order to train the networks, the collagen images and their corresponding stiffnesses are required from [Mendeley](https://data.mendeley.com/datasets/vw2bb5jy99/2)

In [None]:
# Path to the stiffnesses.csv annotation file
annot_file = '/path/to/stiffnesses.csv'

# Path to the images in raw_squares
img_dir = '/path/to/squares/'

In [None]:
# Import the stiffness csv file
df = pd.read_csv(annot_file)

df.head()

### Train the Models

Model training is done with the function `train_model`, which has the following inputs:

##### Basic inputs
 - **df**: The dataframe 'stiffnesses.csv'
 - **img_dir**: The directory where the squares are stored 
 - **name**: The name that will be used during saving (if applicable)
 
##### Image transformation parameter
 - **brightness_range**: The upper and lower bounds of brightness adjustments. Used to augment data to artificially 'increase' the size of the training data, emphasize relevant features, and prevent overfitting of the training data
 - **contrast_range**: The upper and lower bounds of contrast adjustments
 - **sharpness_range**: The upper and lower bounds of sharpness adjustments
 
##### Model hyperparameters
 - **batch_size**: How many samples to run through the model at once when computing the direction to step the model
 - **n_epochs**: The number of epochs that the model will be trained over
 - **learning_rate**: The step size used to change the model parameters in the direction of the error gradient
 - **weight_decay**: Regularization parameter used to prevent model overfitting by reducing all of the network weights each epoch
 - **criterion**: The loss/cost function used to compute model errors
 
##### Saving parameters
 - **save_directory**: The directory where training statistics and summary plots will be stored. Note that a value of 'False' means that these will no be saved
 - **save_visualizations**: If 'True', will save saliency plots for the best and worst fits in the training and validation data sets. Note that 'save_directory' must be specified for plots to be saved
 
##### Output:
In addition to saving output plots (if specified), `train_model` returns the trained network.

In [None]:
# Train a single model
net = training.train_model(df, img_dir, name = 'test',
    # Transformation parameters
    brightness_range = (.9,1.1), 
    contrast_range = (.5,1.5), 
    sharpness_range = (1.5,.67),
    # Parameters for the model
    batch_size = 16,
    n_epochs = 100
    learning_rate = 4e-6,
    weight_decay = 4e-7,
    # Loss function
    criterion = nn.MSELoss(),
    save_directory = False,
    save_visualizations = False
    )

In the actual manuscript, 25 total models were trained using five different sets of image transformation parameters.

In [None]:
# 5 different sets of transformation options that were used for training 
brightness_options = [(.9,1.1), (.8,1.2), (.5,1.5), (.8,1.2), (.9,1.1)]
contrast_options = [(.5,1.5), (.8,1.2), (.5,1.5), (.5,1.5), (.75,1.5)]
sharpness_options = [(1.5,.67), (1.5,.67), (1.5,.67), (1,1), (1,1)]

In [None]:
# This will train 5 models using each set of the above parameters
for i in range(5):

    print('Now training model number ' + str(i))

    brightness_range = brightness_options[i]
    contrast_range = contrast_options[i]
    sharpness_range = sharpness_options[i]

    net = training.train_model(df, img_dir, name = 'network_' + str(i),
        # Transformation parameters
        brightness_range = brightness_range, 
        contrast_range = contrast_range, 
        sharpness_range = sharpness_range,
        # Parameters for the model
        batch_size = 16,
        n_epochs = 100,
        learning_rate = 4e-6,
        weight_decay = 4e-7,
        # Loss function
        criterion = nn.MSELoss(),
        save_directory = False,
        save_visualizations = False
        )