# CNNGeometric demo notebook
This notebook shows how to run a trained model on a given image pair

## Imports

In [None]:
from __future__ import print_function, division
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from model.cnn_geometric_model import CNNGeometric
from data.pf_dataset import PFDataset
from data.download_datasets import download_PF_willow
from image.normalization import NormalizeImageDict, normalize_image
from util.torch_util import BatchTensorToVars, str_to_bool
from geotnf.transformation import GeometricTnf
from geotnf.point_tnf import *
import matplotlib.pyplot as plt
from skimage import io
import warnings
from torchvision.transforms import Normalize
from collections import OrderedDict

warnings.filterwarnings('ignore')

## Parameters

In [None]:
feature_extraction_cnn = 'resnet101'

if feature_extraction_cnn=='vgg':
    model_aff_path = 'trained_models/best_pascal_checkpoint_adam_affine_grid_loss.pth.tar'
    model_tps_path = 'trained_models/best_pascal_checkpoint_adam_tps_grid_loss.pth.tar'
elif feature_extraction_cnn=='resnet101':
    model_aff_path = 'trained_models/best_pascal_checkpoint_adam_affine_grid_loss_resnet_random.pth.tar'
    model_tps_path = 'trained_models/best_pascal_checkpoint_adam_tps_grid_loss_resnet_random.pth.tar'   

source_image_path='datasets/PF-dataset/duck(S)/060_0036.png'
target_image_path='datasets/PF-dataset/duck(S)/060_0013.png'

## Load models

In [None]:
use_cuda = torch.cuda.is_available()
do_aff = not model_aff_path==''
do_tps = not model_tps_path==''

# Create model
print('Creating CNN model...')
if do_aff:
    model_aff = CNNGeometric(use_cuda=use_cuda,geometric_model='affine',feature_extraction_cnn=feature_extraction_cnn)
if do_tps:
    model_tps = CNNGeometric(use_cuda=use_cuda,geometric_model='tps',feature_extraction_cnn=feature_extraction_cnn)
    
# Load trained weights
print('Loading trained model weights...')
if do_aff:
    checkpoint = torch.load(model_aff_path, map_location=lambda storage, loc: storage)
    checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
    model_aff.load_state_dict(checkpoint['state_dict'])
if do_tps:
    checkpoint = torch.load(model_tps_path, map_location=lambda storage, loc: storage)
    checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
    model_tps.load_state_dict(checkpoint['state_dict'])

## Create image transformers

In [None]:
tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)

## Load and preprocess images

In [None]:
resizeCNN = GeometricTnf(out_h=240, out_w=240, use_cuda = False) 
normalizeTnf = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def preprocess_image(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2,0,1)),0)
    image = torch.Tensor(image.astype(np.float32)/255.0)
    image_var = Variable(image,requires_grad=False)

    # Resize image using bilinear sampling with identity affine tnf
    image_var = resizeCNN(image_var)
    
    # Normalize image
    image_var = normalize_image(image_var)
    
    return image_var

In [None]:
#from data import download_datasets
#download_datasets.download_PF_willow()
import pickle
def unpickle_from_file(file_name):
    with open(file_name, 'rb') as handle:
        return pickle.load(handle)

data_sequences_2017 = unpickle_from_file('/gpfs/data/geraslab/jp4989/data/2010_2017_data/data_sequences_0709_1453.pkl')

In [125]:
# test different original input size
for patient in data_sequences_2017[0]:
    if len(patient)>1 and (patient[0]['original_image_size']['L-MLO'][0] != patient[1]['original_image_size']['L-MLO'][0]):
        exam_list = patient
        break

In [126]:
exam_list[0]['original_image_size']['L-MLO'][0], exam_list[1]['original_image_size']['L-MLO'][0]

((4096, 3328), (3328, 2560))

In [None]:
import h5py
import numpy
def read_image_mat(file_name):
    data = h5py.File(file_name, 'r')
    image = numpy.array(data['image']).T
    data.close()
    return image
def flip_image(image, view, horizontal_flip, mode = 'training'):

    if mode == 'training':
        if horizontal_flip == 'NO':
            if (view == 'R-CC') or (view == 'R-MLO'):
                image = numpy.fliplr(image)
        elif horizontal_flip == 'YES':
            if (view == 'L-CC') or (view == 'L-MLO'):
                image = numpy.fliplr(image)
    elif mode == 'medical':
        if horizontal_flip == 'YES':
            image = numpy.fliplr(image)

    return image
