# This notebook assumes execution will be done in a Google Colab environment.

In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from typing import Tuple, Optional
from sklearn.model_selection import train_test_split
import cv2

## Loading and preprocessing the four datasets used in the paper

In [None]:
if not os.path.isdir("/content/datasets"):
    os.mkdir("/content/datasets")

if not os.path.isdir("/content/datasets/DSIFN"):
    if os.path.exists("/content/DSIFN.zip"):
      !unzip /content/DSIFN.zip -d /content/datasets/DSIFN

if not os.path.isdir("/content/datasets/WHU"):
    if os.path.exists("/content/WHU.zip"):
      !unzip /content/WHU.zip -d /content/datasets/WHU

if not os.path.isdir("/content/datasets/LEVIR"):
    if os.path.exists("/content/LEVIR.zip"):
      !unzip /content/LEVIR.zip -d /content/datasets/LEVIR

if not os.path.isdir("/content/datasets/CDD"):
    if os.path.exists("/content/CDD.zip"):
      !unzip /content/CDD.zip -d /content/datasets/CDD

def load_DSIFN(percent_samples):

    before = []
    after = []
    masks = []

    dataset_path = "/content/datasets/DSIFN"

    before_images = {os.path.basename(img): os.path.join(dataset_path + "/A", img) for img in os.listdir(dataset_path + "/A") if img.endswith(('.jpg', '.png'))}
    after_images = {os.path.basename(img): os.path.join(dataset_path + "/B", img) for img in os.listdir(dataset_path + "/B") if img.endswith(('.jpg', '.png'))}
    mask_images = {os.path.basename(img): os.path.join(dataset_path + "/label", img) for img in os.listdir(dataset_path + "/label") if img.endswith(('.jpg', '.png'))}

    before_images = sorted(before_images)
    after_images = sorted(after_images)
    mask_images = sorted(mask_images)

    common_files = set(before_images) & set(after_images) & set(mask_images)


    before.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
    after.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
    masks.extend(dataset_path + "/label/" + mask_images[i] for i in range(len(common_files)))

    all_data = pd.DataFrame(
        {"Before Image": before,
         "After Image": after,
         "GT Mask": masks}
        )
    all_data["Image ID"] = all_data["Before Image"].str.split("/").str[-1]
    all_data["Dataset"] = "DSIFN"

    # Make train test val splits
    temp_data, test_data = train_test_split(all_data, test_size=0.15, random_state=42)
    val_size = 0.15 / (1 - 0.15)
    train_data, val_data = train_test_split(temp_data, test_size=val_size, random_state=42)

    # Shuffle the datasets
    train_data = train_data.sample(frac=percent_samples).reset_index(drop=True)
    val_data = val_data.sample(frac=percent_samples).reset_index(drop=True)
    test_data = test_data.sample(frac=percent_samples).reset_index(drop=True)

    return train_data, val_data, test_data

def load_WHU(percent_samples):

    before = []
    after = []
    masks = []

    dataset_path = "/content/datasets/WHU"

    before_images = {os.path.basename(img): os.path.join(dataset_path + "/A", img) for img in os.listdir(dataset_path + "/A") if img.endswith(('.jpg', '.png'))}
    after_images = {os.path.basename(img): os.path.join(dataset_path + "/B", img) for img in os.listdir(dataset_path + "/B") if img.endswith(('.jpg', '.png'))}
    mask_images = {os.path.basename(img): os.path.join(dataset_path + "/label", img) for img in os.listdir(dataset_path + "/label") if img.endswith(('.jpg', '.png'))}

    before_images = sorted(before_images)
    after_images = sorted(after_images)
    mask_images = sorted(mask_images)

    common_files = set(before_images) & set(after_images) & set(mask_images)


    before.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
    after.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
    masks.extend(dataset_path + "/label/" + mask_images[i] for i in range(len(common_files)))

    all_data = pd.DataFrame(
        {"Before Image": before,
         "After Image": after,
         "GT Mask": masks}
        )
    all_data["Image ID"] = all_data["Before Image"].str.split("_").str[-1]
    all_data["Dataset"] = "WHU"

    # Make train test val splits
    temp_data, test_data = train_test_split(all_data, test_size=0.15, random_state=42)
    val_size = 0.15 / (1 - 0.15)
    train_data, val_data = train_test_split(temp_data, test_size=val_size, random_state=42)

    train_data.shape

    # Shuffle the datasets
    train_data = train_data.sample(frac=percent_samples).reset_index(drop=True)
    val_data = val_data.sample(frac=percent_samples).reset_index(drop=True)
    test_data = test_data.sample(frac=percent_samples).reset_index(drop=True)

    train_data.shape
    train_data.head()

    return train_data, val_data, test_data

