In [None]:
# general libs
import os 
import gdown
import logging
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import wandb
import nvidia_smi
import socket

from datetime import datetime
from timeit import default_timer as timer
from pathlib import Path
from typing import Any

# tensorflow libs
import tensorflow as tf
from tensorflow import keras

from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.utils import img_to_array, load_img

# current time
now_str = datetime.now().strftime("%a, %d %b %Y %H:%M:%S")
date_str = datetime.now().strftime("%Y%m%d")
datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
time_str = datetime.now().strftime("%H%M%S")

# model 
model_name = "efficientnet_b0"

# current dir
current_dir = os.getcwd()

# batch size 
batch_size = 256

# using wandb
is_wandb = True

# current dir name
current_dir_name = os.path.basename(current_dir)

# project name
project_name = f"{current_dir_name}/06_transfer_learning_20230911/{date_str}/{model_name}_bs_{batch_size}_{time_str}"

# file name
file_name = f"{model_name}_bs_{batch_size}_{datetime_str}"

# data directory
DATA_DIR = Path("../../data/")
if not DATA_DIR.is_dir():
    DATA_DIR.mkdir(parents=True, exist_ok=True)

# log directory
LOG_DIR = Path(f"../../logs/{project_name}")
if not LOG_DIR.is_dir():
    LOG_DIR.mkdir(parents=True, exist_ok=True)
else: 
    # remove folders
    os.system(f"rm -rf {LOG_DIR}")
    LOG_DIR.mkdir(parents=True, exist_ok=True)

LOG_FILE_PATH = LOG_DIR.joinpath(f"{file_name}.log")
FE_HISTORY_CURVES_FILE_PATH = LOG_DIR.joinpath(f"fe_history_curve_{file_name}.png")
FT_HISTORY_CURVES_FILE_PATH = LOG_DIR.joinpath(f"ft_history_curve_{file_name}.png")
FE_PREDICTION_FILE_PATH = LOG_DIR.joinpath(f"fe_prediction_{file_name}.png")
FT_PREDICTION_FILE_PATH = LOG_DIR.joinpath(f"ft_prediction_{file_name}.png")

# model directory
MODEL_DIR = Path(f"../../models/{project_name}")
if not MODEL_DIR.is_dir():
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
else: 
    # remove folders
    os.system(f"rm -rf {MODEL_DIR}")
    MODEL_DIR.mkdir(parents=True, exist_ok=True)

MODEL_FEATURE_EXTRACTION_DIR = MODEL_DIR.joinpath(f"fe_{file_name}/checkpoint.ckpt")
if not MODEL_FEATURE_EXTRACTION_DIR.is_dir():
    MODEL_FEATURE_EXTRACTION_DIR.mkdir(parents=True, exist_ok=True)
MODEL_FINE_TUNING_DIR = MODEL_DIR.joinpath(f"ft_{file_name}/checkpoint.ckpt")
if not MODEL_FINE_TUNING_DIR.is_dir():
    MODEL_FINE_TUNING_DIR.mkdir(parents=True, exist_ok=True)

# logging configuration
logging.basicConfig(
    level=logging.INFO,
    datefmt="%a, %d %b %Y %H:%M:%S",
    format="[%(asctime)s.%(msecs)03d] %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler(filename=LOG_FILE_PATH, mode="w"),
        logging.StreamHandler()
    ]
)

# created date
logging.info(f"Created date: Mon, 11 Sep 2023 13:23:45")

# modified date
logging.info(f"Modified date: {now_str}")

logging.info(f"Model name: {model_name}")
# logging.info(f"Current dir: {current_dir}")
logging.info(f"Current dir name: {current_dir_name}")
logging.info(f"Project name: {project_name}")
logging.info(f"Model feature extraction dir: {MODEL_FEATURE_EXTRACTION_DIR}")
logging.info(f"Model fine tuning dir: {MODEL_FINE_TUNING_DIR}")

# tensorflow version
logging.info(f"TensorFlow version: {tf.__version__}")
logging.info(f"Keras version: {tf.keras.__version__}")