def normalise_single_image(image):

    image -= numpy.mean(image)
    image /= numpy.maximum(numpy.std(image), 10**(-5))
def minmax(image):
    image_max = image.max()
    image_min = image.min()
    image -= image_min
    image /= image_max - image_min
    image *= 255

In [None]:
data_prefix = '/gpfs/data/geraslab/jp4989/data/2010_2017_cropped_images_hdf5/'
source_image = read_image_mat(data_prefix+exam_list[0]["L-MLO"][0]+'.hdf5')
target_image = read_image_mat(data_prefix+exam_list[1]["L-MLO"][0]+'.hdf5')
source_image = flip_image(source_image, "L-MLO", exam_list[0]['horizontal_flip'], 'medical')
target_image = flip_image(target_image, "L-MLO", exam_list[0]['horizontal_flip'], 'medical')
source_image = source_image.astype(np.float64)
target_image = target_image.astype(np.float64)
minmax(source_image)
minmax(target_image)
source_image = source_image.astype(np.int16)
target_image = target_image.astype(np.int16)

In [None]:
source_image = np.stack((source_image,)*3, axis=-1)
target_image = np.stack((target_image,)*3, axis=-1)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(source_image)

In [None]:
#source_image = io.imread(source_image_path)
#target_image = io.imread(target_image_path)

source_image_var = preprocess_image(source_image)
target_image_var = preprocess_image(target_image)

if use_cuda:
    source_image_var = source_image_var.cuda()
    target_image_var = target_image_var.cuda()

batch = {'source_image': source_image_var, 'target_image':target_image_var}

resizeTgt = GeometricTnf(out_h=target_image.shape[0], out_w=target_image.shape[1], use_cuda = use_cuda) 

## Evaluate model

In [None]:
if do_aff:
        model_aff.eval()
if do_tps:
    model_tps.eval()

# Evaluate models
if do_aff:
    theta_aff=model_aff(batch)
    warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))

if do_tps:
    theta_tps=model_tps(batch)
    warped_image_tps = tpsTnf(batch['source_image'],theta_tps)

if do_aff and do_tps:
    theta_aff_tps=model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']})        
    warped_image_aff_tps = tpsTnf(warped_image_aff,theta_aff_tps)

In [121]:
batch['source_image'].shape

torch.Size([1, 3, 240, 240])

In [None]:
batch['source_image'].dtype

In [None]:
theta_aff

In [None]:
def preprocess_image_without_resizing(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2,0,1)),0)
    image = torch.Tensor(image.astype(np.float32)/255.0)
    image_var = Variable(image,requires_grad=False)
    
    # Normalize image
    image_var = normalize_image(image_var)
    
    return image_var

In [None]:
source_image_var_without_resizing = preprocess_image_without_resizing(source_image)
target_image_var_without_resizing = preprocess_image_without_resizing(target_image)

if use_cuda:
    source_image_var_without_resizing = source_image_var_without_resizing.cuda()
    target_image_var_without_resizing = target_image_var_without_resizing.cuda()

In [None]:
warped_image_aff_without_resizing = affTnf(source_image_var_without_resizing,theta_aff.view(-1,2,3))

In [None]:
warped_image_aff_without_resizing_np = normalize_image(resizeTgt(warped_image_aff_without_resizing),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

In [None]:
plt.imshow(warped_image_aff_without_resizing_np)

## Process result

In [None]:
# Un-normalize images and convert to numpy
if do_aff:
    warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

if do_tps:
    warped_image_tps_np = normalize_image(resizeTgt(warped_image_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

if do_aff and do_tps:
    warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()


## Display

In [None]:
N_subplots = 2+int(do_aff)+int(do_tps)+int(do_aff and do_tps)
fig, axs = plt.subplots(1,N_subplots)
axs[0].imshow(source_image)
axs[0].set_title('src')
axs[1].imshow(target_image)
axs[1].set_title('tgt')
subplot_idx = 2
if do_aff:
    axs[subplot_idx].imshow(warped_image_aff_without_resizing_np)
    axs[subplot_idx].set_title('aff')
    subplot_idx +=1 
if do_tps:
    axs[subplot_idx].imshow(warped_image_tps_np)
    axs[subplot_idx].set_title('tps')
    subplot_idx +=1 
if do_aff and do_tps:
    axs[subplot_idx].imshow(warped_image_aff_tps_np)
    axs[subplot_idx].set_title('aff+tps')

for i in range(N_subplots):
    axs[i].axis('off')

fig.set_dpi(150)
plt.show()