In [46]:
import os
import random
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from tqdm import tqdm

In [47]:
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [40]:
# Define transformations
transform_base = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_augment = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=1.0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [41]:
# Denormalization function
def denormalize(image_tensor):
    """
    Reverses the normalization applied during preprocessing.
    Args:
        image_tensor: A normalized PyTorch tensor.
    Returns:
        A denormalized PyTorch tensor.
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(image_tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(image_tensor.device)
    return image_tensor * std + mean

In [42]:
# Custom Dataset Class
class AffectNetDataset(Dataset):
    def __init__(self, root_dir, transform=None, augment_classes=None, target_count=5000):
        self.root_dir = root_dir
        self.transform = transform
        self.augment_classes = augment_classes or []
        self.target_count = target_count
        self.image_paths = []
        self.labels = []

        # Load data
        for label in os.listdir(root_dir):
            class_path = os.path.join(root_dir, label)
            if os.path.isdir(class_path):
                images = [os.path.join(class_path, img) for img in os.listdir(class_path) if img.endswith(('.png', '.jpg', '.jpeg'))]
                self.image_paths.extend(images)
                self.labels.extend([int(label)] * len(images))
        
        # Augment classes if needed
        self._augment_data()

    def _augment_data(self):
        for label in self.augment_classes:
            class_indices = [i for i, lbl in enumerate(self.labels) if lbl == label]
            current_count = len(class_indices)
            if current_count < self.target_count:
                augment_count = self.target_count - current_count
                selected_indices = random.choices(class_indices, k=augment_count)
                for idx in selected_indices:
                    self.image_paths.append(self.image_paths[idx])
                    self.labels.append(self.labels[idx])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [44]:
# Paths
dataset_root = "AffectNet"  # Replace with your dataset path
output_root = "AffOut"  # Replace with your output path
temp_output_root = "AffTemp"  # Replace with your temp output path
os.makedirs(output_root, exist_ok=True)
os.makedirs(temp_output_root, exist_ok=True)

In [45]:
# Dataset Preparation
data_splits = ["train", "val", "test"]
augment_classes = [1, 7]  # Disgust and Contempt
target_count = 5000

def process_split(split, output_dir, test_attempt=False):
    split_dir = os.path.join(dataset_root, split)
    output_split_dir = os.path.join(output_dir, split)
    os.makedirs(output_split_dir, exist_ok=True)

    dataset = AffectNetDataset(
        root_dir=split_dir,
        transform=transform_base if split != "train" else transform_augment,
        augment_classes=augment_classes if split == "train" else None,
        target_count=target_count
    )
    
    for label in range(8):
        label_dir = os.path.join(output_split_dir, str(label))
        os.makedirs(label_dir, exist_ok=True)

    print(f"Processing {split} split{' (Test Attempt)' if test_attempt else ''}...")
    with tqdm(total=len(dataset) if not test_attempt else 8 * 5, desc=f"Processing {split}", unit="image") as pbar:
        processed = 0
        for idx, (image, label) in enumerate(dataset):
            try:
                if test_attempt and processed >= 5 * 8:  # 5 images per folder for test
                    break

                # Move the tensor to the GPU and process it
                image_tensor = image.to(device)

                # Denormalize before saving
                image_tensor = denormalize(image_tensor.unsqueeze(0)).squeeze(0)
                pil_image = transforms.ToPILImage()(image_tensor.cpu())

                # Save the image
                label_dir = os.path.join(output_split_dir, str(label))
                output_path = os.path.join(label_dir, f"{split}_{idx}.jpg")
                pil_image.save(output_path)

                processed += 1
                pbar.update(1)

            except Exception as e:
                print(f"Error processing image {idx} in {split} split: {e}")
                pbar.close()
                raise e

    print(f"{split} split{' (Test Attempt)' if test_attempt else ''} complete!")

# Run Test Attempt
print("Running test attempt for all splits...")
for split in data_splits:
    process_split(split, temp_output_root, test_attempt=True)

# If test attempt succeeds, process the full dataset
print("Test attempt successful. Starting full preprocessing...")
for split in data_splits:
    process_split(split, output_root, test_attempt=False)

print("Data preprocessing complete!")

Running test attempt for all splits...
Processing train split (Test Attempt)...


Processing train: 100%|██████████| 40/40 [00:00<00:00, 87.82image/s]


train split (Test Attempt) complete!
Processing val split (Test Attempt)...


Processing val: 100%|██████████| 40/40 [00:00<00:00, 93.72image/s]


val split (Test Attempt) complete!
Processing test split (Test Attempt)...


Processing test: 100%|██████████| 40/40 [00:00<00:00, 82.69image/s] 


test split (Test Attempt) complete!
Test attempt successful. Starting full preprocessing...
Processing train split...


Processing train: 100%|██████████| 40000/40000 [06:31<00:00, 102.08image/s]


train split complete!
Processing val split...


Processing val: 100%|██████████| 800/800 [00:08<00:00, 99.67image/s] 


val split complete!
Processing test split...


Processing test: 100%|██████████| 3200/3200 [00:31<00:00, 101.03image/s]

test split complete!
Data preprocessing complete!



