In [1]:
!rm *.npz
!pip install wandb

rm: cannot remove '*.npz': No such file or directory


# Setting up 

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import glob
import numpy as np
from tensorflow import keras

import wandb
from wandb.keras import WandbCallback
from datetime import datetime

import tensorflow.keras.losses as losses
import tensorflow.keras.optimizers as optimizers

import tensorflow as tf
import tensorflow.keras.metrics as metrics

import os
from tensorflow import argmax

os.chdir("/content/drive/MyDrive/Colab Notebooks/")
os.listdir()

['100-validation_y_0.npz',
 '100-validation_y_1.npz',
 '100-validation_y_2.npz',
 '100-validation_y_3.npz',
 '100-validation_y_4.npz',
 '100-validation_y_5.npz',
 '100-validation_y_6.npz',
 '100-validation_y_7.npz',
 '100-validation_y_8.npz',
 '100-validation_y_9.npz',
 '100-validation_y_10.npz',
 '100-validation_y_11.npz',
 '100-validation_y_12.npz',
 '100-validation_y_13.npz',
 '100-validation_y_14.npz',
 '100-validation_y_15.npz',
 '100-validation_y_16.npz',
 '100-validation_y_17.npz',
 'colab.ipynb',
 '1-test_y_0.npz',
 '100-validation_x_6.npz',
 '100-validation_x_7.npz',
 '100-validation_x_9.npz',
 '100-validation_x_8.npz',
 '100-validation_x_10.npz',
 '100-validation_x_12.npz',
 '100-validation_x_11.npz',
 '100-validation_x_13.npz',
 '100-validation_x_15.npz',
 '100-validation_x_14.npz',
 '100-validation_x_16.npz',
 '100-validation_x_0.npz',
 '1-test_x_0.npz',
 '100-validation_x_17.npz',
 '100-validation_x_1.npz',
 '100-validation_x_3.npz',
 '100-validation_x_2.npz',
 '100-valida

# Utils

In [4]:
class SanityCheck(keras.callbacks.Callback):

    id_batch = -1
    epoch = 0

    DEBUG = False

    def __init__(self, dataset, output="./", regulator=200, export_files=True, export_wandb=True):
        super(SanityCheck, self).__init__()
        self.dataset = dataset
        self.data = self.dataset.__getitem__(0)
        self.image_size = [self.data[0][0].shape[0], self.data[0][0].shape[1]]
        self.output = output
        self.regulator = regulator
        self.export_files = export_files
        self.export_wandb = export_wandb

    def on_epoch_end(self, epoch, logs=None):
        # if self.id_batch > 25:
        self.process_test()

        self.id_batch = -1
        self.epoch += 1

    def on_train_batch_end(self, batch, logs=None):
        self.id_batch += 1
        if self.id_batch > 25 and self.id_batch % self.regulator == 0:
            self.process_test()

    def predict_mask(self, img, mask):
        if self.DEBUG == True:
            result = np.repeat(np.expand_dims(np.zeros(self.image_size), axis=2), 2, axis=2)  # Forme du résultat : (X, Y, 2)
        else:
            result = self.model.predict(np.expand_dims(img, axis=0))[0]

        result = np.array(argmax(result, axis=-1), dtype=np.uint8)
        mask = np.array(argmax(mask, axis=-1), dtype=np.uint8)

        return result, mask

    def extract_file(self, result):
        os.makedirs(self.output, exist_ok=True)

        # plt.rcParams["figure.figsize"] = (14, 20)
        fig, axs = plt.subplots(3, len(result))
        fig.suptitle(("MODEL-NAME" if self.DEBUG else self.model.name) + " - I M S")

        colors = self.dataset.colors()
        for i_img, (img_i, mask_i, seg_i) in enumerate(result):

            # Colorisation du masque et du résultat
            mask = np.zeros(img_i.shape, dtype=np.uint8)
            seg = np.zeros(img_i.shape, dtype=np.uint8)
            for categorie in colors.keys():
                mask[mask_i == categorie] = colors[categorie]["color"]
                seg[seg_i == categorie] = colors[categorie]["color"]

            # Affichages des images

            axs[0, i_img].imshow(img_i)
            axs[0, i_img].axis('off')

            axs[1, i_img].imshow(mask)
            axs[1, i_img].axis('off')

            axs[2, i_img].imshow(seg)
            axs[2, i_img].axis('off')

        plt.subplots_adjust(wspace=.05, hspace=.05)
        fig.savefig("%s/%d_%d.png" % (self.output, self.epoch, self.id_batch), dpi=1000, bbox_inches='tight')
        plt.close()

    def extract_wandb(self, result):
        labels = self.dataset.labels()
        wandb_mask_list = list(map(lambda x: wandb.Image(x[0], masks={"prediction": {"mask_data": x[2], "class_labels": labels}, "ground truth": {"mask_data": x[1], "class_labels": labels}}), result))
        wandb.log({"Predictions" : wandb_mask_list})

    def process_test(self):

        result=[]
        imgs, masks=self.data

        for img_i, mask_i in zip(imgs, masks):
            seg, mask_i=self.predict_mask(img_i, mask_i)
            result.append((img_i, mask_i, seg))

        if self.export_files:
            self.extract_file(result)

        if self.export_wandb:
            self.extract_wandb(result)

In [5]:
class ArgmaxMeanIOU(metrics.MeanIoU):
    def update_state(self, y_true, y_pred, sample_weight=None):
        return super().update_state(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1), sample_weight)

