In [5]:
"""
Import libraries
"""

import pandas as pd
import math
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.layers import (
    BatchNormalization, Conv2D, MaxPooling2D, Flatten, Dropout, Dense,Input,  MaxPool2D, GlobalAveragePooling2D, Layer, Add
)
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import glob
from tqdm import tqdm
import cv2
import sklearn
import skimage
from sklearn.model_selection import train_test_split
from skimage.transform import resize
import random
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score
import seaborn as sns
sns.set()

In [9]:
"""
Choose dataset and model to train
"""

#Dataset to load or build
LOAD_DATASET = False
SMALL_DATASET = False
#RGB and IR images. In order to enable RGB, IR or combination, SMALL_DATASET must be set to False
RGB = False
IR = False
COMBI = True

if SMALL_DATASET:
    num_classes=2
else:
    num_classes=3

#Model to train: choose 'CNN','ResNet18' or 'ResNet_CNN'
MODEL = 'ResNet_CNN'

MODEL_FILE_NAME = 'largedata_ResNet_CNN'
#Hyperparameters
BATCH_SIZE = 8 
EPOCHS = 200

#Paths
RGB_data = '../Data/254p RGB Images/'
IR_data = '../Data/254p Thermal Images/'
small = "../Data/forest_fire/All"

In [10]:
"""
Function definition: load or build dataset
"""
#Fire (Y/N) indicates whether or not there is fire visible in 254p RGB and/or 254p Thermal frame
#Smoke (Y/N) indicates whether smoke fills >= 50% of the 254p RGB frame (visual estimate)

#Functions definition
def get_large_data(path):
    y = []
    x = []
    IDs = []
    files = os.listdir(path)
    for i, file in tqdm(enumerate(files)):
        FileName = os.path.join(path, file)
        ID =  re.findall(r"\d+",FileName[-10:])[0]
        if len(ID)>=3 and int(ID[-1])!=0:
            if path==RGB_data:
                os.rename(path + "/" +file,'../Data/RGB duplicates/'+file)
            if path == IR_data:
                os.rename(path + "/" +file,'../Data/Thermal duplicates/'+file) 
        else:
            ID = int(ID)
            IDs.append(ID)
            if ID in range(1,13701):
                y.append([1,0,0]) #NN = No fire, no smoke
            if ID in range(13701,14700) or ID in range(15981,19803) or ID in range(19900,27184) or ID in range(27515,31295) or ID in range(31510,33598) or ID in range(33930,36551) or ID in range(38031,38154) or ID in range(41642,45280) or ID in range(51207,52287):
                y.append([0,1,0]) # YY = Yes fire, Yes smoke
            if ID in range(14700,15981) or ID in range(19803,19900) or ID in range(27184,27515) or ID in range(31295,31510) or ID in range(33598,33930) or ID in range(36551,38031) or ID in range(38154,41642) or ID in range(45280,51207) or ID in range(52287,53452):
                y.append([0,0,1]) #YN = Yes fire, no smoke. 
        
        img_file = cv2.imread(path + "/" +file)
        if img_file is not None:
            img_arr = np.asarray(img_file)
            x.append(img_arr)
    x = np.asarray(x)
    y = np.asarray(y)
    return x,y,IDs


def get_small_data(folder):
    x = []
    y = []
    for folderName in os.listdir(folder):
        if not folderName.startswith("."):
            if folderName in ["nofire"]:
                label = [0,1]
            elif folderName in ["fire"]:
                label = [1,0]
            for image_filename in tqdm(os.listdir(folder +"/" +folderName+"/")):
                img_file = cv2.imread(folder + "/" +folderName + "/" + image_filename)
                if img_file is not None:
                    #img_file = skimage.transform.resize(img_file,(227,227,3), mode = "constant",anti_aliasing=True)
                    img_arr = np.asarray(img_file)
                    x.append(img_arr)
                    y.append(label)
    x = np.asarray(x)
    y = np.asarray(y)
    return x,y





In [None]:
"""
Load or build dataset
"""

