Welcome to this demonstration notebook for the Weather Classification Dataset and our trained models. This notebook provides an overview of the dataset, explores its structure, and showcases the performance of various models trained for weather classification.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from data_loading import download_dataset

In [None]:
# Assuming download_dataset() is provided
dataset_path = "data/weather-dataset"
# download_dataset(dataset_path)
print(f"Dataset downloaded to: {dataset_path}")

# List dataset contents
classes = sorted(os.listdir(dataset_path))  # Sort for consistent ordering
num_classes = len(classes)
print(f"Number of Classes: {num_classes}")
print("Classes:", classes)

# Count total images
num_images = sum(len(os.listdir(os.path.join(dataset_path, cls))) for cls in classes)
print(f"Total Number of Images: {num_images}")

In [None]:
def show_samples(dataset_path, classes, images_per_row=3):
    num_classes = len(classes)
    rows = (num_classes + images_per_row - 1) // images_per_row
    
    fig, axes = plt.subplots(rows, images_per_row, figsize=(images_per_row * 3, rows * 3))
    axes = np.array(axes).reshape(-1)  # Flatten axes array for easier indexing

    for i, cls in enumerate(classes):
        class_path = os.path.join(dataset_path, cls)
        sample_image = np.random.choice(os.listdir(class_path), size=1)[0]
        
        img_path = os.path.join(class_path, sample_image)
        img = Image.open(img_path).convert("RGB")
        
        axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(cls, fontsize=12)
    
    # Hide unused subplots if number of classes is not a multiple of images_per_row
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")
    
    plt.tight_layout()
    plt.show()

show_samples(dataset_path, classes, images_per_row=3)

In [None]:
import matplotlib.pyplot as plt

# Count images per class
class_counts = {cls: len(os.listdir(os.path.join(dataset_path, cls))) for cls in classes}

# Sort by class name (optional)
class_counts = dict(sorted(class_counts.items()))

# Print note about class imbalance
print("As we can see in the following plot, the dataset has an uneven class distribution.")
print("During training, this is handled using class weights to balance the impact of each class.")

# Plot class distribution
plt.figure(figsize=(10, 5))
plt.bar(class_counts.keys(), class_counts.values(), color="royalblue")
plt.xlabel("Weather Classes", fontsize=12)
plt.ylabel("Number of Images", fontsize=12)
plt.title("Class Distribution in Weather Classification Dataset", fontsize=14)
plt.xticks(rotation=45, ha="right")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.show()