# physical devices
gpus = len(tf.config.list_physical_devices("GPU"))
logging.info(f"GPUs: {gpus}")
if gpus > 0:
    try:
        tf.config.set_visible_devices(tf.config.list_physical_devices("GPU"), "GPU")
    except Exception as error:
        logging.error(f"Caught this error during setting available devices: {error}")

# cpus
cpus = os.cpu_count()
logging.info(f"CPUs: {cpus}")

# dataset name
dataset_name = "food10"
logging.info(f"Dataset name: {dataset_name}")

# batch_size
logging.info(f"Batch size: {batch_size}")

# wandb
logging.info(f"Use wandb: {is_wandb}")

## seed
# default seed
default_seed = 42
logging.info(f"Default seed: {default_seed}")

random.seed(default_seed)
np.random.seed(default_seed)

tf.random.set_seed(default_seed)
tf.keras.utils.set_random_seed(default_seed)
# tf.config.experimental.enable_op_determinism()

# when running on the CuDNN backend, two further options must be set
# os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
# os.environ['TF_DETERMINISTIC_OPS'] = '1'

# set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(default_seed)

## dataset downloading
# 10 percents
DATASET_FOOD10_10_PERCENT_URL = "https://drive.google.com/uc?id=17Yw1PGyDGpwV77Ds6VDBQ7Iv9Q2rQR6P"
DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME = "10_food_classes_10_percent.zip"
DATASET_FOOD10_10_PERCENT_ZIPFILE_PATH = DATA_DIR.joinpath(DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME)
DATASET_FOOD10_10_PERCENT_FOLDER_NAME = "10_food_classes_10_percent"
DATASET_FOOD10_10_PERCENT_FOLDER_PATH = DATA_DIR.joinpath(DATASET_FOOD10_10_PERCENT_FOLDER_NAME)

# full data 
DATASET_FOOD10_FULL_URL = "https://drive.google.com/uc?id=1h9Zvm0UKeGMk8hXSdfCJ5E1hIgxvQHN6"
DATASET_FOOD10_FULL_ZIPFILE_NAME = "food10.zip"
DATASET_FOOD10_FULL_ZIPFILE_PATH = DATA_DIR.joinpath(DATASET_FOOD10_FULL_ZIPFILE_NAME)
DATASET_FOOD10_FULL_FOLDER_NAME =  "food10"
DATASET_FOOD10_FULL_FOLDER_PATH =  DATA_DIR.joinpath(DATASET_FOOD10_FULL_FOLDER_NAME)

# download 10 percents dataset
if not DATASET_FOOD10_10_PERCENT_ZIPFILE_PATH.is_file():
    logging.info(f"The {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME} is downloading...")
    try:
        gdown.download(url=DATASET_FOOD10_10_PERCENT_URL, output=str(DATASET_FOOD10_10_PERCENT_ZIPFILE_PATH))
        logging.info(f"The {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME} is downloaded successfully.")
    except Exception as error:
        logging.info(f"Caught this error during downloading {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME}")
else:
    logging.info(f"The {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME} already exists.")

# download full dataset
if not DATASET_FOOD10_FULL_ZIPFILE_PATH.is_file():
    logging.info(f"The {DATASET_FOOD10_FULL_ZIPFILE_NAME} is downloading...")
    try:
        gdown.download(url=DATASET_FOOD10_FULL_URL, output=str(DATASET_FOOD10_FULL_ZIPFILE_PATH))
        logging.info(f"The {DATASET_FOOD10_FULL_ZIPFILE_NAME} is downloaded successfully.")
    except Exception as error:
        logging.info(f"Caught this error during downloading {DATASET_FOOD10_FULL_ZIPFILE_NAME}")
else:
    logging.info(f"The {DATASET_FOOD10_FULL_ZIPFILE_NAME} already exists.")