if SMALL_DATASET:
    if LOAD_DATASET:
        X = np.load("../Data/X_small.npy")
        y = np.load("../Data/y_small.npy")
        
    else:
        X,y = get_small_data(small)
    
        np.save("../Data/X_small.npy",X)
        np.save("../Data/y_small.npy",y)
  
    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,shuffle=True)
    X_train, X_valid, y_train, y_valid = train_test_split(X_train,y_train,test_size=0.2,shuffle=True)
    # Image Normalization
    X_train, X_valid, X_test = X_train / 255.0, X_valid / 255.0, X_test / 255.0
    n=X_train.shape[1]
    inputshape = (None,n,n,3)
else:
    if LOAD_DATASET:
        X_RGB = np.load("../Data/X_RGB.npy")
        Y_RGB = np.load("../Data/Y_RGB.npy")
        IDs_RGB = np.load("../Data/IDs_RGB.npy")
        X_IR = np.load("../Data/X_IR.npy")
        Y_IR = np.load("../Data/Y_IR.npy")
        IDs_IR = np.load("../Data/IDs_IR.npy")
        
    else:
        X_RGB,Y_RGB,IDs_RGB = get_large_data(RGB_data)
        X_IR,Y_IR,IDs_IR = get_large_data(IR_data)
    
        np.save("../Data/X_RGB.npy",X_RGB)
        np.save("../Data/Y_RGB.npy",Y_RGB)
        np.save("../Data/IDs_RGB.npy",IDs_RGB)
        np.save("../Data/X_IR.npy",X_IR)
        np.save("../Data/Y_IR.npy",Y_IR)
        np.save("../Data/IDs_IR.npy",IDs_IR)

    # Split the data
    X_RGB_train, X_RGB_test, y_RGB_train, y_RGB_test = train_test_split(X_RGB,Y_RGB,test_size=0.2,shuffle=True)
    X_RGB_train, X_RGB_valid, y_RGB_train, y_RGB_valid = train_test_split(X_RGB_train,y_RGB_train,test_size=0.2,shuffle=True)
    # Image Normalization
    X_RGB_train, X_RGB_valid, X_RGB_test = X_RGB_train / 255.0, X_RGB_valid / 255.0, X_RGB_test / 255.0
    
    # Split the data
    X_IR_train, X_IR_test, y_IR_train, y_IR_test = train_test_split(X_IR,Y_IR,test_size=0.2,shuffle=True)
    X_IR_train, X_IR_valid, y_IR_train, y_IR_valid = train_test_split(X_IR_train,y_IR_train,test_size=0.2,shuffle=True)
    # Image Normalization
    X_IR_train, X_IR_valid, X_IR_test = X_IR_train / 255.0, X_IR_valid / 255.0, X_IR_test / 255.0
    n=X_IR_train.shape[1]
    inputshape = (None,n,n,3)
if COMBI:
    #reorder the data for combi CNN
    X_RGBs=np.zeros_like(X_RGB)
    for i in range(len(IDs_RGB)):
        j=np.argwhere(np.array(IDs_RGB)==IDs_IR[i])[0][0]
        X_RGBs[i]=X_RGB[j]

    # Split the data
    X_RGB_train,X_RGB_test,X_IR_train, X_IR_test, y_train, y_test = train_test_split(X_RGBs,X_IR,Y_IR,test_size=0.2,shuffle=True)
    X_RGB_train,X_RGB_valid,X_IR_train, X_IR_valid, y_train, y_valid = train_test_split(X_RGB_train,X_IR_train,y_train,test_size=0.2,shuffle=True)
    # Image Normalization
    X_RGB_train, X_RGB_valid, X_RGB_test = X_RGB_train / 255.0, X_RGB_valid / 255.0, X_RGB_test / 255.0
    X_IR_train, X_IR_valid, X_IR_test = X_IR_train / 255.0, X_IR_valid / 255.0, X_IR_test / 255.0
    n=X_IR_train.shape[1]
    inputshape = [(None,n,n,3),(None,n,n,3)]

