<a href="https://colab.research.google.com/github/kyle-gao/GRSS_TrackMSD2021/blob/main/DDN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images

In [2]:
import tensorflow as tf
import keras.backend as K

import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import shutil
import os
from PIL import Image

from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, multiply, Permute, Concatenate, \
    Conv2D, Add, Activation, Lambda
from keras import backend as K
from keras.activations import sigmoid

from keras import applications
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
    Activation, ZeroPadding2D,Conv2DTranspose,Subtract,multiply,add,UpSampling2D
from keras import  layers
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger

In [3]:
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same',droprate=0.3):
    x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides,activation = 'relu')(x)
    x=BatchNormalization(axis=3)(x)
    x=Dropout(rate = droprate)(x)
    return x

def attach_attention_module(net, attention_module):
    if attention_module == 'se_block':  # SE_block
        net = se_block(net)
    elif attention_module == 'cbam_block':  # CBAM_block
        net = cbam_block(net)
    else:
        raise Exception("'{}' is not supported attention module!".format(attention_module))

    return net


def se_block(input_feature, ratio=8):
    """Contains the implementation of Squeeze-and-Excitation(SE) block.
    As described in https://arxiv.org/abs/1709.01507.
    """

    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    channel = input_feature.shape[channel_axis]

    se_feature = GlobalAveragePooling2D()(input_feature)
    se_feature = Reshape((1, 1, channel))(se_feature)
    assert se_feature.shape[1:] == (1, 1, channel)
    se_feature = Dense(channel // ratio,
                       activation='relu',
                       kernel_initializer='he_normal',
                       use_bias=True,
                       bias_initializer='zeros')(se_feature)
    assert se_feature.shape[1:] == (1, 1, channel // ratio)
    se_feature = Dense(channel,
                       activation='sigmoid',
                       kernel_initializer='he_normal',
                       use_bias=True,
                       bias_initializer='zeros')(se_feature)
    assert se_feature.shape[1:] == (1, 1, channel)
    if K.image_data_format() == 'channels_first':
        se_feature = Permute((3, 1, 2))(se_feature)

    se_feature = multiply([input_feature, se_feature])
    return se_feature


def cbam_block(cbam_feature, ratio=8):
    """Contains the implementation of Convolutional Block Attention Module(CBAM) block.
    As described in https://arxiv.org/abs/1807.06521.
    """

    cbam_feature = channel_attention(cbam_feature, ratio)
    cbam_feature = spatial_attention(cbam_feature)
    return cbam_feature


def channel_attention(input_feature, ratio=8):
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    channel = input_feature.shape[channel_axis]

    shared_layer_one = Dense(channel // ratio,
                             activation='relu',
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    shared_layer_two = Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')

    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = Reshape((1, 1, channel))(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)
    avg_pool = shared_layer_one(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel // ratio)
    avg_pool = shared_layer_two(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)

    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = Reshape((1, 1, channel))(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)
    max_pool = shared_layer_one(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel // ratio)
    max_pool = shared_layer_two(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)

    cbam_feature = Add()([avg_pool, max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])


def spatial_attention(input_feature):
    kernel_size = 7

    if K.image_data_format() == "channels_first":
        channel = input_feature.shape[1]
        cbam_feature = Permute((2, 3, 1))(input_feature)
    else:
        channel = input_feature.shape[-1]
        cbam_feature = input_feature

    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
    assert avg_pool.shape[-1] == 1
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
    assert max_pool.shape[-1] == 1
    concat = Concatenate(axis=3)([avg_pool, max_pool])
    assert concat.shape[-1] == 2
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    assert cbam_feature.shape[-1] == 1

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])


def get_spatial_attention_map(input_feature):
    kernel_size = 7
    cbam_feature = input_feature

    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
    assert avg_pool.shape[-1] == 1
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
    assert max_pool.shape[-1] == 1
    concat = Concatenate(axis=3)([avg_pool, max_pool])
    assert concat.shape[-1] == 2
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    assert cbam_feature.shape[-1] == 1

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return cbam_feature

#Data

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
shutil.unpack_archive("/content/drive/MyDrive/Kaggle_CoverChange.zip","/content/")

In [6]:
data_dir = "/content/Kaggle_CoverChange"
list_ds = tf.data.Dataset.list_files(str(data_dir+"/im1/*"),shuffle=False)
#dataset is made up of strings

In [7]:

#classes = array([ 29.,  38.,  75.,  76., 128., 150., 255.])
class_dict = {(29):1, (38):2, (75):3,(76):3,128:4,150:5,(255):0} #mapping 75 and 76 to same class, i'm assuming this is an encoding issue.

def to_categorical(tensor,class_dict):
  #maps pixel values to categories 1,2...num_classes
  for k,v in class_dict.items():
    tensor[tensor==k]=v
  return tensor

In [8]:
label1_fn = "/content/Kaggle_CoverChange/label1/00003.png"
label1 = tf.keras.preprocessing.image.load_img(label1_fn,color_mode="grayscale")
label1 = tf.keras.preprocessing.image.img_to_array(label1)

##Testing label change conversion

In [9]:
num_classes = 6

##Need to convert 2 labels into 1 change label

Trick: Let n be number of classes
Step 1. Use an outer product between one_hot encoded labels
(h,w,n),(h,w,n)->(h,w,n,n), this is gives a one hot encoded change matrix M.

Eg. M = 0 except for Mij = 1 -> class i changed to class j.

Step 2. Convert to categorical change -> (h,w,n,n)->(h,w,n^2)

In [10]:
"""#converting (label1,label2)->(label changemap)
#where change map has num_classes^2 (not all unique) change classes 
def make_label_change_dict(num_classes):
  label_change_dict = {}
  i = 0
  for x in range(num_classes):
    for y in range(num_classes):
      label_change_dict[(x,y)] = i *(x!=y)
      i += 1
  return label_change_dict

label_change_dict = make_label_change_dict(6)
print(label_change_dict)"""

#converting (label1,label2)->(label changemap)
#where change map has num_classes^2 (not all unique) change classes 
def make_label_change_array(num_classes):
  """Arg:
  num_classes:int, number of classes
  returns:
  num_classes^2 matrix of categorical change labels.
  """
  label_change_arr = np.zeros((num_classes,num_classes),dtype=np.uint8)
  i = 0
  for x in range(num_classes):
    for y in range(num_classes):
      label_change_arr[x,y] = i *(x!=y)
      i += 1
  return label_change_arr

label_change_arr = make_label_change_array(6)

print(label_change_arr)

[[ 0  1  2  3  4  5]
 [ 6  0  8  9 10 11]
 [12 13  0 15 16 17]
 [18 19 20  0 22 23]
 [24 25 26 27  0 29]
 [30 31 32 33 34  0]]


In [11]:
def get_class_change_label(before,after,label_change_arr):
  """
  Input:
  two np.array or tf.tensor of shape (batch,Width,Height, num_classes) corresponding to class labeled data in one hot encoding
  the arrays represent the before and after class labels for change detection
  Output:
  a np.array of shape (batch,Width,Height) corresponding to change labels (see figure 2)
  """
  labels_combined = np.einsum("abd,abe->abde",before,after)
  labels_combined = labels_combined*label_change_arr
  labels_combined = np.sum(labels_combined, axis=(-1,-2))
  return labels_combined

In [12]:
label1_fn = "/content/Kaggle_CoverChange/label1/00018.png"
label2_fn = "/content/Kaggle_CoverChange/label2/00018.png"

##Testing change label generation trick

In [13]:
 """   label1 = tf.keras.preprocessing.image.load_img(label1_fn, color_mode="grayscale")
    label1 = tf.keras.preprocessing.image.img_to_array(label1)
    label1old = label1
    label1 = to_categorical(label1,class_dict)
    label1 = tf.expand_dims(label1,axis=0)
    label1 = tf.one_hot(tf.cast(label1[:,:,:,0],dtype = tf.uint8),depth = num_classes)




    label2 = tf.keras.preprocessing.image.load_img(label2_fn, color_mode ="grayscale")
    label2 = tf.keras.preprocessing.image.img_to_array(label2)
    label1old = label1
    label2 = to_categorical(label2,class_dict)
    label2 = tf.expand_dims(label2,axis=0)
    label2old = label2
    label2 = tf.one_hot(tf.cast(label1[:,:,:,0],dtype = tf.uint8),depth=num_classes)

    print(label1.shape)
    print(label2.shape)
    print(label1[0,300,300,:])
    print(label2[0,300,300,:])
    lc = np.einsum("abcd,abce->abcde",label1,label2)

    print(lc[0,300,300,:,:])
    lc = lc * label_change_arr
    print(lc[0,300,300,:,:])"""

'   label1 = tf.keras.preprocessing.image.load_img(label1_fn, color_mode="grayscale")\n   label1 = tf.keras.preprocessing.image.img_to_array(label1)\n   label1old = label1\n   label1 = to_categorical(label1,class_dict)\n   label1 = tf.expand_dims(label1,axis=0)\n   label1 = tf.one_hot(tf.cast(label1[:,:,:,0],dtype = tf.uint8),depth = num_classes)\n\n\n\n\n   label2 = tf.keras.preprocessing.image.load_img(label2_fn, color_mode ="grayscale")\n   label2 = tf.keras.preprocessing.image.img_to_array(label2)\n   label1old = label1\n   label2 = to_categorical(label2,class_dict)\n   label2 = tf.expand_dims(label2,axis=0)\n   label2old = label2\n   label2 = tf.one_hot(tf.cast(label1[:,:,:,0],dtype = tf.uint8),depth=num_classes)\n\n   print(label1.shape)\n   print(label2.shape)\n   print(label1[0,300,300,:])\n   print(label2[0,300,300,:])\n   lc = np.einsum("abcd,abce->abcde",label1,label2)\n\n   print(lc[0,300,300,:,:])\n   lc = lc * label_change_arr\n   print(lc[0,300,300,:,:])'

##Testing over

In [14]:
def get_item(path):

  #wrap function in tf.numpy_function() or tf.data.dataset.map bugs out

  def _get_item(path):
    """
    args:
    path: Dataset.list_file dataset element

    returns: 
    (h,w,3),(h,w,1),(h,w,3),(h,w,1) image-label 4-tuple in tf.float32 tf.Tensor
    """
    fn = tf.strings.split(path,"/")
    base_dir = tf.strings.join(fn[:-2],separator="/")

    image1_fn = (base_dir+"/im1/"+fn[-1]).numpy()
    image2_fn = (base_dir+"/im2/"+fn[-1]).numpy()

    label1_fn = (base_dir+"/label1/"+fn[-1]).numpy()
    label2_fn = (base_dir+"/label2/"+fn[-1]).numpy()

    #label1_fn = "/content/Kaggle_CoverChange/label1/00018.png"
    #label2_fn = "/content/Kaggle_CoverChange/label2/00018.png"


    image1 = tf.keras.preprocessing.image.load_img(image1_fn)
    image1 = tf.keras.preprocessing.image.img_to_array(image1)

    image2 = tf.keras.preprocessing.image.load_img(image2_fn)
    image2 = tf.keras.preprocessing.image.img_to_array(image2)

    label1 = tf.keras.preprocessing.image.load_img(label1_fn, color_mode="grayscale")
    label1 = tf.keras.preprocessing.image.img_to_array(label1)
    label1 = to_categorical(label1,class_dict)
    #label1 = tf.expand_dims(label1,axis=0)

    label1 = tf.one_hot(tf.cast(label1[:,:,0],dtype = tf.uint8),depth = num_classes)




    label2 = tf.keras.preprocessing.image.load_img(label2_fn, color_mode ="grayscale")
    label2 = tf.keras.preprocessing.image.img_to_array(label2)
    label2 = to_categorical(label2,class_dict)
    #label2 = tf.expand_dims(label2,axis=0)

    label2 = tf.one_hot(tf.cast(label2[:,:,0],dtype = tf.uint8),depth=num_classes)

    change_label = get_class_change_label(label1,label2,label_change_arr=label_change_arr)
    change_label = tf.expand_dims(change_label, axis=-1)

    #classes = array([ 29.,  38.,  75.,  76., 128., 150., 255.])
    #(h,w,1) tensor of int from 1-6

    return image1,image2,change_label

  output = tf.numpy_function(_get_item,[path],[tf.float32,tf.float32,tf.float32])

  return output

In [15]:
def preprocessing(list_ds,batch_size=8,augmentation=None):
  """
  applies some preprocessing
  args:
  list_ds-Dataset.list_files dataset object
  batch_size-int 
  augmentation-a function (x,y)->(x,y)
  returns-batched dataset
  """
  #ds = list_ds.cache() 
  ds = list_ds.map(get_item,num_parallel_calls=tf.data.AUTOTUNE) #(h,w,3),(h,w,1)
  ds = ds.batch(batch_size)
  ds = ds.shuffle(100)

  if augmentation:
    ds = ds.map((lambda x,y : augmentation(x,y)),num_parallel_calls=tf.data.AUTOTUNE)

  
  ds = ds.prefetch(8)
  return ds

def transform(x,y):
  """
  Write your data transformations here
  """
  x=tf.keras.layers.experimental.preprocessing.Rescaling(1./255.0)(x)
  y=tf.one_hot(tf.cast(y[:,:,0],dtype = tf.int32),depth = 7)

  return x,y

#load pretrained model

In [16]:
shutil.unpack_archive("/content/drive/MyDrive/vgg_512_seg_model.zip","/content/Drive/Saved_Model/")

##Test

In [17]:
"""
vgg16_512 = tf.keras.models.load_model("/content/Drive/Saved_Model")
vgg16_512.summary()

feature_layers = [vgg16_512.get_layer("features_c1"),vgg16_512.get_layer("features_c2"),vgg16_512.get_layer('features_c3'),vgg16_512.get_layer('features_c4'),vgg16_512.get_layer('features_c5')]
feature_extractor = tf.keras.models.Model(inputs = vgg16_512.input, outputs = [layer.output for layer in feature_layers])

test_tensor = tf.random.normal((1,512,512,3))
test_out = feature_extractor(test_tensor)
for t in test_out:
  print(t.shape)
"""

'\nvgg16_512 = tf.keras.models.load_model("/content/Drive/Saved_Model")\nvgg16_512.summary()\n\nfeature_layers = [vgg16_512.get_layer("features_c1"),vgg16_512.get_layer("features_c2"),vgg16_512.get_layer(\'features_c3\'),vgg16_512.get_layer(\'features_c4\'),vgg16_512.get_layer(\'features_c5\')]\nfeature_extractor = tf.keras.models.Model(inputs = vgg16_512.input, outputs = [layer.output for layer in feature_layers])\n\ntest_tensor = tf.random.normal((1,512,512,3))\ntest_out = feature_extractor(test_tensor)\nfor t in test_out:\n  print(t.shape)\n'

#DDN

In [18]:
vgg16_512 = tf.keras.models.load_model("/content/Drive/Saved_Model")

In [19]:
vgg16_512.summary()

Model: "model_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_24 (InputLayer)           [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
base_vgg (Functional)           [(None, 128, 128, 25 14714688    input_24[0][0]                   
__________________________________________________________________________________________________
features_c5 (Dropout)           (None, 32, 32, 512)  0           base_vgg[0][2]                   
__________________________________________________________________________________________________
max_pooling2d_57 (MaxPooling2D) (None, 16, 16, 512)  0           features_c5[0][0]                
___________________________________________________________________________________________

In [20]:
def DDN (imsize, num_classes = 6):
    #if input image 512*512

    vgg16_512 = tf.keras.models.load_model("/content/Drive/Saved_Model")

    image1 = layers.Input((imsize,imsize,3),name = "image1")
    image2 = layers.Input((imsize,imsize,3),name = "image2")

    feature_layers = [vgg16_512.get_layer("features_c1"),vgg16_512.get_layer("features_c2"),vgg16_512.get_layer('features_c3'),vgg16_512.get_layer('features_c4'),vgg16_512.get_layer('features_c5')]
    feature_extractor = tf.keras.models.Model(inputs = vgg16_512.input, outputs = [layer.output for layer in feature_layers],trainable=False)

    features_1 = feature_extractor(image1)
    features_2 = feature_extractor(image1)

    t1_b5c3 = features_1[4] #(None, 32, 32, 512)
    t2_b5c3 = features_2[4] 

    t1_b4c3 = features_1[3] #(None, 64, 64, 512)
    t2_b4c3 = features_2[3] 
    
    t1_b3c3 = features_1[2] #(None, 128, 128, 256)  
    t2_b3c3 = features_2[2]

    t1_b2c2 = features_1[1]
    t2_b2c2 = features_2[1] #(None, 256, 256, 128) 

    t1_b1c2 = features_1[0]
    t2_b1c2 = features_2[0] #(None, 512, 512, 64)


    """
    pair5 = layers.Input((imsize,imsize,64*2), name='pair5') #rbg images concatenated channel-wise
    pair4 = layers.Input((imsize//2,imsize//2,128*2), name='pair4')
    pair3 = layers.Input((imsize//4,imsize//4,256*2), name='pair3')
    pair2 = layers.Input((imsize//8,imsize//8,512*2), name='pair2')
    pair1 = layers.Input((imsize//16,imsize//16,512*2), name='pair1')


    t1_b5c3 = pair1[:,:,:,:3] #(None, 32, 32, 512)
    t2_b5c3 = pair1[:,:,:,3:] 

    t1_b4c3 = pair2[:,:,:,:3] #(None, 64, 64, 512)
    t2_b4c3 = pair2[:,:,:,3:] 
    
    t1_b3c3 = pair3[:,:,:,:3] #(None, 128, 128, 256)  
    t2_b3c3 = pair3[:,:,:,3:]

    t1_b2c2 = pair4[:,:,:,:3]
    t2_b2c2 = pair4[:,:,:,3:] #(None, 256, 256, 128) 

    t1_b1c2 = pair5[:,:,:,:3]
    t2_b1c2 = pair5[:,:,:,3:] #(None, 512, 512, 64)
    """

    concat_b5c3 = concatenate([t1_b5c3, t2_b5c3], axis=3) #channel 1024
    x = Conv2d_BN(concat_b5c3,imsize, 3)
    x = Conv2d_BN(x,imsize,3)
    attention_map_1 = get_spatial_attention_map(x)
    x = multiply([x, attention_map_1])
    x = BatchNormalization(axis=3)(x)

    #branche1
    branch_1 =Conv2D(num_classes**2, kernel_size=1, padding='same',name='output_32')(x)

    x = Conv2DTranspose(imsize, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x)
    x = concatenate([x,t1_b4c3,t2_b4c3],axis=3)
    x = channel_attention(x)
    x = Conv2d_BN(x,imsize,3)
    x = Conv2d_BN(x,256,3)
    x = Conv2d_BN(x,128,3)
    attention_map_2 = get_spatial_attention_map(x)
    x = multiply([x, attention_map_2])
    x = BatchNormalization(axis=3)(x)

    #branche2
    branch_2 =Conv2D(num_classes**2, kernel_size=1, padding='same',name='output_64')(x)

    x = Conv2DTranspose(256, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x)
    x = concatenate([x,t1_b3c3,t2_b3c3],axis=3)
    x = channel_attention(x)
    x = Conv2d_BN(x,256,3)
    x = Conv2d_BN(x,128,3)
    x = Conv2d_BN(x, 64, 3)
    attention_map_3 = get_spatial_attention_map(x)
    x = multiply([x, attention_map_3])
    x = BatchNormalization(axis=3)(x)

    #branche3
    branch_3 =Conv2D(num_classes**2, kernel_size=1, padding='same',name='output_128')(x)

    x = Conv2DTranspose(128, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x)
    x = concatenate([x,t1_b2c2,t2_b2c2],axis=3)
    x = channel_attention(x)
    x = Conv2d_BN(x,128,3)
    x = Conv2d_BN(x,128,3)
    x = Conv2d_BN(x, 128, 3)
    attention_map_4 = get_spatial_attention_map(x)
    x = multiply([x, attention_map_4])
    x = BatchNormalization(axis=3)(x)

    #branche4
    branch_4 =Conv2D(num_classes**2, kernel_size=1, padding='same',name='output_256')(x)

    x = Conv2DTranspose(64, kernel_size=2, strides=2, kernel_initializer="he_normal", padding='same')(x)
    x = concatenate([x,t1_b1c2,t2_b1c2],axis=3)
    x = channel_attention(x)
    x = Conv2d_BN(x,128,3)
    x = Conv2d_BN(x,64,3)
    x = Conv2d_BN(x, 64, 3)
    attention_map_5 = get_spatial_attention_map(x)
    x = multiply([x, attention_map_5])

    # branche5
    branch_5 =Conv2D(num_classes**2, kernel_size=1, padding='same',name='output_imsize')(x)

    DDN = Model(inputs=[image1,image2], outputs=[branch_5,branch_4,branch_3,branch_2,branch_1])

    return DDN

In [21]:
#TEST
"""
imsize = 512
image1 = tf.random.normal((1,512,512,3))
image2 = tf.random.normal((1,512,512,3))
ddn = DDN(512,6)
forward = ddn([image1,image2])
for f in forward:
  print (f.shape)"""

'\nimsize = 512\nimage1 = tf.random.normal((1,512,512,3))\nimage2 = tf.random.normal((1,512,512,3))\nddn = DDN(512,6)\nforward = ddn([image1,image2])\nfor f in forward:\n  print (f.shape)'

In [22]:
ddn = DDN(512,6)
ddn.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image1 (InputLayer)             [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
model (Functional)              [(None, 512, 512, 64 14714688    image1[0][0]                     
                                                                 image1[0][0]                     
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 32, 32, 1024) 0           model[0][4]                      
                                                                 model[1][4]                      
____________________________________________________________________________________________

In [23]:
lr = 0.003
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

train_accuracy = tf.keras.metrics.CategoricalAccuracy(
        name='train_accuracy')

entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)#adds a softmax step

#Using Pooling layers to reduce target image size for deep supervision

In [24]:
def DS_loss_fn(y_true, y_pred, loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), weights=[3,1,1,1,1]):
  loss = 0
  for i in range(len(weights)):
    print(y_true.shape)
    loss = loss+ weights[i]*loss_function(y_true,y_pred[i])
    y_true = tf.keras.layers.MaxPool2D()(y_true) 
  return loss


In [25]:
##Test
"""
y_pred = ddn([test_im1,test_im2])
DS_loss_fn(test_label,y_pred)"""

'\ny_pred = ddn([test_im1,test_im2])\nDS_loss_fn(test_label,y_pred)'

In [26]:
save_path = "/content/Saved_Models/"
if not(os.path.isdir(save_path)):
  os.mkdir(save_path)
batch_size=2
epochs = 1
model_savename="DDN_test"
train_ds = preprocessing(list_ds, batch_size = batch_size)

In [27]:
for epoch in range(epochs):
  print("Start of epoch {}".format(epoch))

  for step, (im1,im2,label) in enumerate(train_ds):
    with tf.GradientTape() as tape:

      y_pred = ddn([im1,im2],training=True)

      loss_value = DS_loss_fn(label,y_pred)
      

    grads = tape.gradient(loss_value,ddn.trainable_weights)
    optimizer.apply_gradients(zip(grads,ddn.trainable_weights))

  if step % 100 == 0:
      print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
      acc = train_accuracy(y_pred,y_batch)
      print ("Training accuracy at step {} is {}".format(step, acc))
      print("Seen so far: %s samples" % ((step + 1) * batch_size))
            
  test.save('saved_model/{}'.format(model_savename))

Start of epoch 0


ResourceExhaustedError: ignored