def load_LEVIR(percent_samples):

    before_train = []
    after_train = []
    masks_train = []

    before_val = []
    after_val = []
    masks_val = []

    before_test = []
    after_test = []
    masks_test = []

    for dataset_type in ["train_", "val_", "test_"]:
        dataset_path = "/content/datasets/LEVIR"

        before_images = {os.path.basename(img): os.path.join(dataset_path + "/A", img) for img in os.listdir(dataset_path + "/A") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}
        after_images = {os.path.basename(img): os.path.join(dataset_path + "/B", img) for img in os.listdir(dataset_path + "/B") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}
        mask_images = {os.path.basename(img): os.path.join(dataset_path + "/label", img) for img in os.listdir(dataset_path + "/label") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}

        before_images = sorted(before_images)
        after_images = sorted(after_images)
        mask_images = sorted(mask_images)

        common_files = set(before_images) & set(after_images) & set(mask_images)

        if dataset_type == "train_":
            before_train.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_train.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_train.extend(dataset_path + "/label/" + mask_images[i] for i in range(len(common_files)))
        elif dataset_type == "val_":
            before_val.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_val.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_val.extend(dataset_path + "/label/" + mask_images[i] for i in range(len(common_files)))
        elif dataset_type == "test_":
            before_test.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_test.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_test.extend(dataset_path + "/label/" + mask_images[i] for i in range(len(common_files)))

    train_data = pd.DataFrame(
        {"Before Image": before_train,
         "After Image": after_train,
         "GT Mask": masks_train}
        )
    train_data["Image ID"] = "_" + train_data["Before Image"].str.split("_").str[-2] + "_" + train_data["Before Image"].str.split("_").str[-1]
    train_data["Dataset"] = "LEVIR"

    val_data = pd.DataFrame(
        {"Before Image": before_val,
         "After Image": after_val,
         "GT Mask": masks_val}
         )
    val_data["Image ID"] = val_data["Before Image"].str.split("_").str[-2] + "_" + val_data["Before Image"].str.split("_").str[-1]
    val_data["Dataset"] = "LEVIR"

    test_data = pd.DataFrame(
        {"Before Image": before_test,
         "After Image": after_test,
         "GT Mask": masks_test}
        )
    test_data["Image ID"] = test_data["Before Image"].str.split("_").str[-2] + "_" + test_data["Before Image"].str.split("_").str[-1]
    test_data["Dataset"] = "LEVIR"

    # Shuffle the datasets
    train_data = train_data.sample(frac=percent_samples).reset_index(drop=True)
    val_data = val_data.sample(frac=percent_samples).reset_index(drop=True)
    test_data = test_data.sample(frac=percent_samples).reset_index(drop=True)

    train_data.shape
    train_data.head()

    return train_data, val_data, test_data

