# Spectral-invariant Matching Network

## The source code  contains

 - Data preparation
 - Code description

## Dataset preparation

1) Downloading three cross-spectral and multi-spectral dataset

RGB-NIR Scene Dataset

https://ivrlwww.epfl.ch/supplementary_material/cvpr11/nirscene1.zip

PittsStereo-RGBNIR: A Large RGB-NIR Stereo Dataset Collected in Pittsburgh with Challenging Materials

http://www.cs.cmu.edu/~ILIM/projects/AA/RGBNIRStereo/

KAIST Multispectral Pedestrian Dataset

https://sites.google.com/site/pedestrianbenchmark/

2) Use the 'ICCV2021_Data Preparation for RGB-NIR Patch Dataset.ipynb', 'ICCV2021_Data Preparation for RGB-NIR Stereo Dataset.ipynb', and 'ICCV2021_Data Preparation for KAIST RGB-thermal Dataset.ipynb' files to construct image patch datasets.


## Train

1) Training setting

We trained our model from scratch for 35 epochs in total. All the convolution and convolution transpose layers used the initialisation method in [15] to set initial values for their weights. All models were trained in an end-to-end manner with the ADAM optimiser (β1 = 0.9, β2 = 0.999) [23]. We used a batch size of 32 and set the learning rate to 0.04 with a decay factor 0.1 after 20 epochs. The training was performed with a customised version of Tensorﬂow 2.0 on an NVIDIA Titan Xp GPU, which usually takes two days. A forward pass of SPIMNet takes about 0.1 seconds for patches with a 64×64 resolution. 

To prevent an overﬁtting problem, all samples were normalised to [-1, 1], and data augmentation was carried out through random ﬂipping, random rotation (90, 180, 270 degrees), and random cropping. In addition, two regularisation techniques were employed: the label smoothing [39] and L2 kernel regularisation for the convolution layers of the feature extraction networks with l2 = 0.001.

2) Code description
2.1) Important packages
     - sklearn
     - scipy
     - tensorflow
     - tensorflow_addons
     
 2.2) Important functions
    - load(image_file): Read an image based on its path using Tensorflow
    - normalize(): Normalize pixel values to the range [-1,1]
    - Data augmentation including
      + resize()
      + random_crop()
      + random_jitter(): flip images, rotate images, change brightness and contrast
    - downsample(): create a block for an encoding network
    - upsample(): create a block for a decoding network
    - extract_features(): create a block of an extraction network
    - RGB2NIR_convertor(): create VIS2NIR network
    - NIR2RGB_convertor(): create NIR2VIS network
    - NIR_domain_matching(): Feature extraction in NIR domain
    - RGB_domain_matching(): Feature extraction in VIS domain
    - similaritor_loss() loss function for SPIMNet
    - train(): train SPIMNet
    - compute_test_acc(): compute FPR95 for the testing sets
    
2.3) We provide jupyter files each of which divided into cells of codes. Each cell is designed for a small task or a function, and we have added comments for it. 

In [1]:
# cell 1: Package declaration 

#%env CUDA_DEVICE_ORDER=PCI_BUS_ID
#%env CUDA_VISIBLE_DEVICES=1

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import numpy as np

from IPython import display
from sklearn import metrics
from scipy import interpolate

import tensorflow as tf
import tensorlayer as tl
import tensorflow_addons as tfa

tf.__version__

TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


'2.7.0-dev20210627'

In [2]:
# Cell 2: load images using Tensorflow

PATH = '/home/shared_dir/research/SPIMNet/NIR_RGB/patch-merged-datasets/'

BUFFER_SIZE = 1024*4
BATCH_SIZE  = 8  # for each positive and negative pairs, altogether = 32
IMG_WIDTH   = 64
IMG_HEIGHT  = 64
n_train_samples = 138752

