# Visual-Area Autolabeler: Training

## About

[Noah C. Benson](nben@uw.edu)$^{1,2,3}$, [Shaoling Chen](sc6995@nyu.edu)$^{4}$, and [Jonathan Winawer](jonathan.winawer@nyu.edu)$^{1,2}$

$^1$Department of Psychology  
$^2$Center for Neural Sciences  
$^4$Courant Institute for Mathematics  
New York University  
New York, NY 10012

$^3$**Current Affiliation:**  
eScience Institute  
University of Washington  
Seattle, WA 98122

## Initialization

### Configuration

Here we define any configuration item that needs to be set locally for the system running this notebook. Most likely, you will have to edit these in order for the model to work correctly.

In [1]:
# data_cache_path
# The directory into which data for the model training should be cached. This
# can be None, but if it is, then the training images will need to be
# regenerated every time the notebook is run.
data_cache_path  = '/data/visual-autolabel/data'

# model_cache_path
# The directory into which to store models that are generated during training.
# This may be None, but if it is, then the best models will not be saved out to
# disk during rounds of training.
model_cache_path = '/data/visual-autolabel/models'

### Libraries

In [2]:
import os, sys, pimms, pandas
import numpy as np
import scipy as sp
import nibabel as nib
import pyrsistent as pyr
import neuropythy as ny
import torch, torchvision, torchsummary

import matplotlib as mpl
import matplotlib.pyplot as plt
import ipyvolume as ipv

import visual_autolabel as va

In [3]:
%matplotlib inline

In [4]:
# Additional matplotlib preferences; these are just display preferences.
mpl_font_config = {'family':'sans-serif',
                   'sans-serif':['HelveticaNeue', 'Helvetica', 'Arial'],
                   'size': 10,
                   'weight': 'light'}
mpl.rc('font', **mpl_font_config)
# We want relatively high-res images, especially when saving to disk.
mpl.rcParams['figure.dpi'] = 72*4
mpl.rcParams['savefig.dpi'] = 72*8

### Utilities

In [5]:
def constantly(x):
    """Returns a function that always returns the argument.
    
    `constantly(x)` returns a function `f` such that `f(...)` returns `x`, no
    matter what arguments are passed to `f`.
    """
    def _lambda(*args, **kw): return x
    return _lambda

In [23]:
def train_until(training_plan, until=None, model_key=None,
                model_cache_path=model_cache_path,
                data_cache_path=data_cache_path,
                create_directories=True,
                create_mode=0o755):
    """Continuously runs the given training plan for models until an interrupt.
        
    Runs training on `'anat'`, `'func'`, and `'both'` models, sequentially,
    using random partitions until a keyboard interrupt is caught, at which
    point a `pandas` dataframe of the results is returned. The partition is
    generated only once per group of model trainings (i.e., per training of an
    anatomical, functional, and combined model).
    
    Parameters
    ----------
    training_plan : list of dicts
        The training-plan to pass to the `visual_autolabel.run_modelplan()`
        function.
    model_key : str or None, optional
        A string that should be appended, as a sub-directory name, to the
        `model_cache_path`; this argument allows one to save model training
        to a specific sub-directory of the `model_cache_path` directory.
    model_cache_path : str, optional
        The cache-path to use for the model training. By default, this is the
        global variable `model_cache_path`, defined above.
    data_cache_path : str, optional
        The cache-path from which data for the model training should be loaded.
        By default, this is the global variable `data_cache_path`, defined
        above.
    until : int or None, optional
        If an integer is provided, then only `until` groups of trainings are
        performed, then the result is returned. If `None`, then the training
        continues until a `KeyboardInterrupt` is caught. The default is `None`.
    create_directories : boolean, optional
        Whether to create cache directories that do not exist (default `True`).
    create_mode : int, optional
        What mode to use when creating directories (default: `0o755`).
    """
    if data_cache_path is Ellipsis:
        data_cache_path = globals()['data_cache_path']
    if model_cache_path is Ellipsis:
        model_cache_path = globals()['model_cache_path']
    if model_key is not None:
        if model_cache_path is None:
            model_cache_path = model_key
        else:
            model_cache_path = os.path.join(model_cache_path, model_key)
    if not os.path.isdir(model_cache_path) and create_directories:
        os.makedirs(model_cache_path, create_mode)
    if not os.path.isdir(data_cache_path) and create_directories:
        os.makedirs(data_cache_path, create_mode)
    training_history = []
    datatype_tr = dict(anat='Anatomical Data Only',
                       func='Functional Data Only',
                       both='Anatomical & Functional Data')
    try:
        print('')
        iterno = 0
        while True:
            if until is not None and iterno >= until: break
            iterno += 1
            # Make one partition for all three minimization types.
            part = va.partition(va.sids, how=(0.8, 0.2))
            pid = va.partition_id(part)
            print('%-15s%70s' % ('Iteration %d' % iterno,
                                 'Partition ID: %s' % pid))
            print('=' * 85)
            for (dtype,dnm) in datatype_tr.items():
                print('')
                print(dnm + ' ' + '-'*(85 - len(dnm) - 1))
                print('')
                t0 = time.time()
                (model, loss, dice) = va.train.run_modelplan(
                    training_plan,
                    partition=part,
                    features=dtype,
                    model_cache_path=model_cache_path,
                    data_cache_path=data_cache_path)
                t1 = time.time()
                row = dict(input=dtype, loss=loss, dice=dice,
                           training_time=(t1-t0))
                training_history.append(row)
                print('')
    except KeyboardInterrupt:
        print('')
        print('KeyboardInterrupt caught; ending training.')
    training_history = ny.to_dataframe(training_history)
    return training_history