# Model

In [6]:

import tensorflow.keras.layers as layers
import tensorflow.keras.losses as losses
import tensorflow.keras.metrics as metrics
import tensorflow.keras.models as models
import tensorflow.keras.optimizers as optimizers


from tensorflow.keras.layers import Layer, InputSpec
import keras.utils.conv_utils as conv_utils
import tensorflow as tf
import keras.backend as K

# default input shape
INPUT_SHAPE = (512, 1024, 3)

def normalize_data_format(value):
    if value is None:
        value = K.image_data_format()
    data_format = value.lower()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('The `data_format` argument must be one of '
                         '"channels_first", "channels_last". Received: ' +
                         str(value))
    return data_format

class BilinearUpSampling2D(Layer):
    def __init__(self, size=(2, 2), data_format=None, **kwargs):
        super(BilinearUpSampling2D, self).__init__(**kwargs)
        self.data_format = normalize_data_format(data_format)
        self.size = conv_utils.normalize_tuple(size, 2, 'size')
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
            width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
            return (input_shape[0],
                    input_shape[1],
                    height,
                    width)
        elif self.data_format == 'channels_last':
            height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
            width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
            return (input_shape[0],
                    height,
                    width,
                    input_shape[3])

    def call(self, inputs):
        input_shape = K.shape(inputs)
        if self.data_format == 'channels_first':
            height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
            width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
        elif self.data_format == 'channels_last':
            height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
            width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
        
        return tf.image.resize(inputs, [height, width], method=tf.image.ResizeMethod.BILINEAR)

    def get_config(self):
        config = {'size': self.size, 'data_format': self.data_format}
        base_config = super(BilinearUpSampling2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def ge_layer(x_in, c, e=6, stride=1):
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), padding='same')(x_in)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    if stride == 2:
        x = layers.DepthwiseConv2D(depth_multiplier=e, kernel_size=(3, 3), strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)

        y = layers.DepthwiseConv2D(depth_multiplier=e, kernel_size=(3, 3), strides=2, padding='same')(x_in)
        y = layers.BatchNormalization()(y)
        y = layers.Conv2D(filters=c, kernel_size=(1, 1), padding='same')(y)
        y = layers.BatchNormalization()(y)
    else:
        y = x_in

    x = layers.DepthwiseConv2D(depth_multiplier=e, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(filters=c, kernel_size=(1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Add()([x, y])
    x = layers.Activation('relu')(x)
    return x


def stem(x_in, c):
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), strides=2, padding='same')(x_in)
    x = layers.BatchNormalization()(x)
    x_split = layers.Activation('relu')(x)

    x = layers.Conv2D(filters=c // 2, kernel_size=(1, 1), padding='same')(x_split)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters=c, kernel_size=(3, 3), strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    y = layers.MaxPooling2D()(x_split)

    x = layers.Concatenate()([x, y])
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    return x


def detail_conv2d(x_in, c, stride=1):
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), strides=stride, padding='same')(x_in)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    return x


