In [None]:

# from tensorflow.keras.layers import Input, Conv2D, ReLU, MaxPooling2D, Flatten, Dense, Dropout, Add


# === 1. Define the model ===
inputs = Input(shape=(128, 128, 1))

# Block 1
x1 = Conv2D(32, 3, padding='same')(inputs)
x1 = ReLU()(x1)
x1 = MaxPooling2D(2)(x1)

# Block 2
x2 = Conv2D(64, 3, padding='same')(x1)
x2 = ReLU()(x2)
x2 = MaxPooling2D(2)(x2)

# Auxiliary output
aux_x = Flatten()(x2)
aux_output = Dense(2, name='aux_output')(aux_x)

# Block 3
x3 = Conv2D(128, 3, padding='same')(x2)
x3 = ReLU()(x3)

# Skip connection (project x1 to 128 channels)
skip = Conv2D(128, 1, padding='same')(x1)
skip = MaxPooling2D(2)(skip)  # To match (32, 32, 128), matches x3
x3 = Add()([x3, skip])

x3 = MaxPooling2D(2)(x3)
x3 = Flatten()(x3)
x3 = Dense(256, activation='relu')(x3)
x3 = Dropout(0.3)(x3)

# Final output
final_output = Dense(2, name='main_output')(x3)

# Create model
model_api = Model(inputs=inputs, outputs=[final_output, aux_output])

# === 2. Compile the model ===
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

model_api.compile(
    optimizer=optimizer,
    loss={
        'main_output': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        'aux_output': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    },
    loss_weights={
        'main_output': 1.0,
        'aux_output': 0.3
    },
    metrics={
        'main_output': ['accuracy'],
        'aux_output': ['accuracy']
    }
)

# === 3. Training callbacks ===
early_stop = EarlyStopping(monitor='val_main_output_accuracy', patience=3, restore_best_weights=True)

# === 4. Fit the model ===
history = model_api.fit(
    images_train, 
    {'main_output': labels_train, 'aux_output': labels_train},  # Use labels for both outputs
    validation_data=(images_val, {'main_output': labels_val, 'aux_output': labels_val}),
    epochs=30,
    batch_size=256,
    verbose=1,
    callbacks=[early_stop]
)


def plot_training_history(history):
    # Extract data
    hist = history.history

    # Setup plots
    fig, axs = plt.subplots(2, 2, figsize=(14, 10))

    # Accuracy plots
    axs[0, 0].plot(hist['main_output_accuracy'], label='Train Main Accuracy')
    axs[0, 0].plot(hist['val_main_output_accuracy'], label='Val Main Accuracy')
    axs[0, 0].set_title('Main Output Accuracy')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Accuracy')
    axs[0, 0].legend()
    axs[0, 0].grid(True)

    axs[0, 1].plot(hist['aux_output_accuracy'], label='Train Aux Accuracy', color='orange')
    axs[0, 1].plot(hist['val_aux_output_accuracy'], label='Val Aux Accuracy', color='red')
    axs[0, 1].set_title('Auxiliary Output Accuracy')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('Accuracy')
    axs[0, 1].legend()
    axs[0, 1].grid(True)

    # Loss plots
    axs[1, 0].plot(hist['main_output_loss'], label='Train Main Loss')
    axs[1, 0].plot(hist['val_main_output_loss'], label='Val Main Loss')
    axs[1, 0].set_title('Main Output Loss')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('Loss')
    axs[1, 0].legend()
    axs[1, 0].grid(True)

    axs[1, 1].plot(hist['aux_output_loss'], label='Train Aux Loss', color='orange')
    axs[1, 1].plot(hist['val_aux_output_loss'], label='Val Aux Loss', color='red')
    axs[1, 1].set_title('Auxiliary Output Loss')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('Loss')
    axs[1, 1].legend()
    axs[1, 1].grid(True)

    plt.tight_layout()
    plt.show()

# Call the plotting function
plot_training_history(history)



# === 5. Evaluate the model ===
test_results = model_api.evaluate(
    images_test, 
    {'main_output': labels_test, 'aux_output': labels_test},
    verbose=0
)