def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]
    w = w // 4
    
    rgb_pos = image[:, :w, :]
    nir_pos = image[:, w*1:w*2, :]
    rgb_neg = image[:, w*2:w*3, :]
    nir_neg = image[:, w*3:w*4, :]

    rgb_pos = tf.cast(rgb_pos, tf.float32)
    nir_pos = tf.cast(nir_pos, tf.float32)
    rgb_neg = tf.cast(rgb_neg, tf.float32)
    nir_neg = tf.cast(nir_neg, tf.float32)

    return rgb_pos, nir_pos, rgb_neg, nir_neg

In [3]:
# cell 3: data augmentation

def resize(input_l, input_r, target_l, target_r, height, width):
    input_l  = tf.image.resize(input_l, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    input_r  = tf.image.resize(input_r, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    target_l = tf.image.resize(target_l, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    target_r = tf.image.resize(target_r, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_l, input_r, target_l, target_r

def random_crop(input_l, input_r, target_l, target_r):
    stacked_image = tf.stack([input_l, input_r, target_l, target_r], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[4, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1], cropped_image[2], cropped_image[3]

# normalizing the images to [-1, 1]
def normalize(input_l, input_r, target_l, target_r):
    input_l  = (input_l / 127.5) - 1
    input_r  = (input_r / 127.5) - 1
    target_l = (target_l / 127.5) - 1
    target_r = (target_r / 127.5) - 1

    return input_l, input_r, target_l, target_r

def random_jitter(input_l, input_r, target_l, target_r):    
    # resize to 68x68
    input_l, input_r, target_l, target_r = resize(input_l, input_r, target_l, target_r, 68, 68)
    
    # crop
    input_l, input_r, target_l, target_r = random_crop(input_l, input_r, target_l, target_r)

    # flip_left_right
    if tf.random.uniform(()) > 0.5:        
        input_l  = tf.image.flip_left_right(input_l)
        input_r  = tf.image.flip_left_right(input_r)
        target_l = tf.image.flip_left_right(target_l)
        target_r = tf.image.flip_left_right(target_r)
        
    # flip_up_down
    if tf.random.uniform(()) > 0.5:        
        input_l  = tf.image.flip_up_down(input_l)
        input_r  = tf.image.flip_up_down(input_r)
        target_l = tf.image.flip_up_down(target_l)
        target_r = tf.image.flip_up_down(target_r)
        
    # brighness change
    if tf.random.uniform(()) > 0.5:
        rand_value  = tf.random.uniform((), minval=-5.0, maxval=5.0)        
        input_l = input_l + rand_value
        
        rand_value  = tf.random.uniform((), minval=-5.0, maxval=5.0)
        input_r = input_r + rand_value
        
        rand_value  = tf.random.uniform((), minval=-5.0, maxval=5.0)        
        target_l = target_l + rand_value
        
        rand_value  = tf.random.uniform((), minval=-5.0, maxval=5.0)
        target_r = target_r + rand_value
                 
    # contrast change
    if tf.random.uniform(()) > 0.5:        
        rand_value = tf.random.uniform((), minval=0.8, maxval=1.2)
        mean_value = tf.reduce_mean(input_l)
        input_l   = (input_l - mean_value)*rand_value + mean_value
        
        rand_value = tf.random.uniform((), minval=0.8, maxval=1.2)
        mean_value = tf.reduce_mean(input_r)
        input_r   = (input_r - mean_value)*rand_value + mean_value
        
        rand_value = tf.random.uniform((), minval=0.8, maxval=1.2)
        mean_value = tf.reduce_mean(target_l)
        target_l   = (target_l - mean_value)*rand_value + mean_value
        
        rand_value = tf.random.uniform((), minval=0.8, maxval=1.2)
        mean_value = tf.reduce_mean(target_r)
        target_r   = (target_r - mean_value)*rand_value + mean_value

    
    # clip value
    input_l  = tf.clip_by_value(input_l, clip_value_min=0.0, clip_value_max=255.0)
    input_r  = tf.clip_by_value(input_r, clip_value_min=0.0, clip_value_max=255.0)
    target_l = tf.clip_by_value(target_l, clip_value_min=0.0, clip_value_max=255.0)
    target_r = tf.clip_by_value(target_r, clip_value_min=0.0, clip_value_max=255.0)          
    
    # rotate positive samples for making hard positive cases
    if tf.random.uniform(()) > 0.5: 
        if tf.random.uniform(()) < 0.5:
            input_l = tfa.image.rotate(input_l, 1.5707963268) # 90
            input_r = tfa.image.rotate(input_r, 1.570796326)  # 90
        else:
            input_l = tfa.image.rotate(input_l, 4.7123889804) # 270
            input_r = tfa.image.rotate(input_r, 4.7123889804) # 270
                
    return input_l, input_r, target_l, target_r

def load_image_train(image_file):
    input_l, input_r, target_l, target_r = load(image_file)
    input_l, input_r, target_l, target_r = random_jitter(input_l, input_r, target_l, target_r)
    input_l, input_r, target_l, target_r = normalize(input_l, input_r, target_l, target_r)

    return input_l, input_r, target_l, target_r

def load_image_test(image_file):
    input_l, input_r, target_l, target_r = load(image_file)
    input_l, input_r, target_l, target_r = resize(input_l, input_r, target_l, target_r, IMG_HEIGHT, IMG_WIDTH)
    input_l, input_r, target_l, target_r = normalize(input_l, input_r, target_l, target_r)

    return input_l, input_r, target_l, target_r

In [4]:
# cell 4: load training data

# train_dataset
train_dataset = tf.data.Dataset.list_files(PATH+'country/*.jpg')
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [5]:
# cell 5: Network building blocks

def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.keras.initializers.he_normal(seed=None)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
        result.add(tfa.layers.InstanceNormalization())

    result.add(tf.keras.layers.ReLU())

    return result

def upsample(filters, size):
    initializer = tf.keras.initializers.he_normal(seed=None)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                    kernel_initializer=initializer, use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())
    result.add(tfa.layers.InstanceNormalization())

    result.add(tf.keras.layers.ReLU())

    return result

def extract_first_features(filters, size, strides, apply_batchnorm=True):
    initializer = tf.keras.initializers.he_normal(seed=None)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same',
                             kernel_initializer=initializer, use_bias=False, 
                             kernel_regularizer=tf.keras.regularizers.l2(0.001)))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
        result.add(tfa.layers.InstanceNormalization())

    result.add(tf.keras.layers.ReLU())

    return result

In [6]:
# cell 6: RGB2NIR network

def RGB2NIR_convertor(input_x):
    # input shape:  (64, 64, 3)
    # output shape: (64, 64, 1)
    
    x_1 = input_x
    
    down_stack = [
        downsample(64, 4,  apply_batchnorm=False), # (bs, 32, 32, 64)
        downsample(128, 4, apply_batchnorm=True),  # (bs, 16, 16, 512)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 8, 8, 512)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 4, 4, 512)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 2, 2, 512)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(256, 4), # (bs, 2, 2, 1024)
        upsample(256, 4), # (bs, 4, 4, 1024)
        upsample(256, 4), # (bs, 8, 8, 1024)
        upsample(128, 4), # (bs, 16, 16, 1024)
        upsample(64, 4), # (bs, 32, 32, 512)
    ]

    initializer     = tf.keras.initializers.he_normal(seed=None)    
    OUTPUT_CHANNELS = 1
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 64, 64, 1)

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x_1 = down(x_1)
        skips.append(x_1)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    concat = tf.keras.layers.Concatenate()
    for up, skip in zip(up_stack, skips):
        x_1 = up(x_1)
        x_1 = concat([x_1, skip])

    x_1 = last(x_1) 
    
    return x_1

In [7]:
# cell 7: NIR2RGB network

def NIR2RGB_convertor(input_x):
    # input shape:  (64, 64, 1)
    # output shape: (64, 64, 3)
    
    x_1 = input_x
    
    down_stack = [
        downsample(64, 4,  apply_batchnorm=False), # (bs, 32, 32, 64)
        downsample(128, 4, apply_batchnorm=True),  # (bs, 16, 16, 128)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 8, 8, 256)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 4, 4, 256)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 2, 2, 256)
        downsample(256, 4, apply_batchnorm=True),  # (bs, 1, 1, 256)
    ]
    '''
    up_stack = [
        upsample(256, 4, apply_dropout=True), # (bs, 2, 2, 256)
        upsample(256, 4, apply_dropout=True), # (bs, 4, 4, 256)
        upsample(256, 4, apply_dropout=True), # (bs, 8, 8, 256)
        upsample(128, 4, apply_dropout=True), # (bs, 16, 16, 128)
        upsample(64,  4, apply_dropout=False), # (bs, 32, 32, 64)
    ]
    '''
    up_stack = [
        upsample(256, 4), # (bs, 2, 2, 256)
        upsample(256, 4), # (bs, 4, 4, 256)
        upsample(256, 4), # (bs, 8, 8, 256)
        upsample(128, 4), # (bs, 16, 16, 128)
        upsample(64,  4), # (bs, 32, 32, 64)
    ]
    
    initializer     = tf.keras.initializers.he_normal(seed=None)    
    OUTPUT_CHANNELS = 3
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') # (bs, 64, 64, 1)

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x_1 = down(x_1)
        skips.append(x_1)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    concat = tf.keras.layers.Concatenate()
    for up, skip in zip(up_stack, skips):
        x_1 = up(x_1)
        x_1 = concat([x_1, skip])

    x_1 = last(x_1) 
    
    return x_1

In [8]:
# cell 8: NIR domain matching

def NIR_domain_matching(input_x1, input_x2):
    x_1 = input_x1
    x_2 = input_x2
    
    # for x_1
    layer1 = extract_first_features(32, 3, 1, True)
    layer2 = extract_first_features(64, 3, 1, True)
    layer3 = extract_first_features(128, 3, 1, True)
    layer4 = extract_first_features(128, 5, 2, True)
    layer5 = extract_first_features(256, 3, 1, True)
    layer6 = extract_first_features(256, 5, 2, True)
    layer7 = extract_first_features(256, 3, 1, True)
    layer8 = extract_first_features(256, 5, 2, True)
            
    # for x_1
    x_1 = layer1(x_1)
    x_1 = layer2(x_1)
    x_1 = layer3(x_1)
    x_1 = layer4(x_1)
    x_1 = layer5(x_1)
    x_1 = layer6(x_1)
    x_1 = layer7(x_1)
    x_1 = layer8(x_1)
    x_1 = layers.Flatten()(x_1)
    
    # for x_2
    x_2 = layer1(x_2)
    x_2 = layer2(x_2)
    x_2 = layer3(x_2)
    x_2 = layer4(x_2)
    x_2 = layer5(x_2)
    x_2 = layer6(x_2)
    x_2 = layer7(x_2)
    x_2 = layer8(x_2)
    x_2 = layers.Flatten()(x_2)    
    
    x = tf.abs(x_1-x_2)
    x = tf.concat([x_1, x_2, x], 1)
    
    return x

In [9]:
# cell 9: RGB domain matching

def RGB_domain_matching(input_x1, input_x2):
    x_1 = input_x1
    x_2 = input_x2
    
    # for x_1
    layer1 = extract_first_features(32, 3, 1, True)
    layer2 = extract_first_features(64, 3, 1, True)
    layer3 = extract_first_features(128, 3, 1, True)
    layer4 = extract_first_features(128, 5, 2, True)
    layer5 = extract_first_features(256, 3, 1, True)
    layer6 = extract_first_features(256, 5, 2, True)
    layer7 = extract_first_features(256, 3, 1, True)
    layer8 = extract_first_features(256, 5, 2, True)
            
    # for x_1
    x_1 = layer1(x_1)
    x_1 = layer2(x_1)
    x_1 = layer3(x_1)
    x_1 = layer4(x_1)
    x_1 = layer5(x_1)
    x_1 = layer6(x_1)
    x_1 = layer7(x_1)
    x_1 = layer8(x_1)
    x_1 = layers.Flatten()(x_1)
    
    # for x_2
    x_2 = layer1(x_2)
    x_2 = layer2(x_2)
    x_2 = layer3(x_2)
    x_2 = layer4(x_2)
    x_2 = layer5(x_2)
    x_2 = layer6(x_2)
    x_2 = layer7(x_2)
    x_2 = layer8(x_2)
    x_2 = layers.Flatten()(x_2)    
    
    x = tf.abs(x_1-x_2)
    x = tf.concat([x_1, x_2, x], 1)
    
    return x

In [10]:
# cell 10: construct SPIMNet network

def make_similarity_model():      
    inputs_1 = layers.Input(shape=[64, 64, 3])
    inputs_2 = layers.Input(shape=[64, 64, 1])
    x_rgb = inputs_1
    x_nir = inputs_2
    
    # convert domains  
    x_converted_nir = RGB2NIR_convertor(x_rgb)
    x_converted_rgb = NIR2RGB_convertor(x_nir)
    
    # matching
    f_nir = NIR_domain_matching(x_nir, x_converted_nir)
    f_rgb = RGB_domain_matching(x_rgb, x_converted_rgb)
    
    # concat features
    x = tf.concat([f_nir, f_rgb], 1)
    
    # metric learning
    x = layers.Dense(1024)(x)
    x = layers.Dense(128)(x)
    x = layers.Dense(1)(x)
    
    model = tf.keras.Model(inputs=[inputs_1, inputs_2], outputs=[x, x_rgb, x_converted_rgb, x_nir, x_converted_nir])    
    return model

In [11]:
# cell 11: Build SPIMNet and print it

similaritor = make_similarity_model()
similaritor.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 64, 64, 1)]  0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
sequential (Sequential)         (None, 32, 32, 64)   3072        input_1[0][0]                    
__________________________________________________________________________________________________
sequential_11 (Sequential)      (None, 32, 32, 64)   1024        input_2[0][0]                    
______________________________________________________________________________________________

