# Imports

In [1]:
import json
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np 
import pandas as pd
import seaborn as sns
import os
import gc
import rasterio as rio
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import  cm
import cv2
from matplotlib import animation
from IPython.display import HTML
from tqdm import tqdm
from sklearn.model_selection import train_test_split

2024-11-11 15:29:58.764261: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-11 15:29:58.788633: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-11 15:29:58.815468: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-11 15:29:58.823231: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-11 15:29:58.845035: I tensorflow/core/platform/cpu_feature_guar

# Config

In [2]:
class CFG:
    seed = 10
    img_size = (256,256)
    BATCH_SIZE = 3
    Autotune = tf.data.AUTOTUNE
    validation_size = 0.2
    class_dict= {0:'No Flooding', 
                 1: 'Flooding'}
    
    test_run = False 
    


# Input data

    Read more about the dataset here : https://clmrmb.github.io/SEN12-FLOOD/

In [3]:
s1_labels = './archive/sen12flood/sen12floods_s1_labels/sen12floods_s1_labels/'
s1_tiles = './archive/sen12flood/sen12floods_s1_source/sen12floods_s1_source/'

s2_tiles = './archive/sen12flood/sen12floods_s2_source/sen12floods_s2_source/'
s2_labels = './archive/sen12flood/sen12floods_s2_labels/sen12floods_s2_labels/'


s1_check = 0
for file in os.listdir(s1_labels):
    if os.path.exists(s1_tiles + '/' + file.replace('labels','source')):
        s1_check += 1
        
         
assert s1_check == len(os.listdir(s1_tiles)), 'You my friend , are definintely a idiot!'
    
s2_check = 0
for file in os.listdir(s2_labels):
    if os.path.exists(s2_tiles + '/' + file.replace('labels','source')):
        s2_check += 1
        
        
assert s2_check == len(os.listdir(s2_tiles)), 'You my friend , are definintely  the idiot!'


s1_check,s2_check 

(3332, 2237)

# Make a dataset of paths and labels

# Helper Functions

In [4]:
def load_json(path):
    '''loads a json file'''
    with open(path,'r') as file:
        js = json.load(file)
        
    return js

# collectionss1 = load_json('../input/sen12flood/sen12flood/sen12floods_s1_source/sen12floods_s1_source/collection.json')
# collections2= load_json('../input/sen12flood/sen12flood/sen12floods_s2_source/sen12floods_s2_source/collection.json')
# collections2

In [5]:
def process_label_json(label_json):
    '''process a single label json'''
    info_dict = {}
    
    info_dict['geometry'] = label_json['geometry']['coordinates']
    info_dict['label'] = label_json['properties']['FLOODING']
    info_dict['date'] = label_json['properties']['date']
    info_dict['tile_number'] = label_json['properties']['tile']
#     info_dict['full_data_coverage']= label_json['properties']['FULL-DATA-COVERAGE']
    
    return info_dict


def process_label_stac(stac_json):
    return stac_json['id']
    
    


def image_path_from_label_dir(image_parent_dir:str,
                              label_file :str)->str:
    
    return image_parent_dir + '/' + label_file.replace('labels','source')
    
    

def process_json(label_path,image_directory):
    '''get the data for a single example
     Inputs 
     label_path : path to the label folder 
     image_directory: path to the corresponding image directory'''
    
    

    #get image directory for that label
    folder_id = label_path.rsplit('/',1)[1]
    image_dir_path = image_path_from_label_dir(image_directory,folder_id)

    if not os.path.exists(image_dir_path):
        return {'File_not_found':image_dir_path}
    
    
    for file in os.listdir(label_path):
        #if image dir exists 
        if file.startswith('labels'):
            label_json = load_json(os.path.join(label_path,file))
        else:
            stac_json = load_json(os.path.join(label_path,file))


    #get data 
    info_dict = process_label_json(label_json)

    #get id 
    info_dict['id'] = process_label_stac(stac_json)
    
    #location id 
    info_dict['location_id'] = info_dict['id'].split('_')[3]
    
    
    info_dict['image_dir'] = image_dir_path
    
    
    return info_dict


In [6]:
def get_dataframe(label_directory,image_directory):
    '''get dataframe from the nested label directory'''
    records = []
    
        
    for folder in os.listdir(label_directory):
        if folder.startswith('sen12'):