def context_embedding(x_in, c):
    x = layers.GlobalAveragePooling2D()(x_in)
    x = layers.BatchNormalization()(x)

    x = layers.Reshape((1, 1, c))(x)

    x = layers.Conv2D(filters=c, kernel_size=(1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    # broadcasting no needed

    x = layers.Add()([x, x_in])
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), padding='same')(x)
    return x


def bilateral_guided_aggregation(detail, semantic, c):
    # detail branch
    detail_a = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same')(detail)
    detail_a = layers.BatchNormalization()(detail_a)

    detail_a = layers.Conv2D(filters=c, kernel_size=(1, 1), padding='same')(detail_a)

    detail_b = layers.Conv2D(filters=c, kernel_size=(3, 3), strides=2, padding='same')(detail)
    detail_b = layers.BatchNormalization()(detail_b)

    detail_b = layers.AveragePooling2D((3, 3), strides=2, padding='same')(detail_b)

    # semantic branch
    semantic_a = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same')(semantic)
    semantic_a = layers.BatchNormalization()(semantic_a)

    semantic_a = layers.Conv2D(filters=c, kernel_size=(1, 1), padding='same')(semantic_a)
    semantic_a = layers.Activation('sigmoid')(semantic_a)

    semantic_b = layers.Conv2D(filters=c, kernel_size=(3, 3), padding='same')(semantic)
    semantic_b = layers.BatchNormalization()(semantic_b)

    semantic_b = layers.UpSampling2D((4, 4), interpolation='bilinear')(semantic_b)
    semantic_b = layers.Activation('sigmoid')(semantic_b)

    # combining
    detail = layers.Multiply()([detail_a, semantic_b])
    semantic = layers.Multiply()([semantic_a, detail_b])

    # this layer is not mentioned in the paper !?
    #semantic = layers.UpSampling2D((4,4))(semantic)
    semantic = layers.UpSampling2D((4, 4), interpolation='bilinear')(semantic)

    x = layers.Add()([detail, semantic])
    x = layers.Conv2D(filters=c, kernel_size=(3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)

    return x

def seg_head(x_in, c_t, out_scale, num_classes):
    x = layers.Conv2D(filters=c_t, kernel_size=(3, 3), padding='same')(x_in)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters=num_classes, kernel_size=(3, 3), padding='same')(x)
    x = layers.UpSampling2D((out_scale, out_scale), interpolation='bilinear')(x)

    return x