In [12]:
# cell 12: Construct content loss (~perceptual loss)

from tensorflow.python.keras.applications.vgg19 import VGG19

def vgg_54():
    return _vgg(20)

def _vgg(output_layer):
    vgg = VGG19(input_shape=(64, 64, 3), include_top=False, weights='imagenet')
    return tf.keras.Model(vgg.input, vgg.layers[output_layer].output)

vgg = vgg_54()

mean_squared_error = tf.keras.losses.MeanSquaredError()
def content_loss(img1, img2):    
    img1_fea = vgg(img1)
    img2_fea = vgg(img2)
        
    loss = mean_squared_error(img1_fea, img2_fea)
            
    return loss

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


In [13]:
# test case
img1 = tf.random.normal((1, 64, 64, 3))
img2 = tf.random.normal((1, 64, 64, 3))
cont_loss_t = content_loss(img1, img2)
print(cont_loss_t.numpy())

0.012737564


In [14]:
# cell 13: build loss function

# path to save checkpoints
checkpoint_dir    = ''
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint        = tf.train.Checkpoint(similaritor=similaritor)

# Instantiate an optimizer.
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.01)

beta = 30.0
alpha = 0.1
def similaritor_loss(pos_output, x_rgb, x_converted_rgb, x_nir, x_converted_nir, neg_output):
    # total_loss1
    pos_loss = cross_entropy(tf.ones_like(pos_output), pos_output)    
    neg_loss = cross_entropy(tf.zeros_like(neg_output), neg_output)
    total_loss1 = pos_loss + neg_loss
    
    # total_loss3
    l1_loss1  = tf.reduce_mean(tf.abs(x_rgb - x_converted_rgb))
    l1_loss2  = tf.reduce_mean(tf.abs(x_nir - x_converted_nir))
    total_loss3  = alpha*l1_loss1 + alpha*l1_loss2    
    
    # total_loss2
    pos_nir1 = tf.concat([x_nir, x_nir, x_nir], 3)
    pos_nir2 = tf.concat([x_converted_nir, x_converted_nir, x_converted_nir], 3)
    
    pos_nir_loss = content_loss(pos_nir1, pos_nir2)
    pos_rgb_loss = content_loss(x_rgb, x_converted_rgb)
    total_loss2  = pos_nir_loss*beta + pos_rgb_loss*beta
    
    # total_loss
    total_loss = total_loss1 + total_loss2 + total_loss3
    return total_loss, total_loss1, total_loss2, total_loss3