In [14]:
def plot_prediction(dataset, k, model,
                    axes=None, figsize=(6,1), dpi=72*4, min_alpha=0.5,
                    channels=(0,1,4,5)):
    """Plots the data, true label, and predicted label (by model) of a dataset.
    
    `plot_prediction(dataset, k, model)` creates a `matplotlib` figure for
    `dataset[k]` (i.e., the `k`th subject/image in `dataset`). The `axes` are
    always returned.
    
    Parameters
    ----------
    dataset : HCPVisualDataset
        The dataset used to plot the predictions. This may alternately be a
        PyTorch dataloader, in which case the dataloader's dataset must be an
        `HCPVisualDataset` object.
    k : int
        The sample number or subject ID to plot. A sample number is just the
        index number for the subject in the dataset; if a number less than 1000
        is given, then it is assumed to eb a subject index, while if it is over
        1000, it is assumed to be a subject ID.
    model : PyTorch Module
        A UNet model or other PyTorch model that makes a segmentation
        of the images from the given `dataset`.
    axes : MatPlotLib axes or `None`, optional
        A set of axes onto which to plot the predictions. Must have a total
        flattened length of 3.
    figsize : tuple, optional
        A tuple of `(width, height)` in inches to use for the figure size. This
        is ignored if `axes` is provided. The default is `(6, 1)`.
    dpi : int, optional
        The number of dots per inch in the output image. If `axes` is provided,
        this option is ignored. The default is `72 * 4`.
    min_alpha : float, optional
        The minimum alpha value to show in the alpha channel of the image.
        Values below this level are replaced by the formula
        `adjusted_value = value * (1 - min_alpha) + min_alpha`. The default is
        `0.5`.
    channels : iterable of ints, optional
        When a dataset whose input images have more than 4 image channels is
        provided (i.e., the `'both'` datasets, which have 4 anatomical and 4
        functional image layers), then this list of 4 channels is used. By
        default this is `(0,1,4,5)`. This option is ignored if the dataset
        contains images with only 4 channels.
    """
    if k > 1000:
        # We have a subject-ID instead of an index.
        k = np.where(dataset.sids == k)[0]
    (imdat, imlbl) = dataset[k]
    impre = model(imdat[None,:,:,:].float())
    if not model.apply_sigmoid:
        impre = torch.sigmoid(impre)
    impre = dataset.inv_transform(None, impre.detach()[0])
    (imdat, imlbl) = dataset.inv_transform(imdat, imlbl)
    if axes is None:
        (fig,axes) = plt.subplots(1, 3, figsize=figsize, dpi=dpi)
    # with imdat we want to adjust the alpha layer
    imdat = np.array(imdat)
    imdat[:,:,3] = imdat[:,:,3]*(1 - min_alpha) + min_alpha
    for (ax,im) in zip(axes, [imdat, imlbl, impre]):
        if im.shape[2] > 4: im = im[:,:,:4]
        ax.imshow(np.clip(im, 0, 1))
        ax.axis('off')
    return axes

## Training

### Training Plan

First, we define our standard training plan for training models.

In [2]:
training_plan = [
    dict(lr=0.00375, gamma=0.9, num_epochs=10,  bce_weight=0.67),
    dict(lr=0.00250, gamma=0.9, num_epochs=10,  bce_weight=0.33),
    dict(lr=0.00125, gamma=0.9, num_epochs=10,  bce_weight=0.00)]

### Train Continuously

Here, we train the model continuously until an interruption is received (typically a keyboard interruption, which can be sent via the `Kernel > Interrupt Kernel` menu item in Jupyter. This slowly produces a lot of output, but the result, which is returned once the interrupt is sent, will be a `pandas` dataframe of training statistics.

In [None]:
hist = train_until(training_plan, model_key='2022-04-15_01')

### Visualization of Predictions

In [None]:
# What feature-set do we want to plot?
features = 'anat'
# What model are we plotting the results for?
model = 
# Which subject-IDs (or subject indices, either one) are we plotting?
plot_subs = np.arange(6)
# Training or Validation data?
phase = 'val'

# Make the figure and axes that we're going to use.
(fig,axs) = plt.subplots(len(plot_idcs), 3, figsize=(6, len(plot_subs)), dpi=72*4)
# And plot each row using the `plot_prediction` function.
for (axrow,idx) in zip(axs, plot_subs):
    plot_prediction(datasets[features][phase], idx, model, axes=axrow)