Welcome to this demonstration notebook for the Weather Classification Dataset and our trained models. This notebook provides an overview of the dataset, explores its structure, and showcases the performance of various models trained for weather classification.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from data_loading import download_dataset, WeatherDataModule, get_val_transforms
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import torch
import seaborn as sns

In [None]:
# Assuming download_dataset() is provided
dataset_path = "data/weather-dataset"
# download_dataset(dataset_path)
print(f"Dataset downloaded to: {dataset_path}")

# List dataset contents
classes = sorted(os.listdir(dataset_path))  # Sort for consistent ordering
num_classes = len(classes)
print(f"Number of Classes: {num_classes}")
print("Classes:", classes)

# Count total images
num_images = sum(len(os.listdir(os.path.join(dataset_path, cls))) for cls in classes)
print(f"Total Number of Images: {num_images}")

In [None]:
def show_samples(dataset_path, classes, images_per_row=3):
    num_classes = len(classes)
    rows = (num_classes + images_per_row - 1) // images_per_row
    
    fig, axes = plt.subplots(rows, images_per_row, figsize=(images_per_row * 3, rows * 3))
    axes = np.array(axes).reshape(-1)  # Flatten axes array for easier indexing

    for i, cls in enumerate(classes):
        class_path = os.path.join(dataset_path, cls)
        sample_image = np.random.choice(os.listdir(class_path), size=1)[0]
        
        img_path = os.path.join(class_path, sample_image)
        img = Image.open(img_path).convert("RGB")
        
        axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(cls, fontsize=12)
    
    # Hide unused subplots if number of classes is not a multiple of images_per_row
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")
    
    plt.tight_layout()
    plt.show()

show_samples(dataset_path, classes, images_per_row=3)

In [None]:
import matplotlib.pyplot as plt

# Count images per class
class_counts = {cls: len(os.listdir(os.path.join(dataset_path, cls))) for cls in classes}

# Sort by class name (optional)
class_counts = dict(sorted(class_counts.items()))

# Print note about class imbalance
print("As we can see in the following plot, the dataset has an uneven class distribution.")
print("During training, this is handled using class weights to balance the impact of each class.")

# Plot class distribution
plt.figure(figsize=(10, 5))
plt.bar(class_counts.keys(), class_counts.values(), color="royalblue")
plt.xlabel("Weather Classes", fontsize=12)
plt.ylabel("Number of Images", fontsize=12)
plt.title("Class Distribution in Weather Classification Dataset", fontsize=14)
plt.xticks(rotation=45, ha="right")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.show()

In [None]:
# Dataset
datamodule = WeatherDataModule("./data/weather-dataset", 32, 1, get_val_transforms(), get_val_transforms())
datamodule.setup()
test_dataloader = datamodule.test_dataloader()
X_test = []
y_test = []
for images, labels in test_dataloader:
    X_test.append(images)
    y_test.append(labels)
X_test = torch.cat(X_test)
y_test = torch.cat(y_test)

In [None]:
# Model loading
models = []

In [None]:
# Predictions
model_predictions = {}

In [None]:
# Dictionary containing predictions from different models
# Example: model_predictions = {"Model_1": y_pred_1, "Model_2": y_pred_2, ...}
model_scores = {}
y_true = y_test

for model_name, y_pred in model_predictions.items():
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average="macro")
    model_scores[model_name] = {"Accuracy": accuracy, "F1-Score": f1}


In [None]:
# Extract model names, accuracy, and F1-score
model_names = list(model_scores.keys())
accuracies = [model_scores[m]["Accuracy"] for m in model_names]
f1_scores = [model_scores[m]["F1-Score"] for m in model_names]

x = np.arange(len(model_names))  # X-axis positions

# Plot bar chart
plt.figure(figsize=(10, 5))
bar_width = 0.4
plt.bar(x - bar_width / 2, accuracies, width=bar_width, label="Accuracy", color="royalblue")
plt.bar(x + bar_width / 2, f1_scores, width=bar_width, label="F1-Score", color="darkorange")

plt.xticks(x, model_names, rotation=30, ha="right")
plt.ylabel("Score")
plt.title("Model Comparison: Accuracy & F1-Score")
plt.legend()
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.show()

In [None]:
# Plot confusion matrices for all models
fig, axes = plt.subplots(1, len(model_predictions), figsize=(len(model_predictions) * 5, 5))

if len(model_predictions) == 1:  # Handle single model case
    axes = [axes]

for ax, (model_name, y_pred) in zip(axes, model_predictions.items()):
    cm = confusion_matrix(y_true, y_pred)
    
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes, ax=ax)
    ax.set_title(f"Confusion Matrix: {model_name}")
    ax.set_xlabel("Predicted Labels")
    ax.set_ylabel("True Labels")

plt.tight_layout()
plt.show()

In [None]:
# Cleanup