# 07 - Model 2 Training

In [None]:
# importing libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate, LeakyReLU, ReLU
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.metrics import MeanIoU
from sklearn.model_selection import train_test_split
import tifffile
from pathlib import Path
import os
import mlflow
import mlflow.tensorflow
from tqdm import tqdm
import time

# Project path
TERRAFLOOD = Path('../')

In [None]:
# ------------ inputs --------------
# experiment's meta data
experiment_number = 1
model_architectre = 2
epochs = 200
early_stopping_patience = 10
learning_rate_patience = 5

# input and output paths
load_path = TERRAFLOOD.joinpath("dataset_balanced/")
save_path = TERRAFLOOD.joinpath(f"experiments/model_{model_architecture}_exp_{experiment_number}/")

# structure of logging and saving with naming convention
# Checkpoint of the model
checkpoint_dir = save_path.joinpath("checkpoint/")
checkpoint_path = checkpoint_dir / f"model_{model_architecture}_check_exp_{experiment_number}_epochs_{epochs}_patience_{early_stopping_patience}_{learning_rate_patience}.keras"

# Logging on tensorboard
tensorboard_logs_dir = save_path.joinpath("tensorboard_log/")

# Saving final log on ML-Flow
mlflow_logs_dir = save_path.joinpath("mlflow_log/")

# Saving the final model
model_save_dir = save_path.joinpath("model/")
model_path = model_save_dir / f"model_{model_architecture}_exp_{experiment_number}_epochs_{epochs}_patience_{early_stopping_patience}_{learning_rate_patience}.keras"

# directory existence ensurance
checkpoint_dir.mkdir(parents=True, exist_ok=True)
tensorboard_logs_dir.mkdir(parents=True, exist_ok=True)
mlflow_logs_dir.mkdir(parents=True, exist_ok=True)
model_save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Hardware info
# Print TensorFlow version
print("TensorFlow version:", tf.__version__)

# List available devices
devices = tf.config.list_physical_devices()
print("Available devices:", devices)

# GPU info in details (assuming nvidia as GPU device)
!nvidia-smi