In [11]:
import tensorflow as tf
import datetime
import os
os. environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import *
from tensorflow.keras.metrics import IoU
from tensorflow.keras.layers import Input
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ModelCheckpoint

from dataset_loader import DataLoader
from model import UNet
from losses import CustomLoss

In [12]:
DATASET = {
    'DATASET_PATH': './Dataset',
    'TRAIN_VAL_RATIO': 0.2,
    'BATCH_SIZE': 4,
    'IMG_SIZE': (128,128),
    'SHUFFLE': True,
    'NUM_CLASSES': 3
    }

AUG = {
    'FLIP_H': 0.3,
    'FLIP_V': 0.6,
    'ROTATE_CW': 0.4,
    'ROTATE_CCW': 0.7,
    'TRANSLATE': [5, 0.3],
    'RAND_BRIGHTNESS': 0.5,
}

RANDOM_SEED = 120

MODEL = {
    'ARCH': 'UNet',
    'INPUT_SIZE': (128,128,3)
    }
LOSS = {
    'DICE_COEFF': 0.1,
    'IOU_COEFF': 0.1,
    'FOCAL_LOSS_COEFF': 0.8
    }

TRAIN = {
    'NUM_EPOCHS': 2,
    'LR': 0.0001,
    'DECAY_STEP': 0.0005,
    'EARLY_STOP': True,
    'MODEL_SAVE_DIR': './train_results'    
    }


In [13]:
# Define the dataset loader
tf.random.set_seed(RANDOM_SEED)
datasetLoader = DataLoader(DATASET['DATASET_PATH'])
train_dataset, val_dataset, test_dataset = datasetLoader.create_dataset()


Importing the datasets with the following parameters...
   Dataset path                    : ./Dataset
   Train-Val dataset ratio        : 0.2

Splitting training and validation sets...
num of training data:  2944
num of validation data:  736
num of testing data:  3710


In [14]:
# Define the model
model_arch = MODEL['ARCH']
if model_arch == 'UNet':
    unet = UNet(num_class=3, dropout=0.2)
    inputs = Input(shape=MODEL['INPUT_SIZE'])
    out = unet(inputs, training=True)
    model = tf.keras.Model(inputs=inputs, outputs=out)
elif model_arch == 'Autoencoder':
    pass
elif model_arch == 'CLIP':
    pass
elif model_arch == 'Prompt':
    pass
else:
    print('Model architecture is not assigned')




In [16]:
# Compile the model
learning_rate = TRAIN['LR']
decay_step = TRAIN['DECAY_STEP']
if decay_step is not None:
    schedule = optimizers.schedules.PolynomialDecay(
                    initial_learning_rate=learning_rate,
                    decay_steps=decay_step,
                    power=0.9
                )
    optimizer = Adam(learning_rate=schedule)
else:
    optimizer = Adam(learning_rate=0.001)


loss = CustomLoss(LOSS)
iou_metrics = IoU(3, [0])
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

if TRAIN['EARLY_STOP'] == True:
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=8)
    
model.compile(optimizer=optimizer, 
              loss=loss, 
              metrics=['acc', iou_metrics])

checkpoint = ModelCheckpoint(filepath=os.path.join(TRAIN['MODEL_SAVE_DIR'], '{model_arch}.epoch{epoch:02d}-loss{loss:.2f}.weights.h5'),
                            monitor='val_loss',
                            verbose=1,
                            save_weights_only=True,
                            save_best_only=True,
                            mode='min')

In [17]:
history = model.fit(train_dataset, 
                    validation_data=val_dataset, 
                    epochs=TRAIN['NUM_EPOCHS'], 
                    verbose=1,
                    callbacks=[tensorboard_callback, checkpoint, early_stop])

Epoch 1/2


InvalidArgumentError: Graph execution error:

Detected at node stack defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel_launcher.py", line 18, in <module>

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\traitlets\config\application.py", line 1075, in launch_instance

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelapp.py", line 739, in start

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\tornado\platform\asyncio.py", line 205, in start

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\asyncio\base_events.py", line 645, in run_forever

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\asyncio\base_events.py", line 1999, in _run_once

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\asyncio\events.py", line 88, in _run

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 534, in process_one

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 362, in execute_request

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\kernelbase.py", line 778, in execute_request

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\ipkernel.py", line 449, in do_execute

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\ipykernel\zmqshell.py", line 549, in run_cell

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes

  File "C:\Users\Diaz Angga Permana\AppData\Roaming\Python\Python312\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code

  File "C:\Users\Diaz Angga Permana\AppData\Local\Temp\ipykernel_2824\665906156.py", line 1, in <module>

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 371, in fit

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 219, in function

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 132, in multi_step_on_iterator

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 113, in one_step_on_data

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\trainer.py", line 84, in train_step

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\trainers\trainer.py", line 490, in compute_metrics

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\trainers\compile_utils.py", line 334, in update_state

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\trainers\compile_utils.py", line 21, in update_state

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\metrics\iou_metrics.py", line 142, in update_state

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\metrics\metrics_utils.py", line 677, in confusion_matrix

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\ops\numpy.py", line 5228, in stack

  File "c:\ProgramData\miniconda3\envs\cvLab\Lib\site-packages\keras\src\backend\tensorflow\numpy.py", line 2014, in stack

Shapes of all inputs must match: values[0].shape = [65536] != values[1].shape = [196608]
	 [[{{node stack}}]] [Op:__inference_multi_step_on_iterator_41517]