# 7. Model Training

In [1]:
from __init__ import set_path

set_path()

In [2]:
import pandas as pd

from tfcaidm import Jobs
from tfcaidm import Model
from tfcaidm import Dataset

When on the caidm servers, we can also specify a gpu to allocate using the `gpus` method from jarvis.

## Setup

1. Get hyperparameters
2. Load a dataset
3. Create a model

**Autoselect GPU (use only on caidm cluster)**

In [3]:
from jarvis.utils.general import gpus
gpus.autoselect()

[ 2021-11-20 11:37:38 ] CUDA_VISIBLE_DEVICES automatically set to: 1           


In [4]:
YAML_PATH = "/home/brandon/tfcaidm-pkg/configs/ymls/xr_pna/pipeline.yml"

### Hyperparameters

In [5]:
# --- Get hyperparameters
runs = Jobs(path=YAML_PATH)

# --- Hyperparameters for N runs
all_hyperparams = runs.get_params()

# ---- Hyperparameters for first run
hyperparams = all_hyperparams[0]

In [6]:
hyperparams

{'env/path/root': 'exp',
 'env/path/name': 'xr_pna',
 'env/path/client': '/home/brandon/tfcaidm-pkg/configs/ymls/xr_pna/client.yml',
 'model/model': 'unet',
 'model/conv_type': 'conv',
 'model/pool_type': 'conv',
 'model/eblock': 'conv',
 'model/elayer': 1,
 'model/dblock': 'conv',
 'model/depth': 4,
 'model/width': 32,
 'model/width_scaling': 1,
 'model/kernel_size': [1, 3, 3],
 'model/strides': [1, 2, 2],
 'model/bneck': 2,
 'model/branches': 4,
 'model/atrous_rate': 6,
 'model/order': 'rnc',
 'model/norm': 'bnorm',
 'model/activ': 'leaky',
 'model/attn_msk': 'softmax',
 'train/xs/dat': None,
 'train/ys/pna/mask/name': 'msk',
 'train/ys/pna/mask/remove_bg': True,
 'train/ys/pna/mask/mask_weight': 1,
 'train/ys/pna/mask/output_weight': 5,
 'train/ys/pna/head': 'decoder_classifier',
 'train/ys/pna/n_classes': 2,
 'train/ys/pna/loss': 'sce',
 'train/ys/pna/metric': 'dice',
 'train/trainer/seed': 0,
 'train/trainer/n_folds': 1,
 'train/trainer/batch_size': 8,
 'train/trainer/iters': 3000

### Dataset

In [7]:
client = Dataset(hyperparams).get_client(fold=0)

Get some test data by invoking `create_generators` using the jarvis client.

In [8]:
gen_train, gen_valid = client.create_generators(test=False)

### Model

Model definition:

```json
{
    'model/model': 'unet',
    'model/conv_type': 'conv',
    'model/pool_type': 'conv',
    'model/eblock': 'conv',
    'model/elayer': 1,
    'model/dblock': 'conv',
    'model/depth': 4,
    'model/width': 32,
    'model/width_scaling': 1,
    'model/kernel_size': [1, 3, 3],
    'model/strides': [1, 2, 2],
    'model/bneck': 2,
    'model/branches': 4,
    'model/atrous_rate': 6,
    'model/order': 'rnc',
    'model/norm': 'bnorm',
    'model/activ': 'leaky',
    'model/attn_msk': 'softmax',
}
 ```

In [9]:
from tensorflow.keras import Input

In [10]:
nn = Model(client)
inputs = client.get_inputs(Input)

In [11]:
inputs

{'dat': <KerasTensor: shape=(None, 1, 512, 512, 1) dtype=float32 (created by layer 'dat')>,
 'msk': <KerasTensor: shape=(None, 1, 512, 512, 1) dtype=float32 (created by layer 'msk')>,
 'pna': <KerasTensor: shape=(None, 1, 512, 512, 1) dtype=uint8 (created by layer 'pna')>}

The `nn.create()` method internally invokes `client.get_inputs(Input)`, builds a model defined in `client.hyperparams['model']`, and compiles it with the loss and optimizer.

In [12]:
model = nn.create()

(None, 1, 512, 512, 32)
(None, 1, 256, 256, 32)
(None, 1, 128, 128, 32)
(None, 1, 64, 64, 32)
(None, 1, 32, 32, 32)
(None, 1, 64, 64, 32)
(None, 1, 128, 128, 32)
(None, 1, 256, 256, 32)
(None, 1, 512, 512, 32)


In [13]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
dat (InputLayer)                [(None, 1, 512, 512, 0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 1, 512, 512,  64          dat[0][0]                        
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 1, 512, 512,  128         conv3d[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 1, 512, 512,  0           batch_normalization[0][0]        
______________________________________________________________________________________________

### Trainer

In [14]:
from tfcaidm import Trainer

In [15]:
trainer = Trainer(hyperparams)

#### Train

The default fields for the trainer are:

```python
history = trainer.fit(
    model,
    gen_train,
    gen_valid,
    iters=100,
    steps_per_epoch=10,
    validation_freq=5,
    callbacks=[],
)

"""
Setting optional kwargs (iters, steps_per_epoch, validation_freq) will override the values set in hyperparams["train"]["trainer].

iters=100,
steps_per_epoch=10,
validation_freq=5,

Setting callbacks=[] will use the callbacks from hyperparams["train"]["trainer]. Otherwise a list of callbacks can be passed in or None for no callbakcs.
"""
```

In [None]:
history = trainer.fit(
    model,
    gen_train,
    gen_valid,
    iters=3000,
    steps_per_epoch=100,
    validation_freq=5,
    callbacks=None,
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30

## Inference

Let's try out the trained model.

### Visualize

Using the `imshow` function in jarvis we can visualize the model inputs and outputs.

In [None]:
import copy
import numpy as np
from jarvis.utils.display import imshow

In [None]:
def showall(x, figsize=(7,7)):
    x = copy.deepcopy(x)
    
    for k in x:
        if x[k].ndim >= 4:
            if x[k].shape[-1] > 1:
                x[k] = np.argmax(x[k], axis=-1)
                x[k] = np.expand_dims(x[k], axis=-1)
            imshow(x[k], title=k, figsize=figsize)

### Forward pass

Using the trained, hopefully generalized model, run a forward pass.

In [None]:
xs, ys = next(gen_train)
yhat = model(xs)

#### Ground-truth

In [None]:
showall(xs) # msk-pna is not actually passed in as an input, it is used for class weights...

#### Prediction

In [None]:
showall(yhat)

In [None]:
import tensorflow as tf

def get_example(x, batch_index=0):
    x = {k: np.expand_dims(x[k][batch_index], axis=0) for k in x if len(x[k].shape)}
    
    return x

def show_example(x, y, xname, yname, batch_index=0, figsize=(5, 5)):
    x = get_example(x, batch_index)[xname]
    y = np.argmax(get_example(y, batch_index)[yname], axis=-1)
    
    imshow(x, y, figsize=figsize)

In [None]:
show_example(xs, yhat, "dat", "pna/logits", 2)