In [12]:
"""
Models definition using classes: CNN, CNN_combi, ResNet18, Resnet18_combi and ResNet_CNN
"""
class CNN(Model):
    def __init__(self, channels: int, **kwargs):
        super().__init__(**kwargs)
        self.cnn_conv1=Conv2D(96,(11,11),strides=(4, 4),activation="relu")
        self.cnn_bn1=BatchNormalization()
        self.cnn_pool1=MaxPooling2D((3,3), strides=(2,2))

        self.cnn_conv2=Conv2D(256,(5,5),activation="relu",padding="same")
        self.cnn_bn2=BatchNormalization()
        self.cnn_pool2=MaxPooling2D((3,3), strides=(2,2))

        self.cnn_conv3=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn_bn3=BatchNormalization()

        self.cnn_conv4=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn_bn4=BatchNormalization()

        self.cnn_conv5=Conv2D(256,(3,3),activation="relu",padding="same")
        self.cnn_bn5=BatchNormalization()
        self.cnn_pool3=MaxPooling2D((3,3), strides=(2,2))
        
        self.cnn_flat=Flatten()
        
        # Fully connected
        self.fc1=Dense(4096,activation="relu")
        self.drop1=Dropout(0.5)

        self.fc2=Dense(4096,activation="relu")
        self.drop2=Dropout(0.5)
        self.fc = Dense(num_classes, activation="softmax")
    
    def call(self, inputs):
        
        #cnn bloc
        out0=self.cnn_conv1(inputs)
        out0=self.cnn_bn1(out0)
        out0=self.cnn_pool1(out0)
        out0=self.cnn_conv2(out0)
        out0=self.cnn_bn2(out0)
        out0=self.cnn_pool2(out0)
        out0=self.cnn_conv3(out0)
        out0=self.cnn_bn3(out0)
        out0=self.cnn_conv4(out0)
        out0=self.cnn_bn4(out0)
        out0=self.cnn_conv5(out0)
        out0=self.cnn_bn5(out0)
        out0=self.cnn_pool3(out0)
        out0=self.cnn_flat(out0)
      

        # Fully connected
        out=self.fc1(out0)
        out=self.drop1(out)

        out=self.fc2(out)
        out=self.drop2(out)
        out=self.fc(out)
        return out