In [15]:
# cell 14: train SPIMNet

def train(train_data, epochs):
    for epoch in range(1, epochs):
        
        # learning rate
        if epoch < 20:
            lr = 1e-4
        else:
            lr = 1e-5
        optimizer = tf.keras.optimizers.Adam(lr)        
        
        average_loss = 0
        average_posl = 0
        average_negl = 0
        average_l1lo = 0
                
        count = 0        
        count_ones_pos = 0
        count_ones_neg = 0
    
        for pos_bs_img0, pos_bs_img1, neg_bs_img0, neg_bs_img1 in train_data:  
            pos_bs_img1 = pos_bs_img1[:,:,:,0:1]
            neg_bs_img1 = neg_bs_img1[:,:,:,0:1]
            
            with tf.GradientTape() as sim_tape:  
                # training
                pos_output,x_rgb, x_c_rgb, x_nir, x_c_nir = similaritor([pos_bs_img0, pos_bs_img1], training=True)
                neg_output, _, _, _, _ = similaritor([neg_bs_img0, neg_bs_img1], training=True)
                
                sim_loss, pos_loss, neg_loss, total_loss3 = similaritor_loss(pos_output,x_rgb, x_c_rgb, x_nir, x_c_nir, neg_output)

                # --------- compute training acc ---------
                bool_pos_output = pos_output > 0
                ones_pos_output = tf.reduce_sum(tf.cast(bool_pos_output, tf.float32))        
                count_ones_pos  = count_ones_pos + ones_pos_output

                bool_neg_output = neg_output < 0
                ones_neg_output = tf.reduce_sum(tf.cast(bool_neg_output, tf.float32))        
                count_ones_neg  = count_ones_neg + ones_neg_output
            
            gradients = sim_tape.gradient(sim_loss, similaritor.trainable_variables)            
            optimizer.apply_gradients(zip(gradients, similaritor.trainable_variables))
            
            average_loss = average_loss + sim_loss
            average_posl = average_posl + pos_loss
            average_negl = average_negl + neg_loss
            average_l1lo = average_l1lo + total_loss3
                        
            count = count + 1
            
        average_loss = average_loss / count
        average_posl = average_posl / count
        average_negl = average_negl / count
        average_l1lo = average_l1lo / count
                
        print('epoch {}  average_loss {}  lr {}'.format(epoch, average_loss, lr)) 
        print('normal loss {}  perceptual loss {}  l1 loss {}'.format(average_posl, average_negl, average_l1lo))  
        
        pos_acc = (count_ones_pos*100.0) / n_train_samples
        neg_acc = (count_ones_neg*100.0) / n_train_samples
        print('train acc (pos) {} - acc (neg) {}'.format(pos_acc, neg_acc))
        
        if epoch%10 == 0:            
            checkpoint.save(file_prefix = checkpoint_prefix)
            compute_test_acc()
        print('')