# extract 10 percents dataset
if not DATASET_FOOD10_10_PERCENT_FOLDER_PATH.is_dir():
    logging.info(f"The {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME} is extracting...")
    try:
        gdown.extractall(path= str(DATASET_FOOD10_10_PERCENT_ZIPFILE_PATH), to=str(DATA_DIR))
        logging.info(f"The {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME} is extracted successfully.")
    except Exception as error:
        logging.info(f"Caught this error during extracting {DATASET_FOOD10_10_PERCENT_ZIPFILE_NAME}")
else:
    logging.info(f"The {DATASET_FOOD10_10_PERCENT_FOLDER_PATH} already exists.")

# extract full dataset
if not DATASET_FOOD10_FULL_FOLDER_PATH.is_dir():
    logging.info(f"The {DATASET_FOOD10_FULL_ZIPFILE_NAME} is extracting...")
    try:
        gdown.extractall(path= str(DATASET_FOOD10_FULL_ZIPFILE_PATH), to=str(DATA_DIR))
        logging.info(f"The {DATASET_FOOD10_FULL_ZIPFILE_NAME} is extracted successfully.")
    except Exception as error:
        logging.info(f"Caught this error during extracting {DATASET_FOOD10_FULL_ZIPFILE_NAME}")
else:
    logging.info(f"The {DATASET_FOOD10_FULL_FOLDER_PATH} already exists.")

In [None]:
def walk_through_data(data_dir: Path):
    total_images = 0
    if data_dir.is_dir():
        for filepaths, dirnames, filenames in os.walk(str(data_dir)):
            if len(filenames) > 0:
                total_images += len(filenames)
        logging.info(f"There are {len(filenames)} images in the {data_dir}")
    else:
        raise FileNotFoundError("There is no given directory.")

In [None]:
## Data Preparing

# 10 percents training set
train_10_percents_dir = DATASET_FOOD10_10_PERCENT_FOLDER_PATH.joinpath("train")
walk_through_data(data_dir=train_10_percents_dir)

# 10 percents testing set
test_10_percents_dir = DATASET_FOOD10_10_PERCENT_FOLDER_PATH.joinpath("test")
walk_through_data(data_dir=test_10_percents_dir)

# full data training set
train_full_data_dir = DATASET_FOOD10_FULL_FOLDER_PATH.joinpath("train")
walk_through_data(data_dir=train_full_data_dir)

# full data testing set
test_full_data_dir = DATASET_FOOD10_FULL_FOLDER_PATH.joinpath("test")
walk_through_data(data_dir=test_full_data_dir)


In [None]:
# class_names
class_names = sorted(os.listdir(train_10_percents_dir))
logging.info(f"Class names: {class_names}")
logging.info(f"len(class_names): {len(class_names)}")

In [None]:
## Data Preprocessing

# data augmentation
data_augmentation = keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.2, 0.2),
])

# train 10 percent dataset
train_10_percent_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory=train_10_percents_dir, label_mode="int", batch_size=batch_size, image_size=(224, 224), shuffle=True, seed=default_seed)

# test 10 percent dataset
test_10_percent_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory=test_10_percents_dir, label_mode="int", batch_size=batch_size, image_size=(224, 224), shuffle=False, seed=default_seed)

# train full data dataset
train_full_data_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory=train_full_data_dir, label_mode="int", batch_size=batch_size, image_size=(224, 224), shuffle=True, seed=default_seed)

# test full data dataset
test_full_data_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory=test_full_data_dir, label_mode="int", batch_size=batch_size, image_size=(224, 224), shuffle=False, seed=default_seed)

#### Functions

In [None]:
def create_model(model: tf.keras.Model = None):

    if model is None:

        logging.info("Feature extraction mode")
        # inputs
        inputs = tf.keras.Input(shape=(224, 224, 3))

        # data augmentation
        x = data_augmentation(inputs)

        # base model
        base_model  = tf.keras.applications.EfficientNetB0(include_top=False)

        # freezing all layers
        base_model.trainable = False

        x = base_model(x, training=False)

        # global average pooling
        x = tf.keras.layers.GlobalAveragePooling2D()(x)

        dropout_rate = 0.2
        logging.info(f"Dropout rate: {dropout_rate}")

        # dropout
        x = tf.keras.layers.Dropout(rate=dropout_rate)(x)

        # outputs
        outputs = tf.keras.layers.Dense(units=len(class_names), activation="softmax")(x)

        # model
        model = tf.keras.Model(inputs, outputs)
    else:
        logging.info("Fine-tuning mode")
        model.trainable = True
        for layer in model.layers[:-10]:
            layer.trainable = False        

    # summary
    logging.info(model.summary(print_fn=logging.info))

    return model