#             print(folder,label_directory)
            folder_path = label_directory + '/' + folder
            
            
            #get data for a single example
            feature = process_json(label_path=folder_path,
                                   image_directory=image_directory)
            
            
            records.append(feature)
            
            
    return pd.DataFrame.from_records(data = records)



def type_cast_dataset(dataset):
    '''typecasting columns in dataset'''
    dataset['label'] = dataset['label'].astype(int)
    
    dataset['date'] = pd.to_datetime(dataset['date'])
    dataset['tile_number'] = dataset['tile_number'].astype('int8')
    
    
    return dataset
    

In [7]:
%%time
s1_data = type_cast_dataset(
                            get_dataframe(
                                label_directory=s1_labels,
                                image_directory=s1_tiles
                                        )
                            )


s2_data = type_cast_dataset(
                            get_dataframe(label_directory=s2_labels,
                                          image_directory=s2_tiles)
                            )

print(f'Number of unique locations in Sentinel1 (SAR) data : {s1_data.location_id.nunique()}')
print(f'Number of unique locations in Sentinel2 (optical) data : {s2_data.location_id.nunique()}')

s1_data.shape,s2_data.shape

Number of unique locations in Sentinel1 (SAR) data : 335
Number of unique locations in Sentinel2 (optical) data : 335
CPU times: user 357 ms, sys: 193 ms, total: 550 ms
Wall time: 549 ms


((3331, 7), (2236, 7))

In [8]:
# saving datasets
s1_data.to_csv('s1_data.csv',index=False)
s2_data.to_csv('s2_data.csv',index=False)

In [9]:
def load_raster(filepath):
    '''load a single band raster'''
    with rio.open(filepath) as file: 
        raster = file.read().squeeze(axis=0)
        
    return raster



**Loading multiple raster bands as single raster**

In [10]:
def load_s1_tiffs(folder,
                  scaling_values=[50.,100.]):
    images = []
    i = 0
    for im in sorted(os.listdir(folder)):
         
        if im.rsplit('.',maxsplit=1)[1] == 'tif':
            
            path = folder + '/' + im
            band = load_raster(path)
            band = band / scaling_values[i]
            
            band = cv2.resize(band,
                              CFG.img_size)
            
            images.append(band)
            i+=1 
                    
    return np.dstack(images)


def load_s2_tiffs(folder,
                  scaling_value=10000.):
    images = []
    for im in sorted(os.listdir(folder)):
        if im.rsplit('.',maxsplit=1)[1] == 'tif':    
            path = folder + '/' + im
            band = load_raster(path)
            band = band/ scaling_value
            
            band = cv2.resize(band,CFG.img_size)
            images.append(band)   

    return np.dstack(images)
                    
def load_rgb_tiffs(folder,
                  scaling_value=10000.):
    '''load R,G and B bands'''
    
    images = []
    for im in sorted(os.listdir(folder)):
        name,file_format = im.rsplit('.',maxsplit=1)
        if ((file_format== 'tif') and (name in ['B02','B03','B04'])):    
            path = folder + '/' + im
            band = load_raster(path)
            band = band/ scaling_value
            
            band = cv2.resize(band,CFG.img_size)
            images.append(band)   

    return np.dstack(images)[:,:,::-1]


    
def tf_load_s1(path):    
    path = path.numpy().decode('utf-8')
    return load_s1_tiffs(path)
    
    

def tf_load_s2(path):    
    path = path.numpy().decode('utf-8')
    return load_s2_tiffs(path)


def tf_load_rgb(path):    
    path = path.numpy().decode('utf-8')
    return load_rgb_tiffs(path)
    
def process_image_s1(filename):
    '''function for preprocessing in tensorflow data'''
    
    image = tf.py_function(tf_load_s1, [filename], tf.float32)
    image.set_shape([None, None, 2])  # Explicitly set the shape (assuming grayscale images)
    return image



def process_image_s2(filename):
    '''function for preprocessing in tensorflow data'''
    
    return tf.py_function(tf_load_s2, 
                          [filename], 
                          tf.float32)



def process_image_rgb(filename):
    '''function for preprocessing in tensorflow data'''
    
    image = tf.py_function(tf_load_rgb, 
                          [filename], 
                          tf.float32)
    image.set_shape([None,None,3])
    return image
    

In [11]:
def count_rasters_in_folder(path):
    count = 0 
    
    for file in os.listdir(path):
        if file.rsplit('.',1)[1] == 'tif':
            count +=1 
            
    return count 
    
    