class CNN_combi(Model):
    def __init__(self, channels: int, **kwargs):
        super().__init__(**kwargs)
        self.cnn0_conv1=Conv2D(96,(11,11),strides=(4, 4),activation="relu")
        self.cnn0_bn1=BatchNormalization()
        self.cnn0_pool1=MaxPooling2D((3,3), strides=(2,2))
        self.cnn0_conv2=Conv2D(256,(5,5),activation="relu",padding="same")
        self.cnn0_bn2=BatchNormalization()
        self.cnn0_pool2=MaxPooling2D((3,3), strides=(2,2))
        self.cnn0_conv3=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn0_bn3=BatchNormalization()
        self.cnn0_conv4=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn0_bn4=BatchNormalization()
        self.cnn0_conv5=Conv2D(256,(3,3),activation="relu",padding="same")
        self.cnn0_bn5=BatchNormalization()
        self.cnn0_pool3=MaxPooling2D((3,3), strides=(2,2))
        self.cnn0_flat=Flatten()
        
        self.cnn1_conv1=Conv2D(96,(11,11),strides=(4, 4),activation="relu")
        self.cnn1_bn1=BatchNormalization()
        self.cnn1_pool1=MaxPooling2D((3,3), strides=(2,2))
        self.cnn1_conv2=Conv2D(256,(5,5),activation="relu",padding="same")
        self.cnn1_bn2=BatchNormalization()
        self.cnn1_pool2=MaxPooling2D((3,3), strides=(2,2))
        self.cnn1_conv3=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn1_bn3=BatchNormalization()
        self.cnn1_conv4=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn1_bn4=BatchNormalization()
        self.cnn1_conv5=Conv2D(256,(3,3),activation="relu",padding="same")
        self.cnn1_bn5=BatchNormalization()
        self.cnn1_pool3=MaxPooling2D((3,3), strides=(2,2))
        self.cnn1_flat=Flatten()
        
        
        # Fully connected
        self.fc1=Dense(4096,activation="relu")
        self.drop1=Dropout(0.5)

        self.fc2=Dense(4096,activation="relu")
        self.drop2=Dropout(0.5)
        self.fc = Dense(num_classes, activation="softmax")
    
    def call(self, inputs):
        
        #cnn bloc
        out0=self.cnn0_conv1(inputs[0])
        out0=self.cnn0_bn1(out0)
        out0=self.cnn0_pool1(out0)
        out0=self.cnn0_conv2(out0)
        out0=self.cnn0_bn2(out0)
        out0=self.cnn0_pool2(out0)
        out0=self.cnn0_conv3(out0)
        out0=self.cnn0_bn3(out0)
        out0=self.cnn0_conv4(out0)
        out0=self.cnn0_bn4(out0)
        out0=self.cnn0_conv5(out0)
        out0=self.cnn0_bn5(out0)
        out0=self.cnn0_pool3(out0)
        out0=self.cnn0_flat(out0)
        
        out1=self.cnn1_conv1(inputs[1])
        out1=self.cnn1_bn1(out1)
        out1=self.cnn1_pool1(out1)
        out1=self.cnn1_conv2(out1)
        out1=self.cnn1_bn2(out1)
        out1=self.cnn1_pool2(out1)
        out1=self.cnn1_conv3(out1)
        out1=self.cnn1_bn3(out1)
        out1=self.cnn1_conv4(out1)
        out1=self.cnn1_bn4(out1)
        out1=self.cnn1_conv5(out1)
        out1=self.cnn1_bn5(out1)
        out1=self.cnn1_pool3(out1)
        out1=self.cnn1_flat(out1)
      
        concat = tf.keras.layers.concatenate([out0, out1], name='Concatenate')
        
        # Fully connected
        out=self.fc1(concat)
        out=self.drop1(out)

        out=self.fc2(out)
        out=self.drop2(out)
        out=self.fc(out)
        return out


    
    """
    A standard resnet block.
    """