In [None]:
def display_plot(feature_extraction_history: dict):

    logging.info("Display history curves from feature extraction.")

    accuracy = feature_extraction_history["accuracy"]
    loss = feature_extraction_history["loss"]
    val_accuracy = feature_extraction_history["val_accuracy"]
    val_loss = feature_extraction_history["val_loss"]

    epochs = list(range(len(accuracy)))

    # figure
    plt.figure(figsize=(10, 5))
    plt.suptitle(f"Transfer Learning - Feature Extraction : {now_str}")

    # accuracy 
    plt.subplot(1, 2, 1)
    plt.plot(epochs, accuracy, c="r", label="Training dataset")
    plt.plot(epochs, val_accuracy, c="g", label="Testing dataset")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()

    # loss 
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, c="r", label="Training dataset")
    plt.plot(epochs, val_loss, c="g", label="Testing dataset")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    # save history
    plt.savefig(FE_HISTORY_CURVES_FILE_PATH, bbox_inches="tight")

    logging.info(f"Save history curves from feature extraction in this path: {FE_HISTORY_CURVES_FILE_PATH}")

    # display 
    # plt.show()

def display_transfer_learning_plot(feature_extraction_history: dict, fine_tuning_history: dict, initial_epoch: int = 0):

    logging.info("Display history curves from feature extraction and fine-tuning.")

    accuracy = feature_extraction_history["accuracy"]
    loss = feature_extraction_history["loss"]
    val_accuracy = feature_extraction_history["val_accuracy"]
    val_loss = feature_extraction_history["val_loss"]

    start_fine_tuning_epoch = initial_epoch
    logging.info(f"Start fine-tuning from initial_epoch: {initial_epoch}.")
    
    accuracy += fine_tuning_history["accuracy"]
    loss += fine_tuning_history["loss"]
    val_accuracy += fine_tuning_history["val_accuracy"]
    val_loss += fine_tuning_history["val_loss"]

    epochs = list(range(len(accuracy)))
    
    # figure
    plt.figure(figsize=(10, 5))
    plt.suptitle(f"Transfer Learning - Fine-tuning : {now_str}")


    # accuracy 
    plt.subplot(1, 2, 1)
    plt.plot(epochs, accuracy, c="r", label="Training dataset")
    plt.plot(epochs, val_accuracy, c="g", label="Testing dataset")
    plt.plot([start_fine_tuning_epoch, start_fine_tuning_epoch], plt.ylim(), c="b", label="Start fine-tuning")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()

    # loss 
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, c="r", label="Training dataset")
    plt.plot(epochs, val_loss, c="g", label="Testing dataset")
    plt.plot([start_fine_tuning_epoch, start_fine_tuning_epoch], plt.ylim(), c="b", label="Start fine-tuning")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()

    # save history
    plt.savefig(FT_HISTORY_CURVES_FILE_PATH, bbox_inches="tight")

    logging.info(f"Save history curves from from feature extraction and fine-tuning in this path:{FT_HISTORY_CURVES_FILE_PATH}")
    # display 
    # plt.show()

def get_gpu_util():

    logging.info("GPU status")
    nvidia_smi.nvmlInit()
    deviceCount = nvidia_smi.nvmlDeviceGetCount()
    host_name = socket.gethostname()
    now = datetime.now().strftime("%a, %d %b %Y %H:%M:%S")

    for i in range(deviceCount):
        
        handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)

        util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)

        device_name = nvidia_smi.nvmlDeviceGetName(handle).decode('utf-8')
        
        memInfo = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    
        mem_total = int(memInfo.total / 1024 / 1024)

        mem_used =  int(memInfo.used / 1024 / 1024)

        mem_free =  int(mem_total - mem_used)
        
        gpu_util =  int(util.gpu/100.0)
        
        gpu_mem =   int(util.memory/100.0)
        
        logging.info(f"{host_name} - [{i}] {device_name} | {mem_used} / {mem_total} MB")