def load_CDD(percent_samples):

    before_train = []
    after_train = []
    masks_train = []

    before_val = []
    after_val = []
    masks_val = []

    before_test = []
    after_test = []
    masks_test = []

    for dataset_type in ["train", "val", "test"]:
        dataset_path = os.path.join("/content/datasets/CDD/CDD", dataset_type)

        before_images = {os.path.basename(img): os.path.join(dataset_path + "/A", img) for img in os.listdir(dataset_path + "/A") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}
        after_images = {os.path.basename(img): os.path.join(dataset_path + "/B", img) for img in os.listdir(dataset_path + "/B") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}
        mask_images = {os.path.basename(img): os.path.join(dataset_path + "/OUT", img) for img in os.listdir(dataset_path + "/OUT") if img.startswith(dataset_type) and img.endswith(('.jpg', '.png'))}

        before_images = sorted(before_images)
        after_images = sorted(after_images)
        mask_images = sorted(mask_images)

        common_files = set(before_images) & set(after_images) & set(mask_images)

        if dataset_type == "train":
            before_train.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_train.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_train.extend(dataset_path + "/OUT/" + mask_images[i] for i in range(len(common_files)))
        elif dataset_type == "val":
            before_val.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_val.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_val.extend(dataset_path + "/OUT/" + mask_images[i] for i in range(len(common_files)))
        elif dataset_type == "test":
            before_test.extend(dataset_path + "/A/" + before_images[i] for i in range(len(common_files)))
            after_test.extend(dataset_path + "/B/" + after_images[i] for i in range(len(common_files)))
            masks_test.extend(dataset_path + "/OUT/" + mask_images[i] for i in range(len(common_files)))

    train_data = pd.DataFrame(
        {"Before Image": before_train,
         "After Image": after_train,
         "GT Mask": masks_train}
        )
    print(train_data.shape[0])
    train_data["Image ID"] = train_data["Before Image"].str.split("_").str[-1]
    train_data["Dataset"] = "CDD"

    val_data = pd.DataFrame(
        {"Before Image": before_val,
         "After Image": after_val,
         "GT Mask": masks_val}
         )
    val_data["Image ID"] = val_data["Before Image"].str.split("_").str[-1]
    val_data["Dataset"] = "CDD"

    test_data = pd.DataFrame(
        {"Before Image": before_test,
         "After Image": after_test,
         "GT Mask": masks_test}
        )
    test_data["Image ID"] = test_data["Before Image"].str.split("_").str[-1]
    test_data["Dataset"] = "CDD"

    # Shuffle the datasets
    train_data = train_data.sample(frac=percent_samples).reset_index(drop=True)
    val_data = val_data.sample(frac=percent_samples).reset_index(drop=True)
    test_data = test_data.sample(frac=percent_samples).reset_index(drop=True)

    train_data.shape
    train_data.head()

    return train_data, val_data, test_data

