# Using the ***pre-trained*** Stough CAMUS model to segment the Stanford Echonet data
Stough, 5/20

Here I'm taking my CAMUS-trained UNet-like model and pushing Stanford ap4 data through. The [Stanford 
data and model](https://echonet.github.io/dynamic) is exceptionally different, as all videos have been
cropped and are a bit warped in aspect ratio to how my model was trained. 

In [1]:
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import copy
import SimpleITK as itk
import pandas as pd
import pickle
import copy

import os
import h5py

import cv2

# Being able to use the utility functions here.
import sys
module_path = os.path.abspath(os.path.join('camus'))
if module_path not in sys.path:
    sys.path.insert(0, module_path)
    
sys.path.insert(0, 'camus/utils/')


from camus.utils.camus_config import CAMUS_CONFIG
from camus.utils.echo_utils import makeVideo
from camus.utils.camus_load_info import (make_camus_echo_dataset,
                                       split_camus_echo,
                                       make_camus_EHR,
                                       CAMUS_TRAINING_DIR,
                                       CAMUS_TESTING_DIR,
                                       CAMUS_RESULTS_DIR)
from camus.utils.camus_transforms import (LoadSITKFromFilename,
                                        SitkToNumpy,
                                        ResizeTransform,
                                        ResizeImagesAndLabels,
                                        WindowImagesAndLabels,
                                        RotateImagesAndLabels,
                                        random_Rotater,
                                        random_GaussNoiser,
                                        random_Windower,
                                        identity_Transform,
                                        GaussianNoiseEcho)
from camus.utils.camus_validate import (camus_overlay,
                                      labColorMap,
                                      labColor_cmap,
                                      labNameMap,
                                      nameLabMap,
                                      dict_extend_values,
                                      camus_dice_by_name,
                                      cleanupBinary,
                                      cleanupSegmentation)
from camus.utils.torch_utils import (TransformDataset,
                                   torch_collate,
                                   run_training,
                                   run_validation,
                                   run_validation_returnAll,
                                   BetterLoss)
from camus.torch_models import UNetLike
from camus.segment_common import segment_echo

from camus.utils.echo_utils import (transformResizeImage,
                                    readTransformResizeImage)

import random
from random import shuffle
from scipy.special import softmax

from skimage.transform import (rescale, 
                               resize, 
                               rotate)


# torch
import torch
import torchvision
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch import nn
from torch.autograd import Variable
from torch.nn import Module

from torch.utils.data import (DataLoader, Dataset)
from torchvision.transforms import Compose
from torch.nn.functional import interpolate

# For timing.
import time
tic, toc = (time.time, time.time)

import echonet
from argparse import Namespace

In [2]:
args = Namespace(image_size=[256, 256],
                 foldToTest=1,
                 batch_size=20)

&nbsp;

## Set up the model

In [3]:
VIEWS=['2CH', '4CH']
PHASE_MAP = {'ED':0, 'ES':1}

lvIndex = nameLabMap['Left Ventricle']

# Constant over all CAMUS images
# Remember this is width x height though, so xy instead of the ij
# I use in my math. So the proper scaling to send to computVolume is
# (PIX_Y, PIX_X)
PIX_X = .308
PIX_Y = .154

In [4]:
pathconfig = CAMUS_CONFIG['paths']

# Load the dataset and ehr
# camusByPat = make_camus_echo_dataset(pathconfig['CAMUS_TRAINING_DIR'], args.view)
camusByPat = dict([(key, make_camus_echo_dataset(CAMUS_TESTING_DIR, key)) for key in VIEWS])


# Load the fold info. Here, really I only need the number of folds.
foldfilename = os.path.join(CAMUS_RESULTS_DIR, pathconfig['folds_file'])
with open(foldfilename, 'rb') as fid:
    kf = pickle.load(fid)

NUMFOLDS = len(kf)


augconfig = CAMUS_CONFIG['augment']
# The saved model name that generated the results. 
modelpathname = {}
for view in VIEWS:
    model_fname = 'aug_{}_fold_{{}}_win_{}_{}_rot_{}_noise_{}_{}.pth'.format(
        view,
        augconfig['windowing_scale'][0], augconfig['windowing_scale'][1],
        augconfig['rotation_scale'],
        augconfig['noise_scale'][0], augconfig['noise_scale'][1])

    modelpathname[view] = os.path.join(CAMUS_RESULTS_DIR, 
                                       'saved_models',
                                       model_fname)

print('Defined modelpathname {},\n'
      '                      {},\n'
      'camusByPat, and NUMFOLDS {}.'.format(modelpathname['2CH'], 
                                            modelpathname['4CH'],
                                            NUMFOLDS))

Defined modelpathname /home/stough/data/CAMUS/results/saved_models/aug_2CH_fold_{}_win_0.5_1.0_rot_5.0_noise_0.0_0.15.pth,
                      /home/stough/data/CAMUS/results/saved_models/aug_4CH_fold_{}_win_0.5_1.0_rot_5.0_noise_0.0_0.15.pth,
camusByPat, and NUMFOLDS 10.


In [5]:
def load_UNet(pathname):
    global CAMUS_CONFIG
    
    # First, clear previous gpu stuff.
    torch.cuda.empty_cache()

    net_seg = UNetLike(**CAMUS_CONFIG['unet'])
    net_seg = torch.nn.DataParallel(net_seg)
    net_seg.cuda(); # semicolon suppresses summary printout of the model. Seems unecessary with the dataparallel

    net_seg.load_state_dict(torch.load(pathname))
    net_seg.eval();
    
    print('\tLoaded {}, with {} weights'.format(pathname,
                                                sum(p.numel() for p in net_seg.parameters() if p.requires_grad)))
    
    return net_seg

In [6]:
net_seg = load_UNet(modelpathname['4CH'].format(args.foldToTest))

	Loaded /home/stough/data/CAMUS/results/saved_models/aug_4CH_fold_1_win_0.5_1.0_rot_5.0_noise_0.0_0.15.pth, with 13068100 weights


In [7]:
net_seg

DataParallel(
  (module): UNetLike(
    (block_one): ResidualConvBlock(
      (conv): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm2D()
        (2): ReLU(inplace=True)
      )
    )
    (block_one_dw): DownsamplingConvBlock(
      (conv): Sequential(
        (0): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
        (1): GroupNorm2D()
        (2): ReLU(inplace=True)
      )
    )
    (block_two): ResidualConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): GroupNorm2D()
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): GroupNorm2D()
        (5): ReLU(inplace=True)
      )
    )
    (block_two_dw): DownsamplingConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
        (1): GroupNorm2D()
        (2): ReLU(inplace=True)
  