#### Pre-trained weights

In [None]:
efficientnet  = tf.keras.applications.EfficientNetB0(include_top=False)

# summary
efficientnet.summary()

#### Feature Extraction

In [None]:
class HistoryCallback(tf.keras.callbacks.Callback):
    def __init__(self, epochs: int, wandb: Any, is_wandb: bool = True):
        super().__init__()
        self.epochs = epochs
        self.wandb = wandb
    
    def on_epoch_begin(self, epoch, logs = None):
        self.start_timer = timer() 

    def on_epoch_end(self, epoch, logs = None):
        self.end_timer = timer()
        duration = self.end_timer - self.start_timer

        accuracy = logs.get("accuracy")
        loss = logs.get("loss")
        val_accuracy = logs.get("val_accuracy")
        val_loss = logs.get("val_loss")

        logging.info(f"Epoch: {epoch + 1} / {self.epochs} | "
                     f"{duration:.2f}s | "
                     f"accuracy: {accuracy:.2f}| "
                     f"loss: {loss: .3f} | "
                     f"val_accuracy: {val_accuracy: .2f} | "
                     f"val_loss: {val_loss:.3f}"
                     )

        # history dict
        history_dict = {
            "accuracy": accuracy,
            "loss": loss,
            "val_accuracy": val_accuracy,
            "val_loss": val_loss,
        }

        if is_wandb:
            self.wandb.log(history_dict)

In [None]:
# feature extraction with freezing all layers except top layers
feature_extraction_model = create_model()

# epochs
feature_extraction_epochs = 10
logging.info(f"Feature extraction (epochs): {feature_extraction_epochs}")

# loss function
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
logging.info(f"Feature extraction (loss_fn): SparseCategoricalCrossentropy")

# optimizer
feature_extraction_lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate=feature_extraction_lr)
logging.info(f"Feature extraction (optimizer): Adam")
logging.info(f"Feature extraction (lr): {feature_extraction_lr}")

if is_wandb:
    # init wandb
    wandb.init(project="tensorflow-deep-learning", name=f"tadac/{project_name}_feature_extraction", config={
        "batch_size": batch_size,
        "architecture": model_name,
        "learing_rate": feature_extraction_lr,
        "dataset": "10_food_classes_10_percent",
        "epochs": feature_extraction_epochs,
    })

# callbacks
fe_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_FEATURE_EXTRACTION_DIR, save_best_only=True, save_weights_only=True, verbose=0)
history_callback = HistoryCallback(epochs=feature_extraction_epochs, wandb=wandb, is_wandb=is_wandb)

# compile 
feature_extraction_model.compile(loss=loss_fn,
                                 optimizer=optimizer, 
                                 metrics=["accuracy"])

logging.info(f"len(train_10_percent_dataset): {len(train_10_percent_dataset)}")
logging.info(f"steps_per_epoch: {len(train_10_percent_dataset) / batch_size}")
logging.info(f"len(test_10_percent_dataset): {len(test_10_percent_dataset)}")
logging.info(f"validation_steps: {len(test_10_percent_dataset)/ batch_size}")

# fit
# 0 = silent, 1 = progress bar, 2 = one line per epoch
feature_extraction_history = feature_extraction_model.fit(train_10_percent_dataset, 
                                                           epochs=feature_extraction_epochs,
                                                           steps_per_epoch=len(train_10_percent_dataset), 
                                                           validation_data=test_10_percent_dataset, 
                                                           validation_steps=len(test_10_percent_dataset), 
                                                           shuffle=True, verbose=0,
                                                           workers=cpus, use_multiprocessing= True,
                                                           callbacks=[fe_checkpoint_callback, history_callback])