class ResnetBlock(Model):


    def __init__(self, channels: int, down_sample=False):
       
        super().__init__()

        self.__channels = channels
        self.__down_sample = down_sample
        self.__strides = [2, 1] if down_sample else [1, 1]

        KERNEL_SIZE = (3, 3)
        INIT_SCHEME = "he_normal"

        self.conv_1 = Conv2D(self.__channels, strides=self.__strides[0],
                             kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
        self.bn_1 = BatchNormalization()
        self.conv_2 = Conv2D(self.__channels, strides=self.__strides[1],
                             kernel_size=KERNEL_SIZE, padding="same", kernel_initializer=INIT_SCHEME)
        self.bn_2 = BatchNormalization()
        self.merge = Add()

        if self.__down_sample:
            # perform down sampling using stride of 2
            self.res_conv = Conv2D(
                self.__channels, strides=2, kernel_size=(1, 1), kernel_initializer=INIT_SCHEME, padding="same")
            self.res_bn = BatchNormalization()

    def call(self, inputs):
        res = inputs

        x = self.conv_1(inputs)
        x = self.bn_1(x)
        x = tf.nn.relu(x)
        x = self.conv_2(x)
        x = self.bn_2(x)

        if self.__down_sample:
            res = self.res_conv(res)
            res = self.res_bn(res)

        # if not perform down sample, then add a shortcut directly
        x = self.merge([x, res])
        out = tf.nn.relu(x)
        return out
    
    
    
"""
A ResNet18 model
"""

class ResNet18(Model):

    def __init__(self, num_classes, **kwargs):
        """
            num_classes: number of classes in specific classification task.
        """
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(64, (7, 7), strides=2,
                             padding="same", kernel_initializer="he_normal")
        self.init_bn = BatchNormalization()
        self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.res_1_1 = ResnetBlock(64)
        self.res_1_2 = ResnetBlock(64)
        self.res_2_1 = ResnetBlock(128, down_sample=True)
        self.res_2_2 = ResnetBlock(128)
        self.res_3_1 = ResnetBlock(256, down_sample=True)
        self.res_3_2 = ResnetBlock(256)
        self.res_4_1 = ResnetBlock(512, down_sample=True)
        self.res_4_2 = ResnetBlock(512)
        self.avg_pool = GlobalAveragePooling2D()
        self.flat = Flatten()
         # Fully connected
        self.fc1=Dense(4096,activation="relu")
        self.drop1=Dropout(0.5)

        self.fc2=Dense(4096,activation="relu")
        self.drop2=Dropout(0.5)
        self.fc = Dense(num_classes, activation="softmax")

    def call(self, inputs):
        
        out = self.conv_1(inputs)
        out = self.init_bn(out)
        out = tf.nn.relu(out)
        out = self.pool_2(out)
        for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
            out = res_block(out)
        out = self.avg_pool(out)
        out = self.flat(out)
        
        # Fully connected
        out = self.fc1(out)
        out = self.drop1(out)
        out = self.fc2(out)
        out = self.drop2(out)
        out = self.fc(out)
        return out
    

"""
combination of two ResNets to use RGB and Thermal images at the same time
"""
class ResNet18_combi(Model):
    def __init__(self, num_classes, **kwargs):
        """
            num_classes: number of classes in specific classification task.
        """
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(64, (7, 7), strides=2,
                             padding="same", kernel_initializer="he_normal")
        self.init_bn = BatchNormalization()
        self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.res_1_1 = ResnetBlock(64)
        self.res_1_2 = ResnetBlock(64)
        self.res_2_1 = ResnetBlock(128, down_sample=True)
        self.res_2_2 = ResnetBlock(128)
        self.res_3_1 = ResnetBlock(256, down_sample=True)
        self.res_3_2 = ResnetBlock(256)
        self.res_4_1 = ResnetBlock(512, down_sample=True)
        self.res_4_2 = ResnetBlock(512)
        self.avg_pool = GlobalAveragePooling2D()
        self.flat = Flatten()
        
        self.conv2_1 = Conv2D(64, (7, 7), strides=2,
                             padding="same", kernel_initializer="he_normal")
        self.init2_bn = BatchNormalization()
        self.pool2_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.res2_1_1 = ResnetBlock(64)
        self.res2_1_2 = ResnetBlock(64)
        self.res2_2_1 = ResnetBlock(128, down_sample=True)
        self.res2_2_2 = ResnetBlock(128)
        self.res2_3_1 = ResnetBlock(256, down_sample=True)
        self.res2_3_2 = ResnetBlock(256)
        self.res2_4_1 = ResnetBlock(512, down_sample=True)
        self.res2_4_2 = ResnetBlock(512)
        self.avg2_pool = GlobalAveragePooling2D()
        self.flat2 = Flatten()
        
         # Fully connected
        self.fc1=Dense(4096,activation="relu")
        self.drop1=Dropout(0.5)

        self.fc2=Dense(4096,activation="relu")
        self.drop2=Dropout(0.5)
        self.fc = Dense(num_classes, activation="softmax")

    def call(self,inputs):
        #in1 = Input(shape=(n,n,3))
        
        out1 = self.conv_1(inputs[0])
        out1 = self.init_bn(out1)
        out1 = tf.nn.relu(out1)
        out1 = self.pool_2(out1)
        for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
            out1 = res_block(out1)
        out1 = self.avg_pool(out1)
        out1 = self.flat(out1)
        #model1 = Model(inputs=in1, outputs=out1)
        
        #in2 = Input(shape=(n,n,3))
        out2 = self.conv2_1(inputs[1])
        out2 = self.init2_bn(out2)
        out2 = tf.nn.relu(out2)
        out2 = self.pool2_2(out2)
        for res_block in [self.res2_1_1, self.res2_1_2, self.res2_2_1, self.res2_2_2, self.res2_3_1, self.res2_3_2, self.res2_4_1, self.res2_4_2]:
            out2 = res_block(out2)
        out2 = self.avg2_pool(out2)
        out2 = self.flat2(out2)
        #model2 = Model(inputs=in2, outputs=out2)
        
        concat = tf.keras.layers.concatenate([out1, out2], name='Concatenate')
        
        # Fully connected
        out = self.fc1(concat)
        out = self.drop1(out)
        out = self.fc2(out)
        out = self.drop2(out)
        out = self.fc(out)
        
        #final_model = Model(inputs=[out1.input, out2.input], outputs=out,
        #            name='Final_output')
        #final_model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=["accuracy"])
        return out

"""
Final model: combination of CNN for RGB images and ResNet for Thermal images
"""
class ResNet_CNN(Model):
    def __init__(self, num_classes, **kwargs):
        """
            num_classes: number of classes in specific classification task.
        """
        #bloc1
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(64, (7, 7), strides=2,
                             padding="same", kernel_initializer="he_normal")
        self.init_bn = BatchNormalization()
        self.pool_2 = MaxPool2D(pool_size=(2, 2), strides=2, padding="same")
        self.res_1_1 = ResnetBlock(64)
        self.res_1_2 = ResnetBlock(64)
        self.res_2_1 = ResnetBlock(128, down_sample=True)
        self.res_2_2 = ResnetBlock(128)
        self.res_3_1 = ResnetBlock(256, down_sample=True)
        self.res_3_2 = ResnetBlock(256)
        self.res_4_1 = ResnetBlock(512, down_sample=True)
        self.res_4_2 = ResnetBlock(512)
        self.avg_pool = GlobalAveragePooling2D()
        self.flat = Flatten()
         
        #bloc2
        self.cnn_conv1=Conv2D(96,(11,11),strides=(4, 4),activation="relu")
        self.cnn_bn1=BatchNormalization()
        self.cnn_pool1=MaxPooling2D((3,3), strides=(2,2))
        self.cnn_conv2=Conv2D(256,(5,5),activation="relu",padding="same")
        self.cnn_bn2=BatchNormalization()
        self.cnn_pool2=MaxPooling2D((3,3), strides=(2,2))
        self.cnn_conv3=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn_bn3=BatchNormalization()
        self.cnn_conv4=Conv2D(384,(3,3),activation="relu",padding="same")
        self.cnn_bn4=BatchNormalization()
        self.cnn_conv5=Conv2D(256,(3,3),activation="relu",padding="same")
        self.cnn_bn5=BatchNormalization()
        self.cnn_pool3=MaxPooling2D((3,3), strides=(2,2))
        self.cnn_flat=Flatten()
        
        # Fully connected
        self.fc1=Dense(4096,activation="relu")
        self.drop1=Dropout(0.5)

        self.fc2=Dense(4096,activation="relu")
        self.drop2=Dropout(0.5)
        self.fc = Dense(num_classes, activation="softmax")

    def call(self,inputs):
        
        
        #cnn bloc
        out0=self.cnn_conv1(inputs[0])
        out0=self.cnn_bn1(out0)
        out0=self.cnn_pool1(out0)
        out0=self.cnn_conv2(out0)
        out0=self.cnn_bn2(out0)
        out0=self.cnn_pool2(out0)
        out0=self.cnn_conv3(out0)
        out0=self.cnn_bn3(out0)
        out0=self.cnn_conv4(out0)
        out0=self.cnn_bn4(out0)
        out0=self.cnn_conv5(out0)
        out0=self.cnn_bn5(out0)
        out0=self.cnn_pool3(out0)
        out0=self.cnn_flat(out0)
        
        #resnet bloc
        out1 = self.conv_1(inputs[1])
        out1 = self.init_bn(out1)
        out1 = tf.nn.relu(out1)
        out1 = self.pool_2(out1)
        for res_block in [self.res_1_1, self.res_1_2, self.res_2_1, self.res_2_2, self.res_3_1, self.res_3_2, self.res_4_1, self.res_4_2]:
            out1 = res_block(out1)
        out1 = self.avg_pool(out1)
        out1 = self.flat(out1)
        
        concat = tf.keras.layers.concatenate([out0, out1], name='Concatenate')
        # Fully connected
        out=self.fc1(concat)
        out=self.drop1(out)

        out=self.fc2(out)
        out=self.drop2(out)
        out=self.fc(out)
        return out

In [None]:
if MODEL == 'CNN':
    if COMBI: 
        model = CNN_combi(num_classes)
    else:
        model=CNN(num_classes)
if MODEL == 'ResNet18':
    if COMBI:
        model = ResNet18_combi(num_classes)
    else:
        model = ResNet18(num_classes)
if MODEL == 'ResNet_CNN':
    model = ResNet_CNN(num_classes)
        



model.build(input_shape = inputshape)
model.compile(optimizer = "adam",loss='categorical_crossentropy', metrics=["accuracy"]) 
model.summary()

In [15]:
from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor="val_accuracy",
    patience=10,
    restore_best_weights=True)
