In [None]:
import tensorflow as tf
import os
import numpy as np

from Dataset import CreateDataset
from Deeplabv3p import DeepLabv3p
from Deeplabv3_mid import DeepLabv3p_mid, SimAM
from GourmetNet import GourmetNet
from Unet import Unet

In [None]:
# create a dictionary {name_of_class: class_id} remember background has id 0
DATASET_NAME = "camerfood10"

if DATASET_NAME == "camefood10":
   class_name_dict = {"background":0, "taro": 1, "sauce jaune": 2, "koki": 3, "haricot rouge ou noir": 4, "water fufu": 5, "riz blanc": 6, 
   "baton de manioc": 7, "beignet farine de ble": 8, "sauce tomate": 9, "frites de plantains": 10}

elif "brazillian":
   class_name_dict = {"background":0, "apple": 1, "bean": 2, "boiled egg": 3, "chicken breast": 4, "fried egg": 5, "lunch": 6, 
   "rice": 7, "salad": 8, "spaguetti": 9, "steak": 10}

elif "uecfoodpix":
   class_name_dict = {'background': 0, 'name': 1, 'rice': 2, 'eels on rice': 3, 'pilaf': 4, "chicken-'n'-egg on rice": 5, 'pork cutlet on rice': 6, 
   'beef curry': 7, 'sushi': 8, 'chicken rice': 9, 'fried rice': 10, 'tempura bowl': 11, 'bibimbap': 12, 'toast': 13, 'croissant': 14, 'roll bread': 15, 
   'raisin bread': 16, 'chip butty': 17, 'hamburger': 18, 'pizza': 19, 'sandwiches': 20, 'udon noodle': 21, 'tempura udon': 22, 'soba noodle': 23, 
   'ramen noodle': 24, 'beef noodle': 25, 'tensin noodle': 26, 'fried noodle': 27, 'spaghetti': 28, 'Japanese-style pancake': 29, 'takoyaki': 30, 
   'gratin': 31, 'sauteed vegetables': 32, 'croquette': 33, 'grilled eggplant': 34, 'sauteed spinach': 35, 'vegetable tempura': 36, 'miso soup': 37, 
   'potage': 38, 'sausage': 39, 'oden': 40, 'omelet': 41, 'ganmodoki': 42, 'jiaozi': 43, 'stew': 44, 'teriyaki grilled fish': 45, 'fried fish': 46, 
   'grilled salmon': 47, 'salmon meuniere': 48, 'sashimi': 49, 'grilled pacific saury': 50, 'sukiyaki': 51, 'sweet and sour pork': 52, 
   'lightly roasted fish': 53, 'steamed egg hotchpotch': 54, 'tempura': 55, 'fried chicken': 56, 'sirloin cutlet': 57, 'nanbanzuke': 58, 
   'boiled fish': 59, 'seasoned beef with potatoes': 60, 'hambarg steak': 61, 'beef steak': 62, 'dried fish': 63, 'ginger pork saute': 64, 
   'spicy chili-flavored tofu': 65, 'yakitori': 66, 'cabbage roll': 67, 'rolled omelet': 68, 'egg sunny-side up': 69, 'fermented soybeans': 70, 
   'cold tofu': 71, 'egg roll': 72, 'chilled noodle': 73, 'stir-fried beef and peppers': 74, 'simmered pork': 75, 'boiled chicken and vegetables': 76, 
   'sashimi bowl': 77, 'sushi bowl': 78, 'fish-shaped pancake with bean jam': 79, 'shrimp with chill source': 80, 'roast chicken': 81, 
   'steamed meat dumpling': 82, 'omelet with fried rice': 83, 'cutlet curry': 84, 'spaghetti meat sauce': 85, 'fried shrimp': 86, 'potato salad': 87, 
   'green salad': 88, 'macaroni salad': 89, 'Japanese tofu and vegetable chowder': 90, 'pork miso soup': 91, 'chinese soup': 92, 'beef bowl': 93, 
   'kinpira-style sauteed burdock': 94, 'rice ball': 95, 'pizza toast': 96, 'dipping noodles': 97, 'hot dog': 98, 'french fries': 99, 'mixed rice': 100, 
   'goya chanpuru': 101, 'others': 102, 'beverage': 103}


## Hyperparameters
ROOT_DATASET_PATH = "public/CamerFood10v2"
EPOCHS = 250
LEARNING_RATE = 1e-4
BATCH_SIZE = 2
IMAGE_SIZE = 512
LOGS_DIR = "public/logs2/camerfood_res101"
NB_CLASS = len(class_name_dict)
BACKBONE = "res50" # res101, res50, xception

if not os.path.exists(LOGS_DIR):
   # Create a new directory because it does not exist
   os.makedirs(LOGS_DIR)

## Create the dataset
DATASET_PATH = os.path.join(ROOT_DATASET_PATH, "test")
test_dataset = CreateDataset(DATASET_PATH, DATASET_PATH, IMAGE_SIZE, BATCH_SIZE).get()

DATASET_PATH = os.path.join(ROOT_DATASET_PATH, "train")
train_dataset = CreateDataset(DATASET_NAME, DATASET_PATH, IMAGE_SIZE, BATCH_SIZE).get()

In [None]:
## Create and Save the model architecture
try:
    with open(os.path.join(LOGS_DIR,"model.json"), "x") as json_file:
        # model = DeepLabv3p(num_classes=NB_CLASS, encoder_name="res101", input_shape=(512, 512, 3))()
        model = DeepLabv3p_mid(num_classes=NB_CLASS, backbone_name=BACKBONE, finetune=True, input_shape=(512, 512, 3))()
        # model = Unet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), classes=NB_CLASS)
        # serialize model to JSON
        model_json = model.to_json()
        json_file.write(model_json)
except FileExistsError:
        print("Oops!  This file already exist. Maybe the model have already been saved...")
        
## Load the model architecture
json_file = open(os.path.join(LOGS_DIR,"model.json"), 'r')
loaded_model_json = json_file.read()
json_file.close()

# At loading time, register the custom objects with a `custom_object_scope`:
custom_objects = {"SimAM": SimAM}
model = None
with tf.keras.utils.custom_object_scope(custom_objects):
    model = tf.keras.models.model_from_json(loaded_model_json)
    # model.summary()

class MeanIoU(tf.keras.metrics.MeanIoU):
    def __init__(self,
               y_true=None,
               y_pred=None,
               num_classes=None,
               name=None,
               dtype=None):
        super(MeanIoU, self).__init__(num_classes = num_classes,name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.math.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)


epoch_start = 0
callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.8, patience=30),
    tf.keras.callbacks.CSVLogger(os.path.join(LOGS_DIR,"data.csv"), append=True, separator=","),
    # tf.keras.callbacks.TensorBoard(log_dir=LOGS_DIR, histogram_freq=0, write_graph=True, write_images=False),
    tf.keras.callbacks.ModelCheckpoint(os.path.join(LOGS_DIR,"saved_model_{epoch:02d}.h5"), monitor='val_mIoU', verbose=0,
                    save_best_only=True,  save_weights_only=False, mode='max', save_freq='epoch')
]
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE), 
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy', MeanIoU(num_classes=NB_CLASS, name='mIoU')])

model.fit(train_dataset, epochs=EPOCHS, validation_data=test_dataset, callbacks=callbacks, initial_epoch=epoch_start)
model.save_weights(os.path.join(LOGS_DIR,"model.h5"), overwrite=True)
model.evaluate(test_dataset)