s2_data['raster_count'] = s2_data.image_dir.apply(lambda x : count_rasters_in_folder(x))

#value counts 
s2_data['raster_count'].value_counts()


s2_data=s2_data[s2_data['raster_count']==12] # take only valid rasters
# s2_data[s2_data['raster_count']==0]['location_id'].value_counts()

# Making a TF dataset

    First lets split the dataset into training and validation set. We will stratify based on location id to ensure that locations are well represented in traininng and validation set

In [12]:
#isolating single loaction ids (as they will be a problem for stratification)

# single example locations 
single_index = s2_data['location_id'].value_counts()[s2_data['location_id'].value_counts()==1].index

single_index_df = s2_data[s2_data['location_id'].isin(single_index)].reset_index(drop=True)
s2_data0 = s2_data[~(s2_data['location_id'].isin(single_index))].reset_index(drop=True)

s2_data0.shape,single_index_df.shape

((2126, 8), (12, 8))

**Split dataset into train and validation splits**

In [13]:
import pandas as pd
from sklearn.model_selection import train_test_split  # Assuming you have scikit-learn installed

# Train-test split with stratification
s1_data_tr, s1_data_val = train_test_split(s1_data,
                                          test_size=CFG.validation_size,
                                          random_state=CFG.seed,
                                          stratify=s1_data.location_id)

s2_data_tr, s2_data_val = train_test_split(s2_data0,
                                          test_size=CFG.validation_size,
                                          random_state=CFG.seed,
                                          stratify=s2_data0.location_id)

# Concatenate s2_data_tr and single_index_df while resetting index
s2_data_tr = pd.concat([s2_data_tr, single_index_df], ignore_index=True)

# No need for del or gc.collect() here (automatic memory management)

In [14]:
s1_data_tr.label.value_counts(1),s1_data_val.label.value_counts(1)

(label
 0    0.693318
 1    0.306682
 Name: proportion, dtype: float64,
 label
 0    0.67916
 1    0.32084
 Name: proportion, dtype: float64)

In [15]:
s2_data_tr.label.value_counts(1),s2_data_val.label.value_counts(1)

(label
 0    0.768107
 1    0.231893
 Name: proportion, dtype: float64,
 label
 0    0.746479
 1    0.253521
 Name: proportion, dtype: float64)

**Function for image augmentations**

    Although the Augmentations are simple, we cannot use them on SAR images , as even simple operations like flipping can change the meaning of the image

In [16]:
def augment_image_multispectral(image):
    '''perform simple image augmentations'''
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_crop(image, size=(*CFG.img_size,12))
    
    rot = tf.random.normal((1,),mean = 0.35, stddev=0.15)
    
    if rot > 0.5:
        image = tf.image.rot90(image)

    return image 

def augment_image(image):
    '''perform simple image augmentations'''
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_crop(image, size=(*CFG.img_size,3))
    
    rot = tf.random.normal((1,),mean = 0.35, stddev=0.15)
    
    if rot > 0.5:
        image = tf.image.rot90(image)

    return image 

In [17]:
def get_tf_dataset(image_paths,
                   labels=None, # put none for test data set
                   image_processing_fn=None,
                   augment_fn = None
                  ):
    
    '''returns a tf dataset object
    Inputs: 
    image_paths : paths to images
    labels: labels of each image
    image_processing_fn:  function to load and preprocess images 
    augment_fn : function to augment images '''
    
    #seperate datasets
    if labels is not None:
        labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
    
    
    
    image_dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    #load images 
    image_dataset = image_dataset.map(image_processing_fn,
                                      num_parallel_calls=tf.data.AUTOTUNE)
     
    if augment_fn is not None:
        
        image_dataset = image_dataset.map(augment_fn,
                                          num_parallel_calls=tf.data.AUTOTUNE)
     
    
    if labels is not None:
        return tf.data.Dataset.zip((image_dataset,labels_dataset))
    
    
    return image_dataset



def optimize_pipeline(tf_dataset,
                      batch_size = CFG.BATCH_SIZE,
                      Autotune_fn = CFG.Autotune,
                      cache= False,
                      batch = True):
    
    
    
    # prefetch(load the data with cpu,while gpu is training) the data in memory 
    tf_dataset = tf_dataset.prefetch(buffer_size=Autotune_fn)  
    if cache:
        tf_dataset = tf_dataset.cache()                        # store data in RAM  
        
    tf_dataset =  tf_dataset.shuffle(buffer_size=50)         # shuffle 
    
    if batch:
        tf_dataset = tf_dataset.batch(batch_size)              #split the data in batches  
    
    return tf_dataset