# Add a checkpoint where val accuracy is max, and save that model
mc = tf.keras.callbacks.ModelCheckpoint(MODEL_FILE_NAME+'.keras', monitor='val_accuracy', 
                     mode='max',  verbose=1, save_best_only=True)


In [None]:
batch_size=BATCH_SIZE
epochs=EPOCHS
if COMBI:
    history =  model.fit([X_RGB_train,X_IR_train],y_train,validation_data=([X_RGB_valid,X_IR_valid],y_valid),batch_size=batch_size,
                    epochs=epochs,verbose=1,callbacks=[early_stopping])
else:
    if RGB:
        history =  model.fit(X_RGB_train,y_RGB_train,validation_data=(X_RGB_valid,y_RGB_valid),batch_size=batch_size,
                    epochs=epochs,verbose=1,callbacks=[early_stopping])
    if IR:
         history =  model.fit(X_IR_train,y_IR_train,validation_data=(X_IR_valid,y_IR_valid),batch_size=batch_size,
                    epochs=epochs,verbose=1,callbacks=[early_stopping])
    if SMALL_DATASET:
        history =  model.fit(X_train,y_train,validation_data=(X_valid,y_valid),batch_size=batch_size,
                    epochs=epochs,verbose=1,callbacks=[early_stopping])
        
        
    