if is_wandb:
    # finish wandb
    wandb.finish()

# display plot
pd.DataFrame(feature_extraction_history.history).plot()
plt.show()

In [None]:
# get gpu status
get_gpu_util()

##### History Curve

In [None]:
# display history curves of feature extraction 
display_plot(feature_extraction_history=feature_extraction_history.history)

##### Prediction

In [None]:
## load the model
# finds the filename of latest saved checkpoint file and returns a path of the latest one
latest_fe_ckpt_file_path = tf.train.latest_checkpoint(str(MODEL_FEATURE_EXTRACTION_DIR).replace("checkpoint.ckpt", ""))
logging.info(f"Loading the feature extraction model from a saved path: {latest_fe_ckpt_file_path}")

# create an instance of the model
loaded_fe_model = create_model()

# load the latest weights 
loaded_fe_model.load_weights(filepath=latest_fe_ckpt_file_path)

# load an image for evaluation
test_image_file_path = DATA_DIR.joinpath('sushi_test.jpg')
logging.info(f"Using the {test_image_file_path} for evaluation")
if test_image_file_path.is_file():

    # load an image and convert it to (224, 224)
    pil_image = tf.keras.utils.load_img(path=test_image_file_path, target_size=(224, 224))
    
    # convert its image to array
    arr_image = img_to_array(pil_image)

    # create a batch
    image_t = tf.expand_dims(arr_image, axis=0)

    # y probabilities 
    y_probs = loaded_fe_model.predict(image_t)
    y_predicted_prob = tf.reduce_max(y_probs)

    # predicted label id
    y_predicted_label_id = tf.argmax(y_probs, axis=1)
    y_predicted_label = class_names[int(y_predicted_label_id)]
    logging.info(f"Predicted: {y_predicted_label}")
    logging.info(f"Predicted probability: {y_predicted_prob: .2f}")

    # predicted probability
    y_predicted_label_prob = tf.reduce_max(y_probs)

    # save file
    plt.figure(figsize=(5,5))
    plt.suptitle(f"Feature Extraction - Prediction - {now_str}")
    plt.imshow(pil_image)
    plt.title(f"Predicted: {y_predicted_label} | Prob: {y_predicted_prob: .2f}")
    plt.savefig(FE_PREDICTION_FILE_PATH, bbox_inches="tight")

else:
    raise FileNotFoundError("There is no image in the given path.")

#### Fine-tuning

In [None]:
# feature extraction with freezing all layers except top layers
fine_tuning_model = create_model(model=feature_extraction_model)

# epochs
fine_tuning_epochs = feature_extraction_epochs + 10
logging.info(f"Fine tuning (epochs): {fine_tuning_epochs}")

# loss function
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
logging.info(f"Fine tuning (loss_fn): SparseCategoricalCrossentropy")

# optimizer
fine_tuning_lr = 0.0001 # 10x lower
optimizer = tf.keras.optimizers.Adam(learning_rate=fine_tuning_lr)
logging.info(f"Fine tuning (optimizer): Adam")
logging.info(f"Fine tuning (lr): {fine_tuning_lr}")

# init wandb
wandb.init(project="tensorflow-deep-learning", name=f"tadac/{project_name}_fine_tuning", config={
    "batch_size": batch_size,
    "architecture": model_name,
    "learing_rate": fine_tuning_lr,
    "dataset": "food10",
    "epochs": fine_tuning_epochs,
})

if is_wandb:
    # add logs from the feature history dict
    for i in range(len(feature_extraction_history.history["accuracy"])):
        # wandb dict
        fe_wantdb_dict = {
            "accuracy": feature_extraction_history.history["accuracy"][i], 
            "loss" : feature_extraction_history.history["loss"][i],
            "val_accuracy": feature_extraction_history.history["val_accuracy"][i],
            "val_loss": feature_extraction_history.history["val_loss"][i]
        }

        # wandb logging
        wandb.log(fe_wantdb_dict)

