In [41]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import dataset as dd # custom dataset class
import models as md

import warnings

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
do_channel_augmentation = False

# make a generator to get size of input data
generator_train =  dd.MRImageSequence(scan_numbers=[1], batch_size=10, augment_channels=do_channel_augmentation)
input_shape = generator_train.x_transformed[0].shape[1:]

loading scan  1
X shape:  (320, 320, 256, 8)
y shape:  (320, 320, 256, 1)
augment_images:  False


In [43]:
def compare_fov(in1, in2):    
    warnings.warn('receptive fov do not match') if(in1 != in2) else print('test passed')

In [44]:
inputs = tf.keras.layers.Input(shape=input_shape)
out, fov = md.get_kaist_unet(inputs, get_fov=True)

print('fov: ', fov)

get_kaist_unet
use_pool:  True
use_bn:  False
gen_fn:  gen_conv_relu
fov:  187


In [45]:
def unit_test_fov(input_shape, unet_sizes, fov_correct_answer, use_pool):
    # unet_sizes is a list of tuples, [(num_blocks_stage_i, num_filters_stage_i)]
    inputs = tf.keras.layers.Input(shape=input_shape)
    _, fov = md.get_unet(inputs, unet_sizes, use_pool=use_pool, get_fov=True)
    
    print('fov: ', fov)
    compare_fov(fov, fov_correct_answer)

In [46]:
'''has 2 3x3 convolutions with no pooling ''' 
unet_size = [(2, 16)] 
unit_test_fov(input_shape, unet_size, 4, use_pool=False)

get_unet
use_pool:  False
gen_fn:  gen_conv_relu
unet_shape:  [(2, 16)]
fov:  4
test passed


In [51]:
'''has 2 3x3 convolutions, 3 3x3 convolution, then 2 3x3 convolutions ''' 
unet_size = [(2, 16), (3, 16)] 
unit_test_fov(input_shape, unet_size, 14, use_pool=False)

get_unet
use_pool:  False
gen_fn:  gen_conv_relu
unet_shape:  [(2, 16), (3, 16)]
fov:  14
test passed


In [54]:
''' 1 3x3 conv, pool, 1 3x3 conv, uppool, 1 3x3 conv - verified to be 9 on paper'''
unet_size = [(1, 16), (1, 16)]
unit_test_fov(input_shape, unet_size, 9, use_pool=True)
              

get_unet
use_pool:  True
gen_fn:  gen_conv_relu
unet_shape:  [(1, 16), (1, 16)]
fov:  9
test passed
