In [None]:
import tensorflow as tf
from data_loader import load_tfds_cityscapes
from architecture import unet
from grad_cam import grad_cam_for_segmentation
from three_best_3_worst import display_best_worst_predictions
from visualize_results import display_predictions
from config import Config

In [None]:
# Load datasets
train_dataset, val_dataset = load_tfds_cityscapes(Config.BATCH_SIZE)

# Initialize model
model = unet(input_size=(Config.IMAGE_SIZE[0], Config.IMAGE_SIZE[1], 3), n_classes=Config.NUM_CLASSES)

# Compile model
model.compile(optimizer=tf.keras.optimizers.Adam(Config.LEARNING_RATE),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Train model
history = model.fit(train_dataset, epochs=Config.EPOCHS, validation_data=val_dataset)

# Display training history
def plot_results(history):
    import matplotlib.pyplot as plt
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy over Epochs')

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss over Epochs')

    plt.show()

plot_results(history)

# Evaluate model on validation set
results = model.evaluate(val_dataset)
print("Validation Loss: {:.5f}, Validation Accuracy: {:.2f}%".format(results[0], results[1] * 100))

# Assuming you have a method to get a batch of images and their labels from val_dataset
# for Grad-CAM visualization, select images and class_idx as needed
images, labels = next(iter(val_dataset))  # Example to get a batch
class_idx = 0  # Example class index for visualization

# Grad-CAM Visualization
grad_cam_for_segmentation(model, images.numpy(), class_idx, layer_name='last_conv_layer_name')  # Specify the correct layer_name

# Display best and worst predictions
display_best_worst_predictions(model, val_dataset, num_examples=3)

# Display generic predictions
display_predictions(model, val_dataset, num_display=3)