In [16]:
# compute FPR95 for testing sets

def compute_test_acc():
    path_test  = '/home/shared_dir/research/SPIMNet/NIR_RGB/'
    cate_names = ['field', 'indoor', 'oldbuilding', 'street', 'urban', 'water', 'forest', 'mountain']
    cate_index = [120448, 30336, 50688, 82304, 73856, 71552, 188416, 75648]

    for i in range(0,8):
        test_dataset = tf.data.Dataset.list_files(path_test + 'patch-merged-datasets/' + cate_names[i] + '/*.jpg')
        test_dataset = test_dataset.map(load_image_test)
        test_dataset = test_dataset.batch(BATCH_SIZE)

        print(cate_names[i] + ' ...')

        scores_pos = []
        scores_neg = []

        for pos_bs_img0, pos_bs_img1, neg_bs_img0, neg_bs_img1 in test_dataset:
            pos_bs_img1 = pos_bs_img1[:,:,:,0:1]
            neg_bs_img1 = neg_bs_img1[:,:,:,0:1]

            data_outputp, _, _, _, _ = similaritor([pos_bs_img0, pos_bs_img1], training=True)
            data_outputn, _, _, _, _ = similaritor([neg_bs_img0, neg_bs_img1], training=True)

            s_pos = tf.math.sigmoid(data_outputp)
            s_neg = tf.math.sigmoid(data_outputn)

            scores_pos.append(s_pos[:,0].numpy())
            scores_neg.append(s_neg[:,0].numpy())

        scores_np_pos = np.concatenate(scores_pos, axis=0)
        scores_np_neg = np.concatenate(scores_neg, axis=0)

        labels_pos = np.ones((cate_index[i],), dtype=int)
        labels_neg = np.zeros((cate_index[i],), dtype=int)
        
        scores_np = np.concatenate((scores_np_pos,scores_np_neg), axis=0)
        labels_np = np.concatenate((labels_pos,labels_neg), axis=0)

        fpr, tpr, thresholds = metrics.roc_curve(labels_np, scores_np, pos_label=1)
        fpr95 = float(interpolate.interp1d(tpr, fpr)(0.95))
        print('FPR95:', fpr95)
        print('')

