# Train my [UNetLike](http://eg.bucknell.edu/~jvs008/research/cardiac/SPIE20/SPIE20.html) model on the [Stanford data](https://echonet.github.io/dynamic/)
Stough, 6/20

Pretty self-explanatory

The 2D echonet model (really, deeplabv3) applied to the CAMUS data yielded in the low 80s Dice (see [play_DynamicSegmentsCAMUS.ipynb](play_DynamicSegmentsCAMUS.ipynb)). My UNetLike model trained on CAMUS, applied to the echonet dataset yielded in the high 80s (see [play_camusSegmentsDynamic.ipynb](play_camusSegmentsDynamic.ipynb)). 

So the next question: What if we **trained** my model on the echonet dataset? 

Since echonet provides a useful Dataset type already, I use the DataLoader from echonet, and only after do my additional camusizing transforms (single channel, 256x256, 0-1 intensities) and potentially data augmentations (rotation, windowing, noise through the callable objects such as random_GaussNoiser in camus.utils.camus_transforms), this will definitely create a bottleneck in training since I don't want to try to multiprocess it right now...

Since my UNetLike training on CAMUS was specific to that project, I have to rewrite run_training and run_validation...I'll just write them in this bloated notebook for prototyping...


In [1]:
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import copy
import SimpleITK as itk
import pandas as pd
import pickle
import copy
import tempfile

import os
import h5py

import cv2

