# **Setup**

In [26]:
import os
import tensorflow as tf
import tf_keras as tfk
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from tf_keras import layers
from tf_keras import backend
from tf_keras.preprocessing.image import ImageDataGenerator
from tf_keras.preprocessing import image
from tf_keras.applications.imagenet_utils import preprocess_input, decode_predictions

from sklearn.metrics import confusion_matrix
from sklearn.utils import class_weight

os.environ["KERAS_BACKEND"] = "tensorflow"

import warnings
warnings.filterwarnings("ignore")

# **Dataset Preparation**

In [2]:
TRAIN = "dataset/train_end"
TEST = "dataset/test_end"
TRAIN_CSV = 'train_benchmark.csv'
TEST_CSV = 'test_benchmark.csv'

PATCH_SIZE = 4
EXPANSION_FACTOR = 2
IMG_SIZE = 256
BATCH_SIZE = 64
RANDOM_STATE = 42

In [3]:
def load_data(PATH):
    df = pd.read_csv(PATH)
    df['id'] = df['id'].astype(str) + '.png'
    df['jenis'] = df['jenis'].astype(str)
    df['warna'] = df['warna'].astype(str)
    return df

train_data = load_data(TRAIN_CSV)
test_data = load_data(TEST_CSV)

# **Mobile-ViT**

In [32]:
def conv_block(x, filters=16, kernel_size=3, strides=2):
    x = layers.Conv2D(
        filters,
        kernel_size,
        strides=strides,
        activation=tfk.activations.swish,
        padding="same",
    )(x)
    return x

def correct_pad(inputs, kernel_size):
    img_dim = 2 if backend.image_data_format() == "channels_first" else 1
    input_size = inputs.shape[img_dim : (img_dim + 2)]
    
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
        
    if input_size[0] is None:
        adjust = (1, 1)
    else:
        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
        
    correct = (kernel_size[0] // 2, kernel_size[1] // 2)
    return (
        (correct[0] - adjust[0], correct[0]),
        (correct[1] - adjust[1], correct[1]),
    )

def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = tfk.activations.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
        
    m = layers.DepthwiseConv2D(3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False)(m)
    m = layers.BatchNormalization()(m)
    m = tfk.activations.swish(m)
    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if tf.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    
    return m

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tfk.activations.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(
            x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]],
            dropout_rate=0.1,
        )
        x = layers.Add()([x3, x2])

    return x

def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    num_patches = int((local_features.shape[1] * local_features.shape[2]) / PATCH_SIZE)
    non_overlapping_patches = layers.Reshape((PATCH_SIZE, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features

In [33]:
def MobileViT(num_classes):
    tfk.backend.clear_session()
    
    inputs = tfk.Input((IMG_SIZE, IMG_SIZE, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)

    """Initial conv-stem -> MV2 block."""
    x = conv_block(x, filters=16)
    x = inverted_residual_block(x, expanded_channels=16 * EXPANSION_FACTOR, output_channels=16)

    """Downsampling with MV2 block."""
    x = inverted_residual_block(x, expanded_channels=16 * EXPANSION_FACTOR, output_channels=24, strides=2)
    x = inverted_residual_block(x, expanded_channels=24 * EXPANSION_FACTOR, output_channels=24)
    x = inverted_residual_block(x, expanded_channels=24 * EXPANSION_FACTOR, output_channels=24)

    """First MV2 -> MobileViT block."""
    x = inverted_residual_block(x, expanded_channels=24 * EXPANSION_FACTOR, output_channels=48, strides=2)
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    """Second MV2 -> MobileViT block."""
    x = inverted_residual_block(x, expanded_channels=64 * EXPANSION_FACTOR, output_channels=64, strides=2)
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)

    """Third MV2 -> MobileViT block."""
    x = inverted_residual_block(x, expanded_channels=80 * EXPANSION_FACTOR, output_channels=80, strides=2)
    x = mobilevit_block(x, num_blocks=3, projection_dim=96)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    """Classification head."""
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return tfk.Model(inputs, outputs)

In [None]:
def run_experiment(num_classes, train_data, test_data, cw):
    mobilevit_xxs = MobileViT(num_classes=num_classes)
    
    mobilevit_xxs.compile(
        optimizer=tfk.optimizers.Adam(learning_rate=0.002), 
        loss=tfk.losses.CategoricalCrossentropy(label_smoothing=0.1), 
        metrics=["accuracy"]
    )

    history = mobilevit_xxs.fit(
        train_data, epochs=50,
        validation_data=test_data,
        callbacks=[tfk.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)],
        class_weight=cw
    )
    
    score = mobilevit_xxs.evaluate(test_data)
    print("Test Loss:", score[0])
    print("Test Accuracy:", score[1])
    
    return history, mobilevit_xxs

# **Define Evaluate Function**

In [None]:
def plot_history(history):
    train_accuracy = history.history['accuracy']
    val_accuracy = history.history['val_accuracy']

    train_loss = history.history['loss']
    val_loss = history.history['val_loss']

    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(12, 10))

    ax[0].set_title('Training Accuracy vs. Epochs')
    ax[0].plot(train_accuracy, 'o-', label='Train Accuracy')
    ax[0].plot(val_accuracy, 'o-', label='Validation Accuracy')
    ax[0].set_xlabel('Epochs')
    ax[0].set_ylabel('Accuracy')
    ax[0].legend(loc='best')

    ax[1].set_title('Training/Validation Loss vs. Epochs')
    ax[1].plot(train_loss, 'o-', label='Train Loss')
    ax[1].plot(val_loss, 'o-', label='Validation Loss')
    ax[1].set_xlabel('Epochs')
    ax[1].set_ylabel('Loss')
    ax[1].legend(loc='best')

    plt.tight_layout()
    plt.show()

In [None]:
def plot_confusion_matrix(model, model_name, generator): 
    y_pred = np.argmax(model.predict(generator), axis=1).tolist()
    y_true = generator.classes

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=generator.class_indices.keys(), yticklabels=generator.class_indices.keys())
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title(f'{model_name} Confusion Matrix')
    plt.show()

# **Training Jenis**

In [None]:
train_jenis = ImageDataGenerator(
    rescale=1./255
).flow_from_dataframe(
    dataframe=train_data,
    directory=TRAIN,  
    x_col='id',  
    y_col='jenis',
    target_size=(IMG_SIZE, IMG_SIZE), 
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    seed=RANDOM_STATE
)

test_jenis = ImageDataGenerator(
    rescale=1./255
).flow_from_dataframe(
    dataframe=test_data,
    directory=TEST,  
    x_col='id',  
    y_col='jenis',
    target_size=(IMG_SIZE, IMG_SIZE), 
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False,
    seed=RANDOM_STATE
)

In [None]:
history_jenis, mobilevit_jenis = run_experiment(
    num_classes=2, 
    train_data=train_jenis, 
    test_data=test_jenis,
    cw=class_weight.compute_class_weight('balanced', np.unique(train_jenis.classes), train_jenis.classes)
)
mobilevit_jenis.save('mobilevit_jenis.keras')

In [None]:
plot_history(history_jenis)
plot_confusion_matrix(mobilevit_jenis, "MobileViT Jenis", test_jenis)

# **Training Warna**

In [None]:
train_warna = ImageDataGenerator(
    rescale=1./255
).flow_from_dataframe(
    dataframe=train_data,
    directory=TRAIN,  
    x_col='id',  
    y_col='warna',
    target_size=(IMG_SIZE, IMG_SIZE), 
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    seed=RANDOM_STATE
)

test_warna = ImageDataGenerator(
    rescale=1./255
).flow_from_dataframe(
    dataframe=test_data,
    directory=TEST,  
    x_col='id',  
    y_col='warna',
    target_size=(IMG_SIZE, IMG_SIZE), 
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False,
    seed=RANDOM_STATE
)

In [None]:
history_warna, mobilevit_warna = run_experiment(
    num_classes=2, 
    train_data=train_warna, 
    test_data=test_warna,
    cw=class_weight.compute_class_weight('balanced', np.unique(train_warna.classes), train_warna.classes)
)
mobilevit_warna.save('mobilevit_warna.keras')

In [None]:
plot_history(history_warna)
plot_confusion_matrix(mobilevit_warna, "MobileViT Warna", test_warna)

# **Testing**

In [None]:
def load_and_preprocess_image(img_path):
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array, mode='tf', data_format=None)
    return img, img_array

def visualize_predictions(data: pd.DataFrame, pathOfImage: str, model_jenis, model_warna, num_images=15):
    plt.figure(figsize=(15, 9))

    sample_data = data.sample(n=num_images)

    for i, row in enumerate(sample_data.itertuples()):
        img_path = os.path.join(pathOfImage, str(row.id))
        img, img_array = load_and_preprocess_image(img_path)

        pred_jenis = np.argmax(model_jenis.predict(img_array), axis=1)[0]
        pred_warna = np.argmax(model_warna.predict(img_array), axis=1)[0]

        actual_jenis = int(row.jenis)
        actual_warna = int(row.warna)

        label_color = 'black'
        if actual_jenis != pred_jenis or actual_warna != pred_warna:
            label_color = 'red'

        plt.subplot(3, 5, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(
            f"Warna Aktual = {actual_warna}, Prediksi = {pred_warna}\n"
            f"Jenis Aktual = {actual_jenis}, Prediksi = {pred_jenis}", 
            color=label_color
        )

    plt.tight_layout()
    plt.show()

visualize_predictions(train_data, TRAIN, mobilevit_jenis, mobilevit_warna)