**Making dataset pipelines with TF data**

In [18]:
# Sentinel 1 dataset (not using augmentation here)

S1_dataset_tr = optimize_pipeline(tf_dataset=get_tf_dataset(image_paths = s1_data_tr.image_dir.values,
                                               labels = tf.one_hot(s1_data_tr.label,depth=2),
                                               image_processing_fn = process_image_s1),
                                  
                                  batch_size = 3 * CFG.BATCH_SIZE)


S1_dataset_val = optimize_pipeline(tf_dataset=get_tf_dataset(image_paths = s1_data_val.image_dir.values,
                                                labels = tf.one_hot(s1_data_val.label,depth=2),
                                                image_processing_fn = process_image_s1),
                                   batch_size = 3* CFG.BATCH_SIZE)

2024-11-11 15:30:03.403462: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 36383 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:3b:00.0, compute capability: 8.6


In [19]:
#sentinel 2 dataset 
S2_dataset_tr = optimize_pipeline(get_tf_dataset(image_paths = s2_data_tr.image_dir.values,
                                                   labels = s2_data_tr.label,
                                                   image_processing_fn = process_image_s2,
                                                   augment_fn = augment_image_multispectral)
                                 )


S2_dataset_val = optimize_pipeline(get_tf_dataset(image_paths = s2_data_val.image_dir.values,
                                                   labels = s2_data_val.label,
                                                   image_processing_fn = process_image_s2,
                                                   augment_fn = augment_image_multispectral)
                                  )




In [20]:
RGB_dataset_tr = optimize_pipeline(get_tf_dataset(image_paths = s2_data_tr.image_dir.values,
                                                   labels = s2_data_tr.label,
                                                   image_processing_fn = process_image_rgb,
                                                   augment_fn = augment_image),
                                   batch_size = 3* CFG.BATCH_SIZE
                                 )


RGB_dataset_val = optimize_pipeline(get_tf_dataset(image_paths = s2_data_val.image_dir.values,
                                                   labels = s2_data_val.label,
                                                   image_processing_fn = process_image_rgb,
                                                   augment_fn = augment_image),
                                    batch_size = 3* CFG.BATCH_SIZE
                                  )



**Checking to see if the Pipelines work as expected**

In [21]:
for x,y in S1_dataset_val.take(1): # take one batch for checking 
    print(f'shape of SAR dataset input(val) {x.shape}')

shape of SAR dataset input(val) (9, 256, 256, 2)


2024-11-11 15:30:06.713587: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [22]:
for x,y in S1_dataset_tr.take(1): # take one batch for checking 
    print(f'shape of SAR dataset input {x.shape}')

shape of SAR dataset input (9, 256, 256, 2)


2024-11-11 15:30:07.837010: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [23]:
for x,y in S2_dataset_tr.take(1): # take one batch for checking 
    print(f'shape of MultiSpectral dataset input {x.shape}')

shape of MultiSpectral dataset input (3, 256, 256, 12)


In [24]:
for x,y in RGB_dataset_tr.take(1): # take one batch for checking 
    print(f'shape of MultiSpectral dataset input {x.shape}')

shape of MultiSpectral dataset input (9, 256, 256, 3)


2024-11-11 15:30:19.236041: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


#  CNN Models    CNN models to identify flooding in opotical and SAR images

In [25]:
print(tf.__version__)

2.17.0