# Being able to use the utility functions here.
import sys
module_path = os.path.abspath(os.path.join('camus'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)
    
sys.path.insert(0, 'camus/utils/')


from camus.utils.camus_config import CAMUS_CONFIG
from camus.utils.echo_utils import makeVideo
from camus.utils.camus_load_info import (make_camus_echo_dataset,
                                       split_camus_echo,
                                       make_camus_EHR,
                                       CAMUS_TRAINING_DIR,
                                       CAMUS_TESTING_DIR,
                                       CAMUS_RESULTS_DIR)
from camus.utils.camus_transforms import (LoadSITKFromFilename,
                                        SitkToNumpy,
                                        ResizeTransform,
                                        ResizeImagesAndLabels,
                                        WindowImagesAndLabels,
                                        RotateImagesAndLabels,
                                        random_Rotater,
                                        random_GaussNoiser,
                                        random_Windower,
                                        identity_Transform,
                                        GaussianNoiseEcho)
from camus.utils.camus_validate import (camus_overlay,
                                      labColorMap,
                                      labColor_cmap,
                                      labNameMap,
                                      nameLabMap,
                                      dict_extend_values,
                                      camus_dice_by_name,
                                      cleanupBinary,
                                      cleanupSegmentation)
from camus.utils.torch_utils import (TransformDataset,
                                   torch_collate,
                                   BetterLoss)
from camus.torch_models import UNetLike
from camus.segment_common import segment_echo

from camus.utils.echo_utils import (transformResizeImage,
                                    readTransformResizeImage)

import random
from random import shuffle
from scipy.special import softmax

from skimage.transform import (rescale, 
                               resize, 
                               rotate)


# torch
import torch
import torchvision
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch import nn
from torch.autograd import Variable
from torch.nn import Module

from torch.utils.data import (DataLoader, Dataset)
from torchvision.transforms import Compose
from torch.nn.functional import interpolate

# For timing.
import time
tic, toc = (time.time, time.time)

import echonet
from argparse import Namespace
from pprint import pprint

In [2]:
training_args = Namespace(**CAMUS_CONFIG['training']) # make a Namespace out of a dictionary.
augment_args = Namespace(**CAMUS_CONFIG['augment'])
unet_args = Namespace(**CAMUS_CONFIG['unet'])

In [3]:
# pretty print a namespace.  pprint doesn't work. vars of a namespace gives you back the dictionary...
pprint(vars(training_args))

pprint(vars(augment_args))

{'batch_size': 16,
 'effective_batchsize': 1,
 'howOftenToReport': 10,
 'image_size': [256, 256],
 'learning_rate': 0.001,
 'loss_weights': [1, 1, 1, 1],
 'num_epochs': 300,
 'patienceLimit': 41,
 'patienceToLRcut': 10,
 'weight_decay': 1e-06}
{'noise_scale': [0.0, 0.15],
 'rotation_scale': 5.0,
 'training_augment': True,
 'windowing_scale': [0.5, 1.0]}


&nbsp;

## Set up the Echonet Datasets and DataLoaders.

We'll also setup the processing post the Echonet dataloaders 
that actually fit the data into our network. We'll do this through 
a camusizer object (ripped from [play_camusSegmentsDynamic.ipynb](play_camusSegmentsDynamic.ipynb).

In [4]:
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train"))
kwargs = {"target_type": tasks,
          "mean": mean,
          "std": std
         }


param_trainLoader = {
    'batch_size': training_args.batch_size//2, # both ED/ES come with each, so halve batch_size to the echonet loader.
    'num_workers': 4,
    'shuffle': True,
    'pin_memory': False,
    'drop_last': True}

param_ValTestLoader = {
    'batch_size': training_args.batch_size//2,
    'num_workers': 4,
    'shuffle': False,
    'pin_memory': False,
    'drop_last': False}

paramLoader = {'train': param_trainLoader,
               'val': param_ValTestLoader,
               'test':  param_ValTestLoader}


datasets = {}
dataloaders = {}

for datatype in ['train', 'val', 'test']:
    datasets[datatype] = echonet.datasets.Echo(split=datatype, **kwargs)
    dataloaders[datatype] = torch.utils.data.DataLoader(datasets[datatype],
                                                        **paramLoader[datatype])

100%|██████████| 16/16 [00:01<00:00, 14.32it/s]


In [5]:
'''
The Camusize object takes the tuple that comes from a Echnet dataloader 
and reformats it to be appropriate to train the UNetLike model on:
So, the echonet dataloader is iterated like:
for (_, (large_frame, small_frame, large_trace, small_trace)) in test_dataloader:

where [large,small]_[frame,trace] are N x [3,1] x 112 x 112 (3 for echo frame, 
1 for trace) tensors. The echo frames are further mean/var normalized. 
My UNetLike works on (N x):
1 x 256 x 256 (single channel echo frames)
[0,1] range

So this class's object call is to cat to cat the [large,small] frames and traces
and convert to appropriate for UNetLike. So it takes the dataloader sample above and 
returns the appropriately formatted two-tuple (frames, traces)
'''
class Camusize(object):
    def __init__(self, im_size):
        self.im_size = im_size
        
    def _norm(self, frames):
        # make 0-1, but in tensors:
        # https://discuss.pytorch.org/t/how-to-efficiently-normalize-a-batch-of-tensor-to-0-1/65122/4
        AA = frames.clone()
        AA = AA.view(frames.size(0), -1)
        AA -= AA.min(1, keepdim=True)[0]
        AA /= AA.max(1, keepdim=True)[0]
        AA = AA.view(frames.shape)
        return AA
    
    def _rgb2gray(self, frames):
        # 0.2989 * R + 0.5870 * G + 0.1140 * B 
        # In torch tensors a bit tougher. I'll just do the mean...
        # return np.multiply(frame, np.array([.2989, .5870, .1140])[:, None, None]).sum(axis=0)
        return torch.mul(frames, torch.tensor([.2989, .5870, .1140])[None, :, None, None]).sum(1, keepdim=True)
#         return frames.mean(1, keepdim=True)
        
    '''
    Object call: Should take frame and trace batches and convert to
    CAMUS-acceptable images (in [0-1], and 256x256 single channel):
    frames is n x 3 x 112 x 112 tensor
    traces is n x 112 x 112 tensor
    '''
    def __call__(self, echonet_tuple):
        _, (large_frame, small_frame, large_trace, small_trace) = echonet_tuple
        
        frames = torch.cat((large_frame, small_frame), 0)
        traces = torch.cat((large_trace, small_trace), 0)
        
        out_frames = self._norm(frames)
        out_frames = self._rgb2gray(out_frames)
        out_frames = interpolate(out_frames, size=self.im_size, 
                                 mode='bilinear', align_corners=False)
        
        out_traces = interpolate(traces.unsqueeze(1), size=self.im_size,
                                 mode='nearest').squeeze().type(torch.long)
        
        return out_frames, out_traces

In [6]:
# Ripped from camus_overlay and camusSegmentsDynamic
def overlay(im, lab, lab_gt=None, alpha = 1.0, whichLabels = [0,1]):

    # Just a three channel version of the echo image.
    # I = im.transpose([1,2,0]).copy()
    I = np.stack([im,im,im], axis=-1)
    
    # Special image for lab_gt versus lab.
    if lab_gt is not None:
        if len(lab_gt.shape) == 3: # maybe, for some reason, 1 x h x w was provided...
            lab_gt = lab_gt.squeeze()
            
        I_gt = I.copy()
        
        # Make complementary colors for the label/gt difference overlay.
        gtCompColors = {}
        for key, val in labColorMap.items():
            mx = max(val)
            gtCompColors[key] = [mx-x for x in val]

    # Regular lab image of keys instead of network output. Expected h x w
    # Now add the lab image to the echo.
    mchanLabelKey = lab.copy() # assigning this so the lab_gt stuff still works below
    for key in whichLabels:
        I[lab==key,:] += alpha*np.array(labColorMap[key])
      
    
    if lab_gt is not None:
        for key in [1]: # hardcode, sorry. Want LV to show best.
            whereFP = np.logical_and(mchanLabelKey == key, lab_gt != key)
            whereFN = np.logical_and(mchanLabelKey != key, lab_gt == key)
                
                # Getting the shapes right is difficult.
#                 I_gt[whereFP,:] += np.reshape([0, .2*key, 0], (1,3))
#                 I_gt[whereFN,:] += np.reshape([.2*key, 0, 0], (1,3))
            I_gt[whereFP,:] += labColorMap[key]
            I_gt[whereFN,:] += gtCompColors[key] # false negative is color complement
            
        I_gt = I_gt.clip(0,1)

    I = I.clip(0,1)
    
    if lab_gt is not None:
        return I, I_gt
    
    return I

In [7]:
for i, (frames, traces) in enumerate(map(Camusize(im_size=training_args.image_size), dataloaders['test'])):
    if i >= 5:
        break
        
    f, ax = plt.subplots(1,2, figsize=(4,2))
    Io = overlay(frames[0].squeeze(), lab=traces[0].numpy())
    ax[0].imshow(frames[0].squeeze(), cmap='gray', interpolation=None)
    ax[1].imshow(Io, interpolation=None)
    
    [a.axes.get_xaxis().set_visible(False) for a in ax]
    [a.axes.get_yaxis().set_visible(False) for a in ax]
    plt.tight_layout()
    
    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
frames.shape, traces.shape, frames.min(), frames.max()

(torch.Size([16, 1, 256, 256]),
 torch.Size([16, 256, 256]),
 tensor(0.0002),
 tensor(0.9995))

In [9]:
traces.dtype

torch.int64

&nbsp;

## Let's instantiate the model.

In [10]:
unet_args.n_classes = 2
pprint(vars(unet_args))

{'n_channels': 1, 'n_classes': 2, 'n_filters': 32, 'normalization': 'groupnorm'}


In [11]:
# We're using the UNetLike 2d seg from torch_models.

net_seg = UNetLike(**vars(unet_args))
net_seg = torch.nn.DataParallel(net_seg)
net_seg.cuda(); # semicolon suppresses summary printout of the model.

In [12]:
# Number of trainable parameters?
# https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
print('Trainable params: {}'.format(sum(p.numel() for p in net_seg.parameters() if p.requires_grad)))

# Or
# https://gist.github.com/zackenton/12a86b6e0ff274b39608e40f4a412f2b
from functools import reduce
def pytorch_count_params(model):
    '''count number trainable parameters in a pytorch model'''
    total_params = sum(reduce( lambda a, b: a*b, x.size()) for x in model.parameters())
    return total_params

print('Or maybe: {}'.format(pytorch_count_params(net_seg)))

Trainable params: 13068034
Or maybe: 13068034


&nbsp;

## Training and Validation Functions

I'm ripping most of the content from camus/torch_utils, but I wrote those
with my CAMUS dictionary-based dataloader in mind. I could wrap the Camusize 
class above in an iterable to yield something that could trick that other 
code, but it seems more straightforward to just write my own loops for this 
little experiment.

In [13]:
def run_training(network,  
                 effective_batchsize,
                 criterion = BetterLoss(),
                 cur_learning_rate=1e-3,
                 cur_weight_decay=1e-5):
    
    
    network.train()

    optimizer = torch.optim.Adam(network.parameters(),
                                 lr=cur_learning_rate,
                                 weight_decay=cur_weight_decay)
    optimizer.zero_grad()
        
    running_loss = 0.0
    
    
    for batch_num, (inputs, correct_outputs) in enumerate(map(Camusize(im_size=training_args.image_size), dataloaders['train'])):
        
        if torch.cuda.is_available:
            inputs, correct_outputs = inputs.cuda(), correct_outputs.cuda()

        # wrap them in Variable
        inputs, correct_outputs = Variable(inputs), Variable(correct_outputs)
        
        
        # get network output
        net_outputs = network(inputs)
    
        net_loss = criterion(net_outputs, correct_outputs)

        net_loss.backward()
        
        if (i % effective_batchsize) == 0:
            optimizer.step()
            # zero the parameter gradients
            optimizer.zero_grad()
        

        running_loss += net_loss.detach().cpu().item()
        
    
    # Only using the last for example output.
    net_outputs = net_outputs.detach().cpu().numpy()
        
        
    # Ready to return info, including loss and example.
    avg_loss = running_loss / len(dataloaders['train'])
    one_output = net_outputs[0]    
    one_input = inputs.detach().cpu().numpy()[0]
    one_correct_output = correct_outputs.detach().cpu().numpy()[0]
    
    return avg_loss, one_output, one_input, one_correct_output


def run_validation(network,
                   phase = 'val', # 'valid', 'test'
                   criterion = BetterLoss()):
    # Prevent weight updates.
    network.eval()
    
    running_loss = 0.0
    
    
    with torch.no_grad():
        for batch_num, (inputs, correct_outputs) in enumerate(map(Camusize(im_size=training_args.image_size), dataloaders[phase])):

            if torch.cuda.is_available:
                inputs, correct_outputs = inputs.cuda(), correct_outputs.cuda()

            # wrap them in Variable
            inputs, correct_outputs = Variable(inputs), Variable(correct_outputs)


            # get network output
            net_outputs = network(inputs)

            net_loss = criterion(net_outputs, correct_outputs)

            running_loss += net_loss.detach().cpu().item()
            
        
        # Only using the last for example output.
        net_outputs = net_outputs.detach().cpu().numpy()
            


    # Ready to return info, including loss and example.
    avg_loss = running_loss / len(dataloaders[phase])
    one_output = net_outputs[0]    
    one_input = inputs.detach().cpu().numpy()[0]
    one_correct_output = correct_outputs.detach().cpu().numpy()[0]
    
    return avg_loss, one_output, one_input, one_correct_output
    

&nbsp;

## The Actual Experimental Loop
Alternate a training epoch with the evaluation of the validation set, repeat.

In [14]:
# For saving and restoring the best val loss model.
best_model_file = tempfile.NamedTemporaryFile(mode='w+b', delete=False)
best_epoch = -1
min_val_loss = 1e15



# To ensure we don't go on too long
patienceLimit, patience = training_args.patienceLimit, 0
patienceToLRcut = training_args.patienceToLRcut


# Report dices every some number epochs.
howOftenToReport = training_args.howOftenToReport 

# For updating the learning rate during training.
cur_learning_rate = training_args.learning_rate
num_epochs = training_args.num_epochs


print('STARTING TRAINING')

for i in range(1, num_epochs+1):
    if patience == patienceToLRcut or (patience > patienceToLRcut 
                                       and patience % patienceToLRcut == 0): 
        cur_learning_rate /= 2 
        print('\n\ncutting learning rate to {}'.format(cur_learning_rate))
        print('Reloading best model, from epoch {}'.format(best_epoch))
        net_seg.load_state_dict(torch.load(best_model_file.name))
        
    # Run the training epoch and validation test, and report.
    start = tic()
    train_loss, train_output_seg, train_input_img, train_label = \
        run_training(net_seg, 
                     effective_batchsize = training_args.effective_batchsize,
                     cur_learning_rate = cur_learning_rate,
                     cur_weight_decay = training_args.weight_decay)

    valid_loss, valid_output_seg, valid_input_img, val_label = \
        run_validation(net_seg, 
                       phase = 'val')

    print('\n\nEPOCH {} of {} ({:.3f} sec)'.format(i, num_epochs, toc()-start))
    print('-- train loss {} -- valid loss {} --'.format(train_loss, valid_loss))
    
    
    # Initially, whether to visualize this iteration.
    visThisIter = (i % howOftenToReport) == 1


    ########### Check for best so far model 

    # Now save if this is the best model so far
    if valid_loss < min_val_loss:
        min_val_loss = valid_loss
#             if (i > 20):
#                 visThisIter = True # Only sometimes viz best so far. kind of slow.
        patience = 0
        best_epoch = i
        print('Epoch {}, saving new best loss model, {}'.format(i, valid_loss))
        torch.save(net_seg.state_dict(), best_model_file.name) 
    else:
        patience += 1
        if patience >= patienceLimit:
            print('Breaking on patience, epoch {}.'.format(i))
            break


    ########### Potentially report results this iter.

#     if visThisIter:
#         report_validation_sets(net_seg=net_seg, datadict=datadict, args=args)



# After all the training, now reload the best model.
print('Reloading best model, from epoch {}'.format(best_epoch))
net_seg.load_state_dict(torch.load(best_model_file.name))

STARTING TRAINING


EPOCH 1 of 300 (159.086 sec)
-- train loss 0.36341487168626213 -- valid loss 0.33903557413853475 --
Epoch 1, saving new best loss model, 0.33903557413853475


EPOCH 2 of 300 (164.698 sec)
-- train loss 0.33308682788978833 -- valid loss 0.33290436327087214 --
Epoch 2, saving new best loss model, 0.33290436327087214


EPOCH 3 of 300 (169.045 sec)
-- train loss 0.3306458076949795 -- valid loss 0.3295757207811249 --
Epoch 3, saving new best loss model, 0.3295757207811249


EPOCH 4 of 300 (168.106 sec)
-- train loss 0.32964072763408203 -- valid loss 0.3298109914205089 --


EPOCH 5 of 300 (168.832 sec)
-- train loss 0.3287846425891946 -- valid loss 0.3283388760889539 --
Epoch 5, saving new best loss model, 0.3283388760889539


EPOCH 6 of 300 (168.853 sec)
-- train loss 0.3283675887937709 -- valid loss 0.3283363771364556 --
Epoch 6, saving new best loss model, 0.3283363771364556


EPOCH 7 of 300 (168.795 sec)
-- train loss 0.32800453108267724 -- valid loss 0.32841680397898

<All keys matched successfully>

In [15]:
net_seg.load_state_dict(torch.load(best_model_file.name))

<All keys matched successfully>

In [16]:
net_seg.eval()

net_outputs = torch.tensor([])
target_outputs = torch.tensor([]).type(torch.long)
test_frames = torch.tensor([])


with torch.no_grad():

    for batch_num, (camusized_frames, camusized_traces) in \
        enumerate(map(Camusize(im_size=training_args.image_size), dataloaders['test'])):
        
        
        # Add to the test_frames
        test_frames = torch.cat((test_frames, camusized_frames), 0)
        
        if torch.cuda.is_available():
            camusized_frames = camusized_frames.cuda()
        
        # Push through the network
        inputs = Variable(camusized_frames)
        
        net_outputs = torch.cat((net_outputs,
                                 net_seg(inputs).detach().cpu()),
                                0)
        
        # Add to target answers
        target_outputs = torch.cat((target_outputs, camusized_traces), 0)
        

In [17]:
net_outputs.shape

torch.Size([2552, 2, 256, 256])

In [18]:
# Custom dice business just here.
def getDices(autoseg, labels, whichLabels=['Background', 'Left Ventricle']):
    
    # Similar to code in visOverlay, where we want to find the choice made by
    # the autoseg
    mchanLabelKey = (autoseg > 0).astype(np.uint8).squeeze() # should be (N, H, W)
    
    labels = labels.squeeze() # should be (N, H, W) 
    
    retDice = np.zeros((autoseg.shape[0], len(whichLabels)))
    
    for case in range(autoseg.shape[0]):
        for i, key in enumerate(whichLabels):
            seg = (mchanLabelKey[case]==nameLabMap[key]).astype(np.uint8)
            lab = (labels[case]==nameLabMap[key]).astype(np.uint8)
            
#             print(seg.max(), lab.max()) debug
            # Dice is intersection over average
#             print('lab {}, seg {}, intersect {}'.format(lab.sum(), seg.sum(), (seg*lab).sum()))
            retDice[case][i] = (2.0*((seg*lab).sum()))/(seg.sum() + lab.sum())
    
    return retDice

In [19]:
dices = getDices((torch.argmax(net_outputs, 1)==nameLabMap['Left Ventricle']).numpy(), target_outputs.numpy())

In [20]:
dices[:,1].mean(), dices[:,1].std()

(0.922743060622721, 0.04441620803445517)

In [21]:
plt.figure(figsize=(4,3))
plt.hist(dices[:,1], bins=20);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …