In [None]:
# STEP 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import load_img
from PIL import Image
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.callbacks import EarlyStopping

# === Paths ===
image_dir = '/content/drive/MyDrive/food_dataset/images'
mask_dir = '/content/drive/MyDrive/food_dataset/mask'

IMG_HEIGHT = 224
IMG_WIDTH = 224
N_CLASSES = 4

# === Color Map ===
color_map = {
    (0, 0, 0): 0,
    (255, 0, 0): 1,
    (0, 255, 0): 2,
    (0, 0, 255): 3
}

def rgb_to_label(mask_rgb):
    label_mask = np.zeros((mask_rgb.shape[0], mask_rgb.shape[1]), dtype=np.uint8)
    for rgb, label in color_map.items():
        matches = np.all(mask_rgb == rgb, axis=-1)
        label_mask[matches] = label
    return label_mask

def augment_image(img, mask):
    img = tf.convert_to_tensor(img, dtype=tf.float32)
    mask = tf.convert_to_tensor(mask, dtype=tf.int32)
    if len(img.shape) == 2:
        img = tf.expand_dims(img, -1)
    if len(mask.shape) == 2:
        mask = tf.expand_dims(mask, -1)

    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_up_down(img)
        mask = tf.image.flip_up_down(mask)
    if tf.random.uniform(()) > 0.5:
        img = tf.image.rot90(img)
        mask = tf.image.rot90(mask)

    img = tf.squeeze(img, axis=-1) if img.shape[-1] == 1 else img
    mask = tf.squeeze(mask, axis=-1) if mask.shape[-1] == 1 else mask

    return img.numpy(), mask.numpy()

def load_data(img_dir, mask_dir, augment=False):
    images = []
    masks = []

    for file in sorted(os.listdir(img_dir)):
        if ('_mask' in file) or (not (file.endswith('.jpg') or file.endswith('.png') or file.endswith('.jpeg'))):
            continue

        filename_base = os.path.splitext(file)[0]
        mask_filename = filename_base + '_mask.png'
        mask_path = os.path.join(mask_dir, mask_filename)
        if not os.path.exists(mask_path):
            print(f"Skipping {file} — mask {mask_filename} not found.")
            continue

        try:
            img = load_img(os.path.join(img_dir, file), target_size=(IMG_HEIGHT, IMG_WIDTH))
            img = np.array(img).astype('float32') / 255.0
        except Exception as e:
            print(f"Error loading {file}: {e}")
            continue

        mask = Image.open(mask_path).resize((IMG_WIDTH, IMG_HEIGHT)).convert("RGB")
        mask = np.array(mask)
        mask = rgb_to_label(mask)

        if augment:
            img, mask = augment_image(img, mask)

        images.append(img)
        masks.append(mask)

    images = np.array(images)
    masks = tf.keras.utils.to_categorical(np.array(masks), num_classes=N_CLASSES)
    return images, masks

# === Load + Augment ===
images, masks = load_data(image_dir, mask_dir, augment=True)

# === Split ===
x_train, x_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)

# === Loss ===
def dice_loss(y_true, y_pred):
    numerator = 2 * tf.reduce_sum(y_true * y_pred)
    denominator = tf.reduce_sum(y_true + y_pred)
    return 1 - numerator / (denominator + 1e-6)

def combined_loss(y_true, y_pred):
    return tf.keras.losses.categorical_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# === U-Net with MobileNetV2 encoder ===
def build_unet_mobilenetv2(input_shape=(224, 224, 3), num_classes=4, dropout=0.1):
    base_model = MobileNetV2(input_shape=input_shape, include_top=False, weights='imagenet')
    skips = [
        base_model.get_layer("block_1_expand_relu").output,
        base_model.get_layer("block_3_expand_relu").output,
        base_model.get_layer("block_6_expand_relu").output,
        base_model.get_layer("block_13_expand_relu").output,
    ]
    x = base_model.get_layer("block_16_project").output

    for skip in reversed(skips):
        x = UpSampling2D()(x)
        x = Concatenate()([x, skip])
        x = Conv2D(256, 3, padding="same", activation="relu")(x)
        x = BatchNormalization()(x)
        x = Dropout(dropout)(x)

    x = UpSampling2D()(x)
    x = Conv2D(128, 3, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    outputs = Conv2D(num_classes, 1, activation="softmax")(x)

    model = Model(inputs=base_model.input, outputs=outputs)
    return model

# === Compile ===
model = build_unet_mobilenetv2()
model.compile(optimizer='adam', loss=combined_loss, metrics=['accuracy', MeanIoU(num_classes=4)])
model.summary()

# === Train ===
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=2, epochs=50, callbacks=[early_stop])

# === Plot History ===
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Loss Over Epochs")
plt.show()

# === Predictions ===
preds = model.predict(x_test)

def show_result(idx):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(x_test[idx])
    ax[0].set_title("Input")
    ax[1].imshow(np.argmax(y_test[idx], axis=-1), cmap='jet')
    ax[1].set_title("Ground Truth")
    ax[2].imshow(np.argmax(preds[idx], axis=-1), cmap='jet')
    ax[2].set_title("Prediction")
    for a in ax:
        a.axis('off')
    plt.tight_layout()
    plt.show()

for i in range(min(5, len(x_test))):
    show_result(i)


Output hidden; open in https://colab.research.google.com to view.