In [None]:
import os
import sys
import torch
from pathlib import Path

current_dir = Path('__file__').resolve().parent
parent_dir = current_dir.parent
sys.path.append(str(parent_dir / 'scripts'))

from ConditionClassifier import ConditionClassifier
from ConditionDataset import ConditionDataset

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt

# go into parent directory (repository)
os.chdir('..')

In [None]:
# Define the transform (preprocessing pipeline)
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images
    transforms.ToTensor(),  # Convert to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Define dataset paths
original_test_dir = 'data/cityscapes1/train/'
augmented_test_dir = 'data/aug_cityscapes2/train'

# Create dataset instance
test_dataset = ConditionDataset(original_test_dir, augmented_test_dir, transform=transform)

# Create DataLoader
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Initialize lists to store images for each category
clear_images, foggy_images, glaring_images = [], [], []

# Iterate through the dataset to collect samples
for i in range(len(test_dataset)):
    image, label = test_dataset[i]
    
    if label == 0 and len(clear_images) < 9:
        clear_images.append((image, label))
    elif label == 1 and len(foggy_images) < 9:
        foggy_images.append((image, label))
    elif label == 2 and len(glaring_images) < 9:
        glaring_images.append((image, label))
    
    # Stop when we have enough images
    if len(clear_images) == 9 and len(foggy_images) == 9 and len(glaring_images) == 9:
        break

# Combine all images into a single list
all_images = clear_images + foggy_images + glaring_images

# Define mean and std for un-normalization
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Create a 3x9 subplot grid
fig, axes = plt.subplots(3, 9, figsize=(18, 6))

# Display images
for idx, (image, label) in enumerate(all_images):
    row = idx // 9
    col = idx % 9

    # Convert tensor to NumPy image
    image = image.permute(1, 2, 0).numpy()
    
    # Reverse normalization
    image = (image * std) + mean
    image = np.clip(image, 0, 1)

    # Plot image
    axes[row, col].imshow(image)
    axes[row, col].axis("off")

# Titles for rows
axes[0, 4].set_title("Clear", fontsize=14)
axes[1, 4].set_title("Foggy", fontsize=14)
axes[2, 4].set_title("Glaring", fontsize=14)

save_dir = 'experiments/results'
os.makedirs(save_dir, exist_ok=True)

plt.savefig(os.path.join(save_dir, "condition_comparison.png"), dpi=300, bbox_inches='tight')

plt.tight_layout()
plt.show()