# callbacks
fine_tuning_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_FINE_TUNING_DIR, save_best_only=True, save_weights_only=True, verbose=0)

history_callback = HistoryCallback(epochs=fine_tuning_epochs, wandb=wandb, is_wandb=is_wandb)

# compile 
fine_tuning_model.compile(loss=loss_fn,
                                 optimizer=optimizer, 
                                 metrics=["accuracy"])

logging.info(f"len(train_full_data_dataset): {len(train_full_data_dataset)}")
logging.info(f"steps_per_epoch: {len(train_full_data_dataset) / batch_size}")
logging.info(f"len(test_full_data_dataset): {len(test_full_data_dataset)}")
logging.info(f"validation_steps: {len(test_full_data_dataset)/ batch_size}")

# fit
#  0 = silent, 1 = progress bar, 2 = one line per epoch
fine_tuning_history = fine_tuning_model.fit(train_full_data_dataset, initial_epoch= feature_extraction_epochs,
                                                           epochs=fine_tuning_epochs,
                                                           steps_per_epoch=len(train_full_data_dataset), 
                                                           validation_data=test_full_data_dataset, 
                                                           validation_steps=len(test_full_data_dataset), 
                                                           shuffle=True, verbose=0,
                                                           callbacks=[fine_tuning_model_checkpoint_callback, history_callback])


if is_wandb:
    # finish wandb
    wandb.finish()

# display plot
pd.DataFrame(fine_tuning_history.history).plot()
plt.show()


In [None]:
# get gpu status
get_gpu_util()

##### History Curve

In [None]:
# display history curves of both feature extraction and fine-tuning
display_transfer_learning_plot(feature_extraction_history = feature_extraction_history.history, fine_tuning_history = fine_tuning_history.history, initial_epoch=feature_extraction_history.epoch[-1])

##### Prediction

In [None]:
## load the model
# finds the filename of latest saved checkpoint file and returns a path of the latest one
latest_ft_ckpt_file_path = tf.train.latest_checkpoint(str(MODEL_FINE_TUNING_DIR).replace("checkpoint.ckpt", ""))
logging.info(f"Loading the fine-tuning model from a saved path: {latest_ft_ckpt_file_path}")

# create an instance of the model
loaded_ft_model = create_model()

# load the latest weights 
loaded_ft_model.load_weights(filepath=latest_ft_ckpt_file_path)

# load an image for evaluation
test_image_file_path = DATA_DIR.joinpath('sushi_test.jpg')
logging.info(f"Using the {test_image_file_path} for evaluation")
if test_image_file_path.is_file():

    # load an image and convert it to (224, 224)
    pil_image = tf.keras.utils.load_img(path=test_image_file_path, target_size=(224, 224))
    
    # convert its image to array
    arr_image = img_to_array(pil_image)

    # create a batch
    image_t = tf.expand_dims(arr_image, axis=0)

    # y probabilities 
    y_probs = loaded_ft_model.predict(image_t)
    y_predicted_prob = tf.reduce_max(y_probs)

    # predicted label id
    y_predicted_label_id = tf.argmax(y_probs, axis=1)
    y_predicted_label = class_names[int(y_predicted_label_id)]
    logging.info(f"Predicted: {y_predicted_label}")
    logging.info(f"Predicted probability: {y_predicted_prob: .2f}")

    # predicted probability
    y_predicted_label_prob = tf.reduce_max(y_probs)

    # save file
    plt.figure(figsize=(5,5))
    plt.suptitle(f"Fine-tuning - Prediction - {now_str}")
    plt.imshow(pil_image)
    plt.title(f"Predicted: {y_predicted_label} | Prob: {y_predicted_prob: .2f}")
    plt.savefig(FT_PREDICTION_FILE_PATH, bbox_inches="tight")

else:
    raise FileNotFoundError("There is no image in the given path.")

#### References

1. https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
2. https://www.tensorflow.org/guide/keras/making_new_layers_and_models_via_subclassing 
3. https://matplotlib.org/stable/gallery/subplots_axes_and_figures/subplots_demo.html