&nbsp;

## Look at the Stanford echo data

In [8]:
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train"))
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
kwargs = {"target_type": tasks,
          "mean": mean,
          "std": std
          }

100%|██████████| 16/16 [00:01<00:00, 14.26it/s]


In [9]:
# Set up datasets and dataloaders
test_dataset = echonet.datasets.Echo(split="test", **kwargs)
test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                                               batch_size=args.batch_size, 
                                               num_workers=4, 
                                               shuffle=False, 
                                               pin_memory=False, 
                                               drop_last=True)

In [10]:
_, (large_frame, small_frame, large_trace, small_trace) = test_dataset[100]

In [11]:
large_trace.shape

(112, 112)

In [12]:
large_frame.shape

(3, 112, 112)

In [13]:
large_frame.min(), large_frame.max(), large_frame.mean(), large_frame.std()

(-0.6599135, 4.2719173, -0.053806968, 0.89031696)

In [14]:
lf = large_frame.copy()

In [15]:
nf = (lf - lf.min()) / (lf.max() - lf.min())

In [16]:
nf = np.multiply(nf, np.array([.2989, .5870, .1140])[:, None, None])

In [17]:
nf.sum(axis=0).shape

(112, 112)

&nbsp;

## Using a callable object and [map](https://github.com/pytorch/pytorch/issues/437#issuecomment-272192380) on the dataloader 
to modify the data batches to fit into the CAMUS model (with [interpolate](https://discuss.pytorch.org/t/resize-tensor-without-converting-to-pil-image/52401)).

The CAMUS model expects:
- 1x256x256 single-channel
- intensities in [0,1]


In [18]:
class Camusize(object):
    def __init__(self, im_size, norm):
        self.im_size = im_size
        self.norm = norm
        
    def _norm(self, frames):
        # make 0-1, but in tensors:
        # https://discuss.pytorch.org/t/how-to-efficiently-normalize-a-batch-of-tensor-to-0-1/65122/4
        AA = frames.clone()
        AA = AA.view(frames.size(0), -1)
        AA -= AA.min(1, keepdim=True)[0]
        AA /= AA.max(1, keepdim=True)[0]
        AA = AA.view(frames.shape)
        return AA
    
    def _rgb2gray(self, frames):
        # 0.2989 * R + 0.5870 * G + 0.1140 * B 
        # In torch tensors a bit tougher. I'll just do the mean...
        # return np.multiply(frame, np.array([.2989, .5870, .1140])[:, None, None]).sum(axis=0)
        return torch.mul(frames, torch.tensor([.2989, .5870, .1140])[None, :, None, None]).sum(1, keepdim=True)
#         return frames.mean(1, keepdim=True)
        
    '''
    Object call: Should take frame and trace batches and convert to
    CAMUS-acceptable images (in [0-1], and 256x256 single channel):
    frames is n x 3 x 112 x 112 tensor
    traces is n x 112 x 112 tensor
    '''
    def __call__(self, frames, traces):
        out_frames = self._norm(frames)
        out_frames = self._rgb2gray(out_frames)
        out_frames = interpolate(out_frames, size=self.im_size, 
                                 mode='bilinear', align_corners=False)
        
        out_traces = interpolate(traces.unsqueeze(1), size=self.im_size,
                                 mode='nearest').squeeze()
        
        return out_frames, out_traces

In [19]:
camusizer = Camusize(im_size=args.image_size, norm=True)

for (_, (large_frame, small_frame, large_trace, small_trace)) in test_dataloader:
    frames = torch.cat((large_frame, small_frame), 0)
    traces = torch.cat((large_trace, small_trace), 0)
    
    camusized_frames, camusized_traces = camusizer(frames, traces)
    break

In [20]:
traces.shape

torch.Size([40, 112, 112])

In [21]:
camusized_frames.shape

torch.Size([40, 1, 256, 256])

In [22]:
%%capture
# vid = echonet.utils.makeVideo(np.transpose(camusized_frames, (0,2,3,1)))
vid = echonet.utils.makeVideo(camusized_frames.squeeze(), cmap='gray')

In [23]:
vid

In [24]:
%%capture
vid = echonet.utils.makeVideo(camusized_traces)

In [25]:
vid

&nbsp;

## Apply the CAMUS-trained model on this Stanford data

In [26]:

net_seg.eval()

net_outputs = torch.tensor([])
target_outputs = torch.tensor([])
test_frames = torch.tensor([])


with torch.no_grad():

    for (_, (large_frame, small_frame, large_trace, small_trace)) in test_dataloader:
        frames = torch.cat((large_frame, small_frame), 0)
        traces = torch.cat((large_trace, small_trace), 0)

        camusized_frames, camusized_traces = camusizer(frames, traces)
        
        # Add to the test_frames
        test_frames = torch.cat((test_frames, camusized_frames), 0)
        
        if torch.cuda.is_available():
            camusized_frames = camusized_frames.cuda()
        
        # Push through the network
        inputs = Variable(camusized_frames)
        
        net_outputs = torch.cat((net_outputs,
                                 net_seg(inputs).detach().cpu()),
                                0)
        
        # Add to target answers
        target_outputs = torch.cat((target_outputs, camusized_traces), 0)
        

In [27]:
# The outputs were collected as 20 (batchsize) large then small frames/traces. So
# if I want to look at a particular case, which large/small frame of which patient is that???
bs = args.batch_size
whichcase = 2080
batch_num = bs*(whichcase//(2*bs))
batch_offset = whichcase - 2*batch_num

if batch_offset < bs: # one of the large_frames accumulated.
    _, (large_frame, small_frame, large_trace, small_trace) = test_dataset[batch_num + batch_offset]
    frame = large_frame
    trace = large_trace
else: # one of the small frames.
    _, (large_frame, small_frame, large_trace, small_trace) = test_dataset[batch_num + batch_offset - bs]
    frame = small_frame
    trace = small_trace
    

f, ax = plt.subplots(1,3,figsize=(8,3))
ax[0].imshow(frame[0].squeeze(), cmap='gray')
ax[1].imshow(trace.squeeze())
ax[2].imshow(torch.argmax(net_outputs[whichcase], 0));

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

&nbsp;

## Get the Dices
Some codes ripped from the [DynamicSegmentsCAMUS](./play_DynamicSegmentsCAMUS.ipynb) version.

In [28]:
# Custom dice business just here.
def getDices(autoseg, labels, whichLabels=['Background', 'Left Ventricle']):
    
    # Similar to code in visOverlay, where we want to find the choice made by
    # the autoseg
    mchanLabelKey = (autoseg > 0).astype(np.uint8).squeeze() # should be (N, H, W)
    
    labels = labels.squeeze() # should be (N, H, W) 
    
    retDice = np.zeros((autoseg.shape[0], len(whichLabels)))
    
    for case in range(autoseg.shape[0]):
        for i, key in enumerate(whichLabels):
            seg = (mchanLabelKey[case]==nameLabMap[key]).astype(np.uint8)
            lab = (labels[case]==nameLabMap[key]).astype(np.uint8)
            
#             print(seg.max(), lab.max()) debug
            # Dice is intersection over average
#             print('lab {}, seg {}, intersect {}'.format(lab.sum(), seg.sum(), (seg*lab).sum()))
            retDice[case][i] = (2.0*((seg*lab).sum()))/(seg.sum() + lab.sum())
    
    return retDice

In [29]:
dices = getDices((torch.argmax(net_outputs, 1)==nameLabMap['Left Ventricle']).numpy(), target_outputs.numpy())

In [30]:
dices.shape

(2520, 2)

In [31]:
net_outputs.shape, target_outputs.shape

(torch.Size([2520, 4, 256, 256]), torch.Size([2520, 256, 256]))

In [32]:
dices[:,1].mean(), dices[:,1].std()

(0.880529374398505, 0.06793691250779385)

In [33]:
plt.figure(figsize=(4,3))
plt.hist(dices[:,1], bins=20);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

&nbsp;

# So, CAMUS model applied to the Stanford data.
It's not great, though better than the dynamic model applied to the CAMUS data. Let's vis a bit more.

In [34]:
# Let's get the arrays net_outputs, target_outputs, and test_frames in order
all_segs = (torch.argmax(net_outputs, 1)==nameLabMap['Left Ventricle']).numpy()
all_labels = target_outputs.numpy()
all_images = test_frames.squeeze().numpy()

In [35]:
# Ripped from camus_overlay and DyanmicSegmentsCAMUS
def overlay(im, lab, lab_gt=None, alpha = 1.0, whichLabels = [0,1]):

    # Just a three channel version of the echo image.
    # I = im.transpose([1,2,0]).copy()
    I = np.stack([im,im,im], axis=-1)
    
    # Special image for lab_gt versus lab.
    if lab_gt is not None:
        if len(lab_gt.shape) == 3: # maybe, for some reason, 1 x h x w was provided...
            lab_gt = lab_gt.squeeze()
            
        I_gt = I.copy()
        
        # Make complementary colors for the label/gt difference overlay.
        gtCompColors = {}
        for key, val in labColorMap.items():
            mx = max(val)
            gtCompColors[key] = [mx-x for x in val]

    # Regular lab image of keys instead of network output. Expected h x w
    # Now add the lab image to the echo.
    mchanLabelKey = lab.copy() # assigning this so the lab_gt stuff still works below
    for key in whichLabels:
        I[lab==key,:] += alpha*np.array(labColorMap[key])
      
    
    if lab_gt is not None:
        for key in [1]: # hardcode, sorry. Want LV to show best.
            whereFP = np.logical_and(mchanLabelKey == key, lab_gt != key)
            whereFN = np.logical_and(mchanLabelKey != key, lab_gt == key)
                
                # Getting the shapes right is difficult.
#                 I_gt[whereFP,:] += np.reshape([0, .2*key, 0], (1,3))
#                 I_gt[whereFN,:] += np.reshape([.2*key, 0, 0], (1,3))
            I_gt[whereFP,:] += labColorMap[key]
            I_gt[whereFN,:] += gtCompColors[key] # false negative is color complement
            
        I_gt = I_gt.clip(0,1)

    I = I.clip(0,1)
    
    if lab_gt is not None:
        return I, I_gt
    
    return I

In [36]:
from ipywidgets import VBox, HBox, IntSlider

plt.ioff()
plt.clf()

whichCase = 0
case_slider = IntSlider(
    orientation='horizontal',
    value=whichCase,
    min=0,
    max=len(all_images),
    step=1,
    description='Index'
)

# We're going to keep the image, overlay, gt_overlay, and text representing 
# the case.

I = all_images[whichCase]
I_seg, I_diff = overlay(all_images[whichCase], 
                        lab = all_segs[whichCase], 
                        lab_gt = all_labels[whichCase])
I_gt = overlay(all_images[whichCase],
               lab = all_labels[whichCase])


# Display artists
fig_args = {'num':' ', 'frameon':True, 'sharex':True, 'sharey':True}
fig, ax = plt.subplots(1,4, figsize=(10,3), **fig_args) 

disp_0 = ax[0].imshow(I, interpolation=None, cmap='gray')
disp_1 = ax[1].imshow(I_seg, interpolation=None)
disp_2 = ax[2].imshow(I_gt, interpolation=None)
disp_3 = ax[3].imshow(I_diff, interpolation=None)


ltext = ax[0].set_title(f'Case {whichCase}')

ax[1].set_title('seg result')
ax[2].set_title('ground truth')
ax[3].set_title('diff with GT')

# Other ways of doing this result in more empty space, even with tight_layout
[a.axes.get_xaxis().set_visible(False) for a in ax]
[a.axes.get_yaxis().set_visible(False) for a in ax]

plt.tight_layout()


def update_images(change):
    global all_images, all_labels, all_segs, disp_0, disp_1, disp_2, disp_3, ltext, I, I_seg, I_gt, I_diff, whichCase
    
    whichCase = case_slider.value
    
    I = all_images[whichCase]
    I_seg, I_diff = overlay(all_images[whichCase], 
                            lab = all_segs[whichCase], 
                            lab_gt=all_labels[whichCase])
    I_gt = overlay(all_images[whichCase],
               lab = all_labels[whichCase])
    
    disp_0.set_array(I)
    ltext.set_text(f'Case {whichCase}')
    
    disp_1.set_array(I_seg)
    disp_2.set_array(I_gt)
    disp_3.set_array(I_diff)
    
    fig.canvas.draw()
    fig.canvas.flush_events()
    
case_slider.observe(update_images, names='value')

VBox([case_slider, fig.canvas])

VBox(children=(IntSlider(value=0, description='Index', max=2520), Canvas(toolbar=Toolbar(toolitems=[('Home', '…