### 📝 Imports

In [1]:
import torch
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
import os

from cats_dataset import CatsDataset
from tools.augmentation.data_augmenter import DataAugmenter

from PIL import Image



### 🔧 Config

In [2]:
bathroom_cat_dir_path = 'data/bathroom-cat-128x128/'
captain_dir_path = 'data/captain-128x128/'
control_dir_path= 'data/control-128x128/'

# Directory to save augmented images
augmented_images_dir = 'data/augmented_images'
os.makedirs(augmented_images_dir, exist_ok=True)

### 🌐 Create Transforms

In [None]:
transform = transforms.Compose([
    DataAugmenter(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

### 🚦 Load Dataset

In [4]:
dataset = CatsDataset(
    bathroom_cat_dir=bathroom_cat_dir_path, 
    captain_dir=captain_dir_path, 
    control_dir=control_dir_path, 
    transform=transform
)

### 🖼 Visualize and Save Augmented Images

In [None]:
sample_loader = DataLoader(dataset, batch_size=8, shuffle=True)
data_iter = iter(sample_loader)
images, labels = next(data_iter)

def save_augmented_images(images, directory):
    for idx in range(images.size(0)):
        img = images[idx]
        # Unnormalize
        img = img * 0.5 + 0.5
        # Convert to NumPy array
        np_img = img.numpy()
        # Transpose to (H, W, C)
        np_img = np.transpose(np_img, (1, 2, 0))
        # Convert to PIL Image
        pil_img = Image.fromarray((np_img * 255).astype(np.uint8))
        # Save image
        pil_img.save(os.path.join(directory, f'augmented_image_{idx}.png'))

save_augmented_images(images, augmented_images_dir)
print(f"Augmented images have been saved to '{augmented_images_dir}' directory.")

def imshow(img):
    img = img * 0.5 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

for i in range(images.size(0)):
    imshow(images[i])