def BiSeNetV2(num_classes=2, out_scale=8, input_shape=INPUT_SHAPE, l=4, seghead_expand_ratio=2):
    x_in = layers.Input(input_shape)

    # semantic branch
    # S1 + S2
    x = stem(x_in, 64 // l)

    # S3
    x = ge_layer(x, 128 // l, stride=2)
    x = ge_layer(x, 128 // l, stride=1)

    # S4
    x = ge_layer(x, 64, stride=2)
    x = ge_layer(x, 64, stride=1)

    # S5
    x = ge_layer(x, 128, stride=2)

    x = ge_layer(x, 128, stride=1)
    x = ge_layer(x, 128, stride=1)
    x = ge_layer(x, 128, stride=1)

    x = context_embedding(x, 128)

    # detail branch
    # S1
    y = detail_conv2d(x_in, 64, stride=2)
    y = detail_conv2d(y, 64, stride=1)

    # S2
    y = detail_conv2d(y, 64, stride=2)
    y = detail_conv2d(y, 64, stride=1)
    y = detail_conv2d(y, 64, stride=1)

    # S3
    y = detail_conv2d(y, 128, stride=2)
    y = detail_conv2d(y, 128, stride=1)
    y = detail_conv2d(y, 128, stride=1)

    x = bilateral_guided_aggregation(y, x, 128)

    x = seg_head(x, num_classes * seghead_expand_ratio, out_scale, num_classes)

    model = models.Model(inputs=[x_in], outputs=[x], name="BiSeNet-V2")

    # set weight initializers
    for layer in model.layers:
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel_initializer = tf.keras.initializers.HeNormal()
        if hasattr(layer, 'depthwise_initializer'):
            layer.depthwise_initializer = tf.keras.initializers.HeNormal()

    return model

def bisenetv2_DEEPER(num_classes=2, out_scale=8, input_shape=INPUT_SHAPE, l=4, seghead_expand_ratio=2):
    x_in = layers.Input(input_shape)

    # semantic branch
    # S1 + S2
    x = stem(x_in, 64 // l)

    # S3
    x = ge_layer(x, 128 // l, stride=2)
    x = ge_layer(x, 128 // l, stride=1)
    

    # S3 ++ 
    x = ge_layer(x, 256 // l, stride=2)
    x = ge_layer(x, 256 // l, stride=1)

    # S4
    x = ge_layer(x, 64, stride=2)
    x = ge_layer(x, 64, stride=1)

    # S5
    x = ge_layer(x, 128, stride=2)

    x = ge_layer(x, 128, stride=1)
    x = ge_layer(x, 128, stride=1)
    x = ge_layer(x, 128, stride=1)

    print(x.shape)

    x = context_embedding(x, 128)

    # detail branch
    # S1
    y = detail_conv2d(x_in, 64, stride=2)
    y = detail_conv2d(y, 64, stride=1)

    # S2
    y = detail_conv2d(y, 64, stride=2)
    y = detail_conv2d(y, 64, stride=1)
    y = detail_conv2d(y, 64, stride=1)

    # S3
    y = detail_conv2d(y, 128, stride=2)
    y = detail_conv2d(y, 128, stride=1)
    y = detail_conv2d(y, 128, stride=1)

    # S3 ++
    y = detail_conv2d(y, 256, stride=2)
    y = detail_conv2d(y, 256, stride=1)
    y = detail_conv2d(y, 256, stride=1)

    x = bilateral_guided_aggregation(y, x, 256) # AVANT 128 

    x = seg_head(x, num_classes * seghead_expand_ratio, out_scale, num_classes)

    model = models.Model(inputs=[x_in], outputs=[x], name="BiSeNet-V2-Deeper")

    # set weight initializers
    for layer in model.layers:
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel_initializer = tf.keras.initializers.HeNormal()
        if hasattr(layer, 'depthwise_initializer'):
            layer.depthwise_initializer = tf.keras.initializers.HeNormal()

    return model

# Training

In [7]:
DATASET_FILE_SIZE = 100

USE_WANDB = True

IMG_SIZE = (512, 512)
BATCH_SIZE = 4
EPOCHS = 30
LR = 1e-4


MAPILLARY_VISTAS_CATEGORIES = {
    1: {"name": "Road", "color": [[128, 64, 128], [110, 110, 110]]},
    2: {"name": "Lane", "color": [[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26], [250, 170, 16], [250, 170, 15], [250, 170, 11], [250, 170, 12], [250, 170, 18], [250, 170, 19], [250, 170, 25], [250, 170, 20], [250, 170, 21], [250, 170, 22], [250, 170, 24]]},
    3: {"name": "Crosswalk", "color": [[140, 140, 200], [200, 128, 128]]},
    4: {"name": "Curb", "color": [[196, 196, 196], [90, 120, 150]]},
    5: {"name": "Sidewalk", "color": [[244, 35, 232]]},

    6: {"name": "Traffic Light", "color": [[250, 170, 30]]},
    7: {"name": "Traffic Sign", "color": [[220, 220, 0]]},

    8: {"name": "Person", "color": [[220, 20, 60]]},

    9: {"name": "Bicycle", "color": [[119, 11, 32], [255, 0, 0]]},
    10: {"name": "Bus", "color": [[0, 60, 100]]},
    11: {"name": "Car", "color": [[0, 0, 142], [0, 0, 90], [0, 0, 110]]},
    12: {"name": "Motorcycle", "color": [[0, 0, 230], [255, 0, 200], [255, 0, 100]]},
    13: {"name": "Truck", "color": [[0, 0, 70]]},
    
    14: {"name": "Sky", "color": [[70, 130, 180]]},
    15: {"name": "Nature", "color": [[107, 142, 35], [152, 251, 152]]}
}

AUDI_A2D2_CATEGORIES = {
    1: {"name": "Road", "color": [[180, 50, 180], [255, 0, 255]]},
    2: {"name": "Lane", "color": [[255, 193, 37], [200, 125, 210], [128, 0, 255]]},
    3: {"name": "Crosswalk", "color": [[210, 50, 115]]},
    4: {"name": "Curb", "color": [[128, 128, 0]]},
    5: {"name": "Sidewalk", "color": [[180, 150, 200]]},

    6: {"name": "Traffic Light", "color": [[0, 128, 255], [30, 28, 158], [60, 28, 100]]},
    7: {"name": "Traffic Sign", "color": [[0, 255, 255], [30, 220, 220], [60, 157, 199]]},

    8: {"name": "Person", "color": [[204, 153, 255], [189, 73, 155], [239, 89, 191]]},

    9: {"name": "Bicycle", "color": [[182, 89, 6], [150, 50, 4], [90, 30, 1], [90, 30, 30]]},
    10: {"name": "Bus", "color": []},
    11: {"name": "Car", "color": [[255, 0, 0], [200, 0, 0], [150, 0, 0], [128, 0, 0]]},
    12: {"name": "Motorcycle", "color": [[0, 255, 0], [0, 200, 0], [0, 150, 0]]},
    13: {"name": "Truck", "color": [[255, 128, 0], [200, 128, 0], [150, 128, 0], [255, 255, 0], [255, 255, 200]]},

    14: {"name": "Sky", "color": [[135, 206, 255]]},
    15: {"name": "Nature", "color": [[147, 253, 194]]}
}

In [8]:
class NPZDataset(keras.utils.Sequence):

    def __init__(self, dataset_type, length):

        self.files_name = self.get_files(dataset_type)
        self.current_batch = None
        self.current_batch_id = None
        self.length = length
        

    def get_files(self, dataset_type):
        data = list(sorted(glob.glob('*-' + dataset_type + '_x_*.npz')))

        print("Nom du dataset avec size:", list(map(lambda x: x[0:x.index("-")], data)))
        return data

    def classes(self):
        return len(AUDI_A2D2_CATEGORIES) + 1

    def labels(self):
        l = {0: "Background"}
        for i, label in enumerate(map(lambda x: AUDI_A2D2_CATEGORIES[x]["name"], AUDI_A2D2_CATEGORIES), start=1):
            l[i] = label
        return l

    def name(self):
        return "NPZDataset"

    def __len__(self):
        INDEX_DEBUT = 0
        return sum(list(map(lambda x: int(x[INDEX_DEBUT:x.index("-")]), self.files_name)))

    def __getitem__(self, batch_id):

        index_file = batch_id // self.length
        batch_number = batch_id % self.length

        if self.current_batch_id != index_file:
            self.current_batch_id = index_file
            print("---", batch_id, "-> Changing the batch file")
            self.current_batch = None
            with np.load(self.files_name[index_file]) as data_x:
                with np.load(self.files_name[index_file].replace("x", "y")) as data_y:
                    self.current_batch = {
                        "x" : data_x["arr_0"],
                        "y": data_y["arr_0"]
                    }

        x = self.current_batch["x"][batch_number]
        y = self.current_batch["y"][batch_number]
        
        return x, y


In [9]:
validation_dataset = NPZDataset('validation', 100)

Nom du dataset avec size: ['100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100']


In [10]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 13249834530450959027
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 11320098816
locality {
  bus_id: 1
  links {
  }
}
incarnation: 16388493014480480952
physical_device_desc: "device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7"
xla_global_id: 416903419
]


In [None]:
train_dataset = NPZDataset('validation', 100)
validation_dataset = NPZDataset('validation', 100)
test_dataset = NPZDataset('test', 1)


print("train_dataset :", len(train_dataset), "batchs -", len(train_dataset) * BATCH_SIZE, "images")
print("validation_dataset :", len(validation_dataset), "batchs -", len(validation_dataset) * BATCH_SIZE, "images")
print("test_dataset :", len(test_dataset), "batchs -", len(test_dataset) * BATCH_SIZE, "images")


# Creating model
print("\n> Creating model")
model = BiSeNetV2(num_classes=validation_dataset.classes(), input_shape=IMG_SIZE + (3,))

optimizer = optimizers.Adam(learning_rate=LR)
cce = losses.CategoricalCrossentropy(from_logits=True)

model.compile(optimizer, loss=cce, metrics=['accuracy', ArgmaxMeanIOU(validation_dataset.classes())])


model = keras.models.load_model("BiSeNet-V2_MultiDataset_512-512_epoch-13_loss-0.23_miou_0.54.h5", custom_objects={'ArgmaxMeanIOU': ArgmaxMeanIOU})


# Callbacks
now_str = datetime.now().strftime("%Y%m%d-%H%M%S")
callbacks = [
    SanityCheck(test_dataset, output="trained_models/" + now_str + "/check/", regulator=500, export_files=False, export_wandb=USE_WANDB),
    keras.callbacks.ModelCheckpoint("trained_models/" + now_str + "/" + model.name + "_" + test_dataset.name() + "_" + str(IMG_SIZE[0]) + "-" + str(IMG_SIZE[1]) + "_epoch-{epoch:02d}_loss-{val_loss:.2f}_miou_{val_argmax_mean_iou:.2f}.h5"),
    keras.callbacks.TensorBoard(log_dir="trained_models/" + now_str + "/logs/", histogram_freq=1)
]

if USE_WANDB:
    run = wandb.init(project="Road Segmentation", entity="nrocher", config={
        "learning_rate": LR,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "image_size": IMG_SIZE,
        "dataset": test_dataset.name(),
        "model": model.name
    })
    callbacks.append(WandbCallback())



# Training
print("\n> Training")
model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=validation_dataset,
    shuffle=False,
    callbacks=callbacks
)



# Weights & Biases - END
if USE_WANDB:
    run.finish()

Nom du dataset avec size: ['100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100']
Nom du dataset avec size: ['100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100']
Nom du dataset avec size: ['1']
train_dataset : 1800 batchs - 7200 images
validation_dataset : 1800 batchs - 7200 images
test_dataset : 1 batchs - 4 images

> Creating model
--- 0 -> Changing the batch file


[34m[1mwandb[0m: Currently logged in as: [33mnrocher[0m (use `wandb login --relogin` to force relogin)



> Training
--- 0 -> Changing the batch file


  layer_config = serialize_layer_fn(layer)


Epoch 1/30
  99/1800 [>.............................] - ETA: 12:19 - loss: 0.2262 - accuracy: 0.9197 - argmax_mean_iou: 0.5218--- 100 -> Changing the batch file
 199/1800 [==>...........................] - ETA: 13:04 - loss: 0.2191 - accuracy: 0.9229 - argmax_mean_iou: 0.5185--- 200 -> Changing the batch file
 299/1800 [===>..........................] - ETA: 12:40 - loss: 0.2123 - accuracy: 0.9254 - argmax_mean_iou: 0.5202--- 300 -> Changing the batch file
 399/1800 [=====>........................] - ETA: 12:04 - loss: 0.2104 - accuracy: 0.9261 - argmax_mean_iou: 0.5304--- 400 -> Changing the batch file
--- 100 -> Changing the batch file
--- 200 -> Changing the batch file
--- 300 -> Changing the batch file
--- 400 -> Changing the batch file
--- 500 -> Changing the batch file
--- 600 -> Changing the batch file
--- 700 -> Changing the batch file
--- 800 -> Changing the batch file
--- 900 -> Changing the batch file
--- 1000 -> Changing the batch file
--- 1100 -> Changing the batch file
--