In [None]:
# Test Notebook for Project

## Objective
#This notebook is dedicated to testing the core functions and modules of the project. Each section corresponds to a specific functionality, ensuring that individual components work as expected and integrate seamlessly.

### Import Libraries and Setup

import pytest
import numpy as np
import torch
# Add the src directory to sys.path

import sys
sys.path.append('/content/drive/MyDrive/GAN-thesis-project/src')
from data_utils import preprocess_images, load_mnist_data, split_dataset

In [None]:
### Test `preprocess_images`

#### Description
#The function `preprocess_images` normalizes images to the range [-1, 1].

#### Test Cases
#- Verify that output values are in the range [-1, 1].
#- Check the output shape matches the expected dimensions.

# Generate a sample batch of random images
sample_images = np.random.randint(0, 256, (10, 28, 28), dtype=np.uint8)
preprocessed_images = preprocess_images(sample_images)

# Assertions
assert preprocessed_images.min() >= -1.0, "Preprocessed images should not have values below -1."
assert preprocessed_images.max() <= 1.0, "Preprocessed images should not have values above 1."
assert preprocessed_images.shape == (10, 1, 28, 28), "Image shape after preprocessing is incorrect."

print("`preprocess_images` test passed.")

In [None]:
### Test `load_mnist_data`

#### Description
#The function `load_mnist_data` loads and preprocesses the MNIST dataset, allowing for fraction-based sampling and batching.

#### Test Cases
#- Verify the dataset fraction is applied correctly.
#- Ensure the batch size is respected.

data_loader = load_mnist_data(fraction=0.1, batch_size=8, shuffle=True)

# Fetch a single batch for testing
for batch in data_loader:
    images, labels = batch
    # Assertions
    assert images.shape[0] <= 8, "Batch size exceeds the specified limit."
    assert images.shape[1:] == (1, 28, 28), "Image dimensions are incorrect."
    print(f"Batch shape: {images.shape}, Labels shape: {labels.shape}")
    break

print("`load_mnist_data` test passed.")

In [None]:
### Test `split_dataset`

#### Description
#The function `split_dataset` splits a dataset into training and validation sets based on a specified ratio.

#### Test Cases
#- Ensure proper splitting ratios.
#- Verify no data leakage between training and validation sets.

# Generate random data and labels
data = np.random.rand(100, 28, 28)
labels = np.random.randint(0, 10, 100)

(train_x, train_y), (val_x, val_y) = split_dataset(data, labels, validation_split=0.2)

# Assertions
assert len(train_x) == 80, "Training data size is incorrect."
assert len(val_x) == 20, "Validation data size is incorrect."
assert len(train_y) == 80, "Training labels size is incorrect."
assert len(val_y) == 20, "Validation labels size is incorrect."

# # Check for data leakage
# assert not set(train_y).intersection(set(val_y)), "Data leakage detected between training and validation sets."

# print("`split_dataset` test passed.")