In [17]:
# cell 15: train SPIMNet with 35 epochs

EPOCHS = 36 # from 1
train(train_dataset, EPOCHS)            
checkpoint.save(file_prefix = checkpoint_prefix)
compute_test_acc()

epoch 1  average_loss 2.962709426879883  lr 0.0001
normal loss 0.5010076761245728  perceptual loss 2.3033647537231445  l1 loss 0.1583368480205536
train acc (pos) 91.44445037841797 - acc (neg) 93.17631530761719

epoch 2  average_loss 2.1625747680664062  lr 0.0001
normal loss 0.22761917114257812  perceptual loss 1.799654245376587  l1 loss 0.13529269397258759
train acc (pos) 96.71932983398438 - acc (neg) 97.79029846191406

epoch 3  average_loss 1.9807138442993164  lr 0.0001
normal loss 0.17784222960472107  perceptual loss 1.687182903289795  l1 loss 0.11569993942975998
train acc (pos) 97.72111511230469 - acc (neg) 98.46920776367188

epoch 4  average_loss 1.8795654773712158  lr 0.0001
normal loss 0.15127530694007874  perceptual loss 1.626526117324829  l1 loss 0.10176645964384079
train acc (pos) 98.3113784790039 - acc (neg) 98.87208557128906

epoch 5  average_loss 1.8101754188537598  lr 0.0001
normal loss 0.13693860173225403  perceptual loss 1.5802676677703857  l1 loss 0.09296322613954544
tr

epoch 35  average_loss 1.313899040222168  lr 1e-05
normal loss 0.07579156756401062  perceptual loss 1.1793227195739746  l1 loss 0.05878268554806709
train acc (pos) 99.98918914794922 - acc (neg) 99.95243072509766

field ...
FPR95: 0.023516372210414287

indoor ...
FPR95: 0.017728111814345967

oldbuilding ...
FPR95: 0.006944444444444444

street ...
FPR95: 0.0029038685847589426

urban ...
FPR95: 0.004265056325823223

water ...
FPR95: 0.02396019677996422

forest ...
FPR95: 0.0011092476222826087

mountain ...
FPR95: 0.009013677382966722



In [18]:
# cell 16: restore the lastest checkpoint

!ls {checkpoint_dir}/
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

NGC-DL-CONTAINER-LICENSE  dev	lib    mnt   root  srv	tmp
bin			  etc	lib64  opt   run   sys	usr
boot			  home	media  proc  sbin  tf	var


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f763c64e0f0>