# You can combine at least two dataset splits and up to all four dataset splits
def combine_dataset_splits(
    train1: pd.DataFrame, val1: pd.DataFrame, test1: pd.DataFrame,
    train2: pd.DataFrame, val2: pd.DataFrame, test2: pd.DataFrame,
    train3: Optional[pd.DataFrame] = None,
    val3: Optional[pd.DataFrame] = None,
    test3: Optional[pd.DataFrame] = None,
    train4: Optional[pd.DataFrame] = None,
    val4: Optional[pd.DataFrame] = None,
    test4: Optional[pd.DataFrame] = None,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:

    train_list = [train1, train2]
    val_list = [val1, val2]
    test_list = [test1, test2]

    if train3 is not None:
        train_list.append(train3)
        val_list.append(val3)
        test_list.append(test3)

    if train4 is not None:
        train_list.append(train4)
        val_list.append(val4)
        test_list.append(test4)

    train_combined = pd.concat(train_list, ignore_index=True)
    val_combined = pd.concat(val_list, ignore_index=True)
    test_combined = pd.concat(test_list, ignore_index=True)

    train_combined = train_combined.sample(frac=1).reset_index(drop=True)
    val_combined = val_combined.sample(frac=1).reset_index(drop=True)
    test_combined = test_combined.sample(frac=1).reset_index(drop=True)

    return train_combined, val_combined, test_combined

# Function to display a desired number of images from the dataset (Before image, after image, and gt mask)
def display_before_after_mask(images, num_samples):

  for i in range(num_samples):

    before = images.iloc[i]['Before Image']
    after = images.iloc[i]['After Image']
    mask = images.iloc[i]['GT Mask']

    before_img = cv2.imread(before)
    after_img = cv2.imread(after)
    mask_img = cv2.imread(mask, cv2.IMREAD_GRAYSCALE)

    before_img = cv2.cvtColor(before_img, cv2.COLOR_BGR2RGB)
    after_img = cv2.cvtColor(after_img, cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(before_img)
    plt.title("Before Image: " + images.iloc[i]['Dataset'] + " " + images.iloc[i]['Image ID'])

    plt.subplot(1, 3, 2)
    plt.imshow(after_img)
    plt.title("After Image: " + images.iloc[i]['Dataset'] + " " + images.iloc[i]['Image ID'])

    plt.subplot(1, 3, 3)
    plt.imshow(mask_img, cmap='gray')
    plt.title("GT Mask: " + images.iloc[i]['Dataset'] + " " + images.iloc[i]['Image ID'])

    plt.show()

## The datasets to use can be adjusted by the user
- Used two of the four in this case

In [None]:
LEVIR_train, LEVIR_val, LEVIR_test = load_LEVIR(1.0)
# CDD_train, CDD_val, CDD_test = load_CDD(0.15)
# WHU_train, WHU_val, WHU_test = load_WHU(1.0)
DSIFN_train, DSIFN_val, DSIFN_test = load_DSIFN(1.0)

# combined_train, combined_val, combined_test = combine_dataset_splits(
#     CDD_train, CDD_val, CDD_test,
#     DSIFN_train, DSIFN_val, DSIFN_test,
#     LEVIR_train, LEVIR_val, LEVIR_test,
#     WHU_train, WHU_val, WHU_test
#   )
combined_train, combined_val, combined_test = combine_dataset_splits(
    DSIFN_train, DSIFN_val, DSIFN_test,
    LEVIR_train, LEVIR_val, LEVIR_test
  )

In [None]:
def map_to_tensors(ds, target_size=(256, 256)):
    before_paths = ds['Before Image'].tolist()
    after_paths = ds['After Image'].tolist()
    gt_paths = ds['GT Mask'].tolist()

    dataset = tf.data.Dataset.from_tensor_slices((before_paths, after_paths, gt_paths))

    def load_and_preprocess(before_path, after_path, gt_path):
        before = tf.io.read_file(before_path)
        before = tf.image.decode_png(before, channels=3)
        before = tf.image.resize(before, target_size) / 255.0

        after = tf.io.read_file(after_path)
        after = tf.image.decode_png(after, channels=3)
        after = tf.image.resize(after, target_size) / 255.0

        gt = tf.io.read_file(gt_path)
        gt = tf.image.decode_png(gt, channels=1)
        gt = tf.image.resize(gt, target_size) / 255.0

        return {"pre_images": before, "post_images": after}, gt

    dataset = dataset.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(4).prefetch(tf.data.AUTOTUNE)
    return dataset


# # Combined dataset
combined_tensor_train = map_to_tensors(combined_train)
combined_tensor_val = map_to_tensors(combined_val)
combined_tensor_test = map_to_tensors(combined_test)

Change number of desired epochs

In [None]:
EPOCHS = 15

Showing a few example before-after images with their ground truth masks.

In [None]:
display_before_after_mask(combined_train, 5)

In [None]:
# Stem Module (Start)

In [None]:
# Defining the Stem Module; based on the authors' approach, dimensions = [96, 192, 384, 768]
def stem_module(input_shape=(256, 256, 3), embedding_dim=768):
    inputs = layers.Input(shape=input_shape)

    # Applying a series of 2D convolutional layers (4 layers according to authors approach); based on the authors' approach, dimensions = [96, 192, 384, 768]
    x1 = layers.Conv2D(96, (7, 7), padding='same', activation='relu')(inputs)  # Conv Layer 1
    x1p = layers.MaxPooling2D(pool_size=(2, 2))(x1)  # 128x128

    x2 = layers.Conv2D(192, (3, 3), padding='same', activation='relu')(x1p)  # Conv Layer 2
    x2p = layers.MaxPooling2D(pool_size=(2, 2))(x2)  # 64x64

    x3 = layers.Conv2D(384, (3, 3), padding='same', activation='relu')(x2p)  # Conv Layer 3
    x3p = layers.MaxPooling2D(pool_size=(2, 2))(x3)  # 32x32

    x4 = layers.Conv2D(embedding_dim, (3, 3), padding='same', activation='relu')(x3p)  # Conv Layer 4
    x4p = layers.MaxPooling2D(pool_size=(2, 2))(x4)  # 16x16

    return models.Model(inputs=inputs, outputs=[x1, x2, x3, x4p])

In the paper it says the Stem Module (SM) operates similarly to a Vision Transformer (ViT). Should this basic SM not work as we thought, we can probably substitute for a pretrained ViT.

In [None]:
# Stem Module (End)

Feeding the inputs to the stem modules

In [None]:
pre_images = layers.Input(shape=(256, 256, 3), name="pre_images")
post_images = layers.Input(shape=(256, 256, 3), name="post_images")

stem = stem_module(input_shape=(256, 256, 3))
pre_skip1, pre_skip2, pre_skip3, pre_image_features = stem(pre_images)
post_skip1, post_skip2, post_skip3, post_image_features = stem(post_images)

The Stem Module should be functioning as intended now. These features are will be passed to the VSS blocks.

In [None]:
# Siamese Encoder + Difference Module Implementation (Start)

In [None]:
class SelectiveScan2D(layers.Layer):
        def __init__(self, **kwargs):
            super(SelectiveScan2D, self).__init__(**kwargs)
            # Parameters A, B, C, D, and delta are learnable parameters
            self.A = self.add_weight(name='A', shape=(1,), initializer='random_normal', trainable=True)
            self.B = self.add_weight(name='B', shape=(1,), initializer='random_normal', trainable=True)
            self.C = self.add_weight(name='C', shape=(1,), initializer='random_normal', trainable=True)
            self.D = self.add_weight(name='D', shape=(1,), initializer='random_normal', trainable=True)
            self.delta = self.add_weight(name='delta', shape=(1,), initializer='random_normal', trainable=True)

        def call(self, features):
          # Assuming input is a 4D tensor of shape (batch_size, height, width, channels)
          # batch_size, height, width, channels = tf.shape(features)

          shape = tf.shape(features)
          batch_size = shape[0]
          height = shape[1]
          width = shape[2]
          channels = shape[3]

          # Flatten the input into a 2D tensor (batch_size, height * width * channels)
          flattened_input = tf.reshape(features, (batch_size, -1))

          # Reshape back to 2D (height, width, channels) for processing each scan direction
          features_2d = tf.reshape(features, (batch_size, height, width, channels))

          # Processing top-left to bottom-right (Scan Direction A)
          scan_tl_br = self.A * features_2d + self.delta

          # Processing bottom-right to top-left (Scan Direction B)
          scan_br_tl = self.B * tf.reverse(features_2d, axis=[1, 2]) + self.delta

          # Processing top-right to bottom-left (Scan Direction C)
          scan_tr_bl = self.C * tf.reverse(features_2d, axis=[2]) + self.delta

          # Processing bottom-left to top-right (Scan Direction D)
          scan_bl_tr = self.D * tf.reverse(features_2d, axis=[1]) + self.delta

          # Merge results (we simply add them; can possibly experiment with different mergings later)
          merge_scans = scan_tl_br + scan_br_tl + scan_tr_bl + scan_bl_tr

          return merge_scans

class DifferenceModule(layers.Layer):
    def __init__(self, ss2D, **kwargs):
        super(DifferenceModule, self).__init__(**kwargs)
        self.ss2D = ss2D

        self.conv1 = layers.Conv2D(64, kernel_size=1, activation='relu')

        self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same', activation='relu')

        # Joint Selective Scan 2D Module
        self.joint_selective_scan_2D = JointSelectiveScan2D(self.ss2D)

        self.layernorm = layers.LayerNormalization()

        self.dense2 = layers.Dense(64, activation='relu')

        self.pool = layers.MaxPooling2D(pool_size=(2, 2), strides=2, padding='same')

    def call(self, pre_post_features):
        pre, post = pre_post_features

        pre = self.conv1(pre)
        post = self.conv1(post)

        pre = self.depthwise_conv(pre)
        post = self.depthwise_conv(post)

        # Applying joint selective scan 2D
        pre, post = self.joint_selective_scan_2D([pre, post])

        pre = self.layernorm(pre)
        post = self.layernorm(post)

        # Concatenating pre and post and passing it through a linear layer
        combined = layers.Concatenate(axis=-1)([pre, post])
        combined = self.dense2(combined)

        return combined

class JointSelectiveScan2D(tf.keras.layers.Layer):
    def __init__(self, ss2D, **kwargs):
        super(JointSelectiveScan2D, self).__init__(**kwargs)
        self.ss2D = ss2D
        self.A = ss2D.A
        self.B = ss2D.B
        self.C = ss2D.C
        self.D = ss2D.D
        self.delta = ss2D.delta

    def call(self, pre_post_features):
        pre_feat, post_feat = pre_post_features

        # Extracting shapes
        batch_size = tf.shape(pre_feat)[0]
        height = tf.shape(pre_feat)[1]
        width = tf.shape(pre_feat)[2]
        channels = tf.shape(pre_feat)[3]

        # Preprocessing
        pre_post_concat = tf.concat([pre_feat, post_feat], axis=-1)  # [B, H, W, 2C]
        post_pre_concat = tf.concat([post_feat, pre_feat], axis=-1)  # [B, H, W, 2C]

        # Defining a scan function
        def selective_scan(input_feat):
            scan_tl_br = self.A * input_feat + self.delta
            scan_br_tl = self.B * tf.reverse(input_feat, axis=[1, 2]) + self.delta
            scan_tr_bl = self.C * tf.reverse(input_feat, axis=[2]) + self.delta
            scan_bl_tr = self.D * tf.reverse(input_feat, axis=[1]) + self.delta
            return scan_tl_br + scan_br_tl + scan_tr_bl + scan_bl_tr

        # Apply scan to both concatenations
        pre_output = selective_scan(pre_post_concat)
        post_output = selective_scan(post_pre_concat)

        return pre_output, post_output

class VSSBlock(layers.Layer):
    def __init__(self, **kwargs):
        super(VSSBlock, self).__init__(**kwargs)

        # Linear Layer (we can change 128 if we want more feature space dimensionality)
        self.dense1 = layers.Dense(64, activation='relu')

        # Depth-wise Convolutional Layer
        self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same', activation='relu')

        # Selective Scan 2D Module
        self.selective_scan_2D = SelectiveScan2D()

        # Difference Module
        self.difference_module = DifferenceModule(self.selective_scan_2D)

        # Layer Normalization
        self.layernorm = layers.LayerNormalization()

        # Linear Layer (final output)
        self.dense2 = layers.Dense(64, activation='relu')

        # Downsampling (maxPooling; we can change the method of downsampling later if necessary)
        self.pool = layers.MaxPooling2D(pool_size=(2, 2), strides=2, padding='same')

    def call(self, pre_post_features, dm_vectors = None):
      if dm_vectors is None:
        dm_vectors = []
      pre_image_features = pre_post_features[0]
      post_image_features = pre_post_features[1]

      for i in range(4):
        # Processing pre-image and post-image feature vectors
        pre = self.dense1(pre_image_features)  # Applying first linear transformation
        post = self.dense1(post_image_features)

        batch_size_pre = tf.shape(pre)[0]
        pre = tf.reshape(pre, (batch_size_pre, 8, 8, 256))  # Reshaping to simulate spatial dimensions for the conv layer (adjust as needed)

        batch_size_post = tf.shape(post)[0]
        post = tf.reshape(post, (batch_size_post, 8, 8, 256))

        pre = self.depthwise_conv(pre)  # Applying depth-wise convolution
        post = self.depthwise_conv(post)
        pre = self.selective_scan_2D(pre)  # Applying SS2D
        post = self.selective_scan_2D(post)
        pre = self.layernorm(pre)  # Applying layer normalization
        post = self.layernorm(post)
        pre = self.dense2(pre)  # Final linear layer (output)
        post = self.dense2(post)
        dm_vector = self.difference_module([pre, post]) # Applying Difference Module
        dm_vectors.append(dm_vector)
        if i < 3:
          pre = self.pool(pre)  # Downsampling
          post = self.pool(post)

      # Returning feature vectors (from Difference Module) to be fed to Mask Decoder
      dm_vectors = layers.Concatenate(axis=-1)(dm_vectors)

      return dm_vectors

In [None]:
# Siamese Encoder + Difference Module Implementation (End)

In [None]:
vss = VSSBlock()
dm_vectors = vss([pre_image_features, post_image_features])

In [None]:
# Begin Mask Decoder Module

In [None]:
class CAVSSBlock(layers.Layer):
    def __init__(self, **kwargs):
        super(CAVSSBlock, self).__init__(**kwargs)

        self.vss_block = VSSBlockMaskDecoder()

        self.norm = layers.LayerNormalization()
        self.conv = layers.Conv2D(1, kernel_size=1)

        self.avg_pool = layers.AveragePooling2D(pool_size=(2, 2), strides=1, padding='same')
        self.max_pool = layers.MaxPooling2D(pool_size=(2, 2), strides=1, padding='same')

        self.sigmoid = layers.Activation('sigmoid')

    def call(self, x):
        residual = x
        avg = tf.reduce_mean(x, axis=-1, keepdims=True)

        x = self.vss_block(avg)

        x = self.norm(x)
        x = self.conv(x)

        avg_pool = self.avg_pool(x)
        max_pool = self.max_pool(x)

        combined = layers.Add()([avg_pool, max_pool])
        gate = self.sigmoid(combined)

        # Resize gate to match residual shape
        gate_resized = tf.image.resize(gate, size=tf.shape(residual)[1:3], method='bilinear')

        return gate_resized * residual

class MaskDecoder(layers.Layer):
    def __init__(self, **kwargs):
        super(MaskDecoder, self).__init__()
        self.upsample1 = layers.UpSampling2D(size=(8, 8), interpolation='bilinear')
        self.cavss1 = CAVSSBlock()

        self.upsample2 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.cavss2 = CAVSSBlock()

        self.upsample3 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.cavss3 = CAVSSBlock()

        self.final_conv = layers.Conv2D(1, 1, activation='sigmoid')  # binary segmentation

    def call(self, x, skip_features=None):

        if skip_features:
            for i, (upsample, cavss, (pre, post)) in enumerate(zip(
                [self.upsample1, self.upsample2, self.upsample3],
                [self.cavss1, self.cavss2, self.cavss3],
                skip_features
            )):
                x = upsample(x)
                skip = tf.concat([pre, post], axis=-1)
                x = tf.concat([x, skip], axis=-1)
                x = cavss(x)
        else:
            x = self.upsample1(x)
            x = self.cavss1(x)
            x = self.upsample2(x)
            x = self.cavss2(x)
            x = self.upsample3(x)
            x = self.cavss3(x)

        return self.final_conv(x)


class VSSBlockMaskDecoder(layers.Layer):
    def __init__(self, **kwargs):
        super(VSSBlockMaskDecoder, self).__init__(**kwargs)
        self.dense1 = layers.Dense(64, activation='relu')
        self.depthwise_conv = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same', activation='relu')
        self.selective_scan_2D = SelectiveScan2D()
        self.layernorm = layers.LayerNormalization()
        self.dense2 = layers.Dense(64, activation='relu')

    def call(self, features):
        x = self.dense1(features)
        # Automatically infer shape from input
        b, h, w, c = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x = tf.reshape(x, (b, h, w, c // 1))  # reshape shouldn't drop rank
        x = self.depthwise_conv(x)
        x = self.selective_scan_2D(x)
        x = self.layernorm(x)
        x = self.dense2(x)
        return x

In [None]:
# End Mask Decoder Module

Building the model and showing its summary

In [None]:
decoder = MaskDecoder()
output_mask = decoder(dm_vectors, skip_features=[
    (pre_skip3, post_skip3),
    (pre_skip2, post_skip2),
    (pre_skip1, post_skip1)
])

# Build the model
model = tf.keras.models.Model(inputs=[pre_images, post_images], outputs=output_mask)
model.summary()

Should output (256, 256, 1) corresponding to width (256) and height (256) of input images, and output channels (1) for black and white difference mask.

Instantiating the optimizer and loss function. Fitting the model.

In [None]:
# Combo loss recommended for change detection. Authors did not specify what loss function they used.
def combo_loss(y_true, y_pred):
    # Ensure shapes match and types are consistent
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Binary Crossentropy
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)

    # Reduce per image to scalar
    bce = tf.reduce_mean(bce, axis=[1, 2])

    # Dice loss per image
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice_loss = 1. - (2. * intersection + 1e-7) / (union + 1e-7)

    # Combine losses
    loss = 0.5 * bce + 0.5 * dice_loss
    return tf.reduce_mean(loss)

# According to authors
optimizer = tf.keras.optimizers.AdamW(learning_rate=6e-5, weight_decay=0.01)

model.compile(optimizer=optimizer, loss=combo_loss, metrics=['accuracy'])

history = model.fit(combined_tensor_train,
          validation_data=combined_tensor_val,
          epochs = EPOCHS,
          callbacks=[callbacks.EarlyStopping(patience=3, restore_best_weights=True)])

## Saving the model in two ways.
- Uncomment cells to save the model if desired

In [None]:
#model.save("siamese_detection_model.h5")

In [None]:
#model.save("siamese_detection_model.keras")

## Showing a few of our predictions next to before and after images, as well as ground truth images.

In [None]:
for batch in combined_tensor_test.take(10):
  inputs, ground_truth = batch

  pred = model.predict(inputs)
  pre_image = inputs["pre_images"][0]
  post_image = inputs["post_images"][0]
  pred_mask = pred[0]
  pred_mask = tf.cast(pred_mask > 0.5, tf.int32)
  true_mask = ground_truth[0]


  fig, axs = plt.subplots(1, 4, figsize=(10, 5))

  axs[0].imshow(tf.squeeze(pred_mask), cmap='gray')
  axs[0].set_title("Predicted Mask")

  axs[1].imshow(tf.squeeze(true_mask), cmap='gray')
  axs[1].set_title("Ground Truth")

  axs[2].imshow(tf.squeeze(pre_image), cmap='gray')
  axs[2].set_title("Before Image")

  axs[3].imshow(tf.squeeze(post_image), cmap='gray')
  axs[3].set_title("After Image")

  for ax in axs:
    ax.axis('off')

  plt.tight_layout()
  plt.show()

## Getting F1, IoU, and Overall Accuracy Scores based on authors' evaluation metrics.

### Additionally getting validation and validation accuracy plots.

In [None]:
from sklearn.metrics import f1_score, jaccard_score, accuracy_score

def plot_training_accuracy_curves(model, history, val_ds):

    val_loss, val_accuracy = model.evaluate(val_ds)
    print(f"Validation loss: {val_loss}, Validation Accuracy: {val_accuracy}")

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def evaluate_model(model, test_dataset):
    y_true = []
    y_pred = []

    for (inputs, y) in test_dataset:
        pred = model.predict(inputs)
        pred = tf.where(pred > 0.5, 1, 0).numpy()
        y_pred.append(pred)
        y_true.append(y.numpy())

    y_true = np.concatenate(y_true).astype(np.uint8).flatten()
    y_pred = np.concatenate(y_pred).astype(np.uint8).flatten()

    f1 = f1_score(y_true, y_pred)
    iou = jaccard_score(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)

    print(f"F1 Score: {f1}")
    print(f"IoU Score: {iou}")
    print(f"Accuracy: {accuracy}")

    return f1, iou, accuracy


In [None]:
plot_training_accuracy_curves(model, history, combined_tensor_val)

In [None]:
f1, iou, accuracy = evaluate_model(model, combined_tensor_test)