model.save('./'+MODEL_FILE_NAME)

In [None]:
model = tf.keras.models.load_model(MODEL_FILE_NAME)
if COMBI:
    score = model.evaluate([X_RGB_test,X_IR_test], y_test, batch_size=batch_size, verbose=1)
else:
    if RGB:
        score = model.evaluate(X_RGB_test, y_RGB_test, batch_size=batch_size, verbose=1)
    if IR:
        score = model.evaluate(X_IR_test, y_IR_test, batch_size=batch_size, verbose=1)
    if SMALL_DATASET:
        score = model.evaluate(X_test, y_test, batch_size=batch_size, verbose=1)
        

print('Test loss:', score[0])
print('Test accuracy:', score[1])

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline 

plt.figure(figsize=(6, 5))
plt.plot(history.history['accuracy'], color='b')
plt.plot(history.history['val_accuracy'], color='r')
plt.title('Model Accuracy', weight='bold', fontsize=16)
plt.ylabel('accuracy', weight='bold', fontsize=14)
plt.xlabel('epoch', weight='bold', fontsize=14)
plt.ylim(0.4, 1.0)
#plt.xticks(weight='bold', fontsize=12)
#plt.yticks(weight='bold', fontsize=12)
plt.legend(['train', 'val'], loc='lower right', prop={'size': 14})
#plt.grid(color = 'y', linewidth='0.5')
plt.savefig('training_curve_'+MODEL_FILE_NAME+'.png')