In [32]:
def multichannel_cnn(num_channels:int,
                     hidden_units:int, #number of  hidden dense 
                     weights = None  # none for random init, use imagenet for imagenet weights 
                    ):
    '''model that takes multiple channel as input, instead of using the rgb channels as by default'''
    
    
    # backbone = tf.keras.applications.resnet_v2.ResNet50V2(
    #                                         include_top=False,
    #                                         input_shape = (*CFG.img_size,num_channels),
    #                                         weights=weights,
    #                                         pooling = 'avg')
    
    def add_constant_channel(x):
        # Create a constant tensor with the same batch size and spatial dimensions, and value 0
        constant_channel = tf.zeros_like(x[:, :, :, :1])  # Shape (224, 224, 1)
        # Concatenate the constant channel to the existing two channels
        return tf.concat([x, constant_channel], axis=-1)

    if num_channels == 2:
    # Apply the lambda layer
        input_layer = tf.keras.layers.Input(shape=(*CFG.img_size, 2))
        input_tensor = tf.keras.layers.Lambda(add_constant_channel)(input_layer)
    else:
        input_tensor = tf.keras.layers.Input(shape=(*CFG.img_size, num_channels))
        
    backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(
                                            include_top=False,
                                            input_tensor=input_tensor,
                                            input_shape = input_tensor.shape[0],
                                            weights='imagenet',
                                            pooling = 'avg')
    
    x = tf.keras.layers.BatchNormalization()(backbone.output)
    x = tf.keras.layers.Dense(hidden_units, activation = 'relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dropout(rate = 0.2)(x)

    final_out = tf.keras.layers.Dense(2, activation = 'softmax')(x)
    
    #make a model 
    model = tf.keras.Model(inputs = backbone.input, 
                  outputs = final_out)
    
    return model 

# plot train and val acc as  a function of epochs
def plot_history(history,addn_metric=None):
    '''
    Inputs
    history:history object from tensorflow
    add_metric: metric name in the history (like f1_score)'''
    his=pd.DataFrame(history.history)
    
    if addn_metric:
        plt.subplots(1,3,figsize=(20,6))
        #loss:
        ax1=plt.subplot(1,3,1)
        ax1.plot(range(len(his)),his['loss'],color='g',label='training')
        ax1.plot(range(len(his)),his['val_loss'],color='r',label='validation')
        ax1.set_xlabel('EPOCHS')
        ax1.set_ylabel('LOSS')
        ax1.legend()
        ax1.set_title('Loss Per Epoch')
        #accuracy
        ax2=plt.subplot(1,3,2)
        ax2.plot(range(len(his)),his['accuracy'],color='g',label='training_acc')
        ax2.plot(range(len(his)),his['val_accuracy'],color='r',label='validation_acc')
        ax2.set_xlabel('EPOCHS')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.set_title('Accuracy Per Epoch')
        
        ax3= plt.subplot(1,3,3)
        ax3.plot(range(len(his)),his[f'{addn_metric}'],color='g',label='training')
        ax3.plot(range(len(his)),his[f'val_{addn_metric}'],color='r',label='validation')
        ax3.set_xlabel('EPOCHS')
        ax3.set_ylabel(f'{addn_metric}')
        ax3.legend()
        ax3.set_title(f'{addn_metric} Per Epoch')
    else:
        plt.subplots(1,2,figsize=(20,8))
        #loss:
        ax1=plt.subplot(1,2,1)
        ax1.plot(range(len(his)),his['loss'],color='g',label='training')
        ax1.plot(range(len(his)),his['val_loss'],color='r',label='validation')
        ax1.set_xlabel('EPOCHS')
        ax1.set_ylabel('LOSS')
        ax1.legend()
        ax1.set_title('Loss Per Epoch')
        #accuracy
        ax2=plt.subplot(1,2,2)
        ax2.plot(range(len(his)),his['accuracy'],color='g',label='training_acc')
        ax2.plot(range(len(his)),his['val_accuracy'],color='r',label='validation_acc')
        ax2.set_xlabel('EPOCHS')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.set_title('Accuracy Per Epoch')
    plt.show()  

# Metrics 

In [33]:


#from https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model

def recall_m(y_true, y_pred):
    y_true = K.cast(y_true, dtype='float32')  # Ensure both y_true and y_pred are float32
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    y_true = K.cast(y_true, dtype='float32')  # Ensure both y_true and y_pred are float32
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_score(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    
    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

In [34]:
SAR_CNN = multichannel_cnn(num_channels = 2,
                           hidden_units = 512, #number of  hidden dense 
                          )



SAR_CNN.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
                loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False),
                metrics = ['accuracy',f1_score,recall_m,precision_m]
               )

#check on some data 
for x,y in S1_dataset_val.take(1): # take one batch for checking 
    SAR_CNN(x)


# !mkdir CNN_models

  backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(


In [35]:
# from tensorflow.keras.utils import plot_model
# plot_model(SAR_CNN, to_file="multichannel_cnn_architecture.png", show_shapes=True, show_layer_names=True)

# Defining callbacks

In [36]:
EPOCHS = 2 if CFG.test_run else 75
# callbacks 
#reduce_learning rate
reduce_lr=tf.keras.callbacks.ReduceLROnPlateau(patience=5,
                                                factor=0.8,
                                                min_delta=1e-2,
                                                monitor='val_accuracy',
                                                verbose=1,
                                                mode='max')

#early stopping 
early_stopping=tf.keras.callbacks.EarlyStopping(patience=10,
                                              min_delta=1e-3,
                                              monitor='val_accuracy',
                                              restore_best_weights=True,
                                              mode='max')


# exponential decay 

def lr_scheduler(epoch, lr):
    '''learning rate scheduler, decays expo after the tenth epoch'''

    if epoch < 10:
        return lr
    else:
        return float(lr * tf.math.exp(-0.1))
    

    
learning_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_scheduler)


callbacks_1= [reduce_lr,early_stopping,learning_scheduler]

# Building and training the SAR CNN

**Training on SAR data**

In [37]:
hist1 = SAR_CNN.fit(S1_dataset_tr,
                    validation_data = S1_dataset_val,
                    epochs = EPOCHS,
                    callbacks = callbacks_1
                   )


#save model
sar_model_path = 'CNN_models/SAR_CNN.h5'
SAR_CNN.save(filepath = 'CNN_models/SAR_CNN.h5')


#plot history 
plot_history(hist1,'f1_score')

Epoch 1/75
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 153ms/step - accuracy: 0.6652 - f1_score: 0.6652 - loss: 0.8314 - precision_m: 0.6652 - recall_m: 0.6652 - val_accuracy: 0.4588 - val_f1_score: 0.4652 - val_loss: 1.9344 - val_precision_m: 0.4652 - val_recall_m: 0.4652 - learning_rate: 1.0000e-04
Epoch 2/75
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 131ms/step - accuracy: 0.7512 - f1_score: 0.7512 - loss: 0.5922 - precision_m: 0.7512 - recall_m: 0.7512 - val_accuracy: 0.4903 - val_f1_score: 0.4963 - val_loss: 0.6578 - val_precision_m: 0.4963 - val_recall_m: 0.4963 - learning_rate: 1.0000e-04
Epoch 3/75
[1m296/296[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 131ms/step - accuracy: 0.7857 - f1_score: 0.7857 - loss: 0.5035 - precision_m: 0.7857 - recall_m: 0.7857 - val_accuracy: 0.3973 - val_f1_score: 0.4044 - val_loss: 0.8285 - val_precision_m: 0.4044 - val_recall_m: 0.4044 - learning_rate: 1.0000e-04
Epoch 4/75
[1m296/296[0

: 

In [None]:
# SAR_CNN.summary()


****Evaluate on validation dataset****

In [None]:
SAR_CNN.evaluate(S1_dataset_val)

# del SAR_CNN;gc.collect()

K.clear_session()

[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 85ms/step - accuracy: 0.8285 - f1_score: 0.8283 - loss: 0.8320 - precision_m: 0.8283 - recall_m: 0.8283


# Building and training the RGB - CNN

In [39]:
RGB_CNN = multichannel_cnn(num_channels = 3,
                           hidden_units = 512, #number of  hidden dense
                           weights = 'imagenet'
                          )


RGB_CNN.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.001),
                loss = 'sparse_categorical_crossentropy',
                metrics = ['accuracy',f1_score,recall_m,precision_m])

  backbone = tf.keras.applications.mobilenet_v2.MobileNetV2(


In [None]:
hist2 = RGB_CNN.fit(RGB_dataset_tr,
                    validation_data = RGB_dataset_val,
                    epochs = EPOCHS,
                    callbacks = callbacks_1)

#save model

RGB_CNN.save(filepath = 'CNN_models/RGB_CNN.h5')

plot_history(hist2,'f1_score')

Epoch 1/75


I0000 00:00:1731308893.686633 3131987 service.cc:146] XLA service 0x7fee700028f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1731308893.686683 3131987 service.cc:154]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6
2024-11-11 15:08:15.230074: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-11-11 15:08:16.772873: W tensorflow/core/framework/op_kernel.cc:1840] OP_REQUIRES failed at xla_ops.cc:577 : INVALID_ARGUMENT: Incompatible shapes: [9] vs. [9,2]
	 [[{{node mul_1}}]]
	tf2xla conversion failed while converting __inference_one_step_on_data_40786[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.


InvalidArgumentError: Graph execution error:

Detected at node mul_1 defined at (most recent call last):
<stack traces unavailable>
Incompatible shapes: [9] vs. [9,2]
	 [[{{node mul_1}}]]
	tf2xla conversion failed while converting __inference_one_step_on_data_40786[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
	 [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_42023]

In [None]:
#delete 
del RGB_CNN;gc.collect()

K.clear_session()

# Checking the GradCAM and saliency Maps for RGB CNN

In [None]:


plt.subplots(4,4,figsize=(8*3,8*3))
n = 4
idx= 1


for images,labels in RGB_dataset_val.shuffle(buffer_size=12).take(1):

    for i in range(4):
        #get label 
        img = images[i]
        lab = int(labels[i].numpy())


#         print(img.shape,lab.shape)
        score1 = CategoricalScore(lab)



        #predict on image
        prd= np.argmax(RGB_CNN.predict(img[tf.newaxis,:,:,:]))


        plt.subplot(4,4,idx)
        plt.title(f'orignal image ({CFG.class_dict[lab]})')
        plt.axis('off')
        plt.imshow(img)
        idx+=1

        #saliency

        plt.subplot(4,4,idx)
        plt.title(f'predicted {CFG.class_dict[prd]}(saliency map)')
        sal = get_saliency(img,
                           score1,
                           cnn_model = RGB_CNN).squeeze(axis=0)
        
#         print(sal.shape)
        plt.axis('off')
        plt.imshow(img)
        plt.imshow(sal,alpha=0.45,cmap='jet') #overlay
        idx+=1

        #gradcam
        plt.subplot(4,4,idx)
        gdcam = get_gradcam(img,
                            score1,
                           cnn_model = RGB_CNN)
        plt.imshow(img)
        plt.imshow(gdcam,alpha=0.30,cmap='jet') #overlay
        plt.title(f'predicted {CFG.class_dict[prd]}(gradcam)')
        plt.axis('off')
        idx+=1


        #gradcam ++
        plt.subplot(4,4,idx)
        gdcam_pls = get_gradcam_plus(img,
                                     score1,
                                     model = RGB_CNN)
        plt.imshow(img)
        plt.imshow(gdcam_pls,alpha=0.30,cmap='jet') #overlay
        plt.title(f'predicted {CFG.class_dict[prd]}(gradcam + +)')
        plt.axis('off')
        idx+=1

        if idx>16:
            break

    plt.tight_layout()
    plt.show()


In [None]:
plt.subplots(4,4,figsize=(8*3,8*3))
n = 4
idx= 1


for images,labels in RGB_dataset_val.shuffle(buffer_size=12).take(1):

    for i in range(4):
        #get label 
        img = images[i]
        lab = int(labels[i].numpy())


#         print(img.shape,lab.shape)
        score1 = CategoricalScore(lab)



        #predict on image
        prd= np.argmax(RGB_CNN.predict(img[tf.newaxis,:,:,:]))


        plt.subplot(4,4,idx)
        plt.title(f'orignal image ({CFG.class_dict[lab]})')
        plt.axis('off')
        plt.imshow(img)
        idx+=1

        #saliency

        plt.subplot(4,4,idx)
        plt.title(f'predicted {CFG.class_dict[prd]}(saliency map)')
        sal = get_saliency(img,
                           score1,
                           cnn_model = RGB_CNN).squeeze(axis=0)
        
#         print(sal.shape)
        plt.axis('off')
        plt.imshow(img)
        plt.imshow(sal,alpha=0.45,cmap='jet') #overlay
        idx+=1

        #gradcam
        plt.subplot(4,4,idx)
        gdcam = get_gradcam(img,
                            score1,
                           cnn_model = RGB_CNN)
        plt.imshow(img)
        plt.imshow(gdcam,alpha=0.30,cmap='jet') #overlay
        plt.title(f'predicted {CFG.class_dict[prd]}(gradcam)')
        plt.axis('off')
        idx+=1


        #gradcam ++
        plt.subplot(4,4,idx)
        gdcam_pls = get_gradcam_plus(img,
                                     score1,
                                     model = RGB_CNN)
        plt.imshow(img)
        plt.imshow(gdcam_pls,alpha=0.30,cmap='jet') #overlay
        plt.title(f'predicted {CFG.class_dict[prd]}(gradcam + +)')
        plt.axis('off')
        idx+=1

        if idx>16:
            break

    plt.tight_layout()
    plt.show()

