In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18
from torchmetrics.classification import BinaryAccuracy
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import pytorch_lightning as pl
from models import FaceID_CNN, Ready_faceID_CNN
import os
from torchvision import datasets
from PIL import Image
import pandas as pd
from torchvision.datasets import ImageFolder

base_path : str = os.path.dirname(os.getcwd())
CSV_PATH  : str = base_path + '\\csv'
src_path  : str = base_path + '\\src'
json_path : str = base_path + '\\json'

traits = ['Male']


class CelebADataset(torch.utils.data.Dataset):
    def __init__(self, img_folder, labels, transform=None):
        self.img_folder = img_folder
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_folder, self.labels.iloc[idx]["image_id"])
        label = self.labels.iloc[idx][1:].values.astype("float32")
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label)

In [21]:
transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

dataset_path = os.path.join(CSV_PATH, "celeba")
img_folder_path = os.path.join(dataset_path, "img_align_celeba")
attr_path = os.path.join(dataset_path, "list_attr_celeba.txt")
partition_path = os.path.join(dataset_path, "list_eval_partition.txt")

if not os.path.exists(img_folder_path) or not os.path.exists(attr_path) or not os.path.exists(partition_path):
    raise FileNotFoundError("The dataset folder or required files are missing.")

# Load attribute labels
attr_df = pd.read_csv(attr_path, sep=r'\s+', skiprows=1)
attr_df = attr_df.reset_index().rename(columns={"index": "image_id"})
attr_df["image_id"] = attr_df["image_id"].astype(str)

# Keep only the desired traits and convert -1/+1 to 0/1
attr_df = attr_df[["image_id"] + traits]
filtered_df = attr_df[(attr_df[traits] > 0).any(axis=1)]

# Load partition information
partition_df = pd.read_csv(partition_path, sep=' ', header=None, names=["image_id", "partition"])
partition_df["image_id"] = partition_df["image_id"].astype(str)

# Merge attributes with partition info
filtered_df = filtered_df.merge(partition_df, on="image_id")

# Split the data based on partitions
train_df = filtered_df[filtered_df["partition"] == 0].drop(columns=["partition"])
val_df = filtered_df[filtered_df["partition"] == 1].drop(columns=["partition"])
test_df = filtered_df[filtered_df["partition"] == 2].drop(columns=["partition"])

# Debugging: Check sizes of splits
print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}, Test size: {len(test_df)}")

# Create datasets
train_data = CelebADataset(img_folder_path, train_df, transform)
val_data = CelebADataset(img_folder_path, val_df, transform)
test_data = CelebADataset(img_folder_path, test_df, transform)

# Create DataLoaders
# train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=5, pin_memory=True)
# val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=5, pin_memory=True)
# test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=5, pin_memory=True)

Train size: 68261, Validation size: 8458, Test size: 7715


In [14]:
filtered_df = attr_df[(attr_df[traits] > 0).any(axis=1)]

In [18]:
filtered_df[filtered_df["partition"] == 2]

Unnamed: 0,image_id,Eyeglasses,partition
182646,182647.jpg,1,2
182647,182648.jpg,1,2
182661,182662.jpg,1,2
182670,182671.jpg,1,2
182672,182673.jpg,1,2
...,...,...,...
202540,202541.jpg,1,2
202567,202568.jpg,1,2
202587,202588.jpg,1,2
202589,202590.jpg,1,2


In [13]:
attr_df[attr_df["partition"] == 0]

Unnamed: 0,image_id,Eyeglasses,partition
0,000001.jpg,0,0
1,000002.jpg,0,0
2,000003.jpg,0,0
3,000004.jpg,0,0
4,000005.jpg,0,0
...,...,...,...
162765,162766.jpg,0,0
162766,162767.jpg,0,0
162767,162768.jpg,0,0
162768,162769.jpg,0,0
