# VAE (Variational Autoencoder) for Training Data Generation

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import os 
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Create Dataset Object

In [6]:
class SingleCelebrityDataset(Dataset):
    '''
    This is the Dataset class for the Celebrity data. It's meant to conform
    to PyTorch's structure with the DataLoader. Typically, Datasets are premade,
    but this allows for customization.
    '''
    def __init__(self, data_dir, celebrity, idx, transform=None):
        self.data_dir = data_dir
        self.celebrity = celebrity
        self.transform = transform
        self.idx = idx # keeps track of celebrity numerical value
        self.encoder = LabelEncoder()
        self.image_paths, self.labels = self.load_data()

    def load_data(self):
        '''
        Interface method, called in the constructor of DataLoader (I think)
        This traverses through the ./data folder to assign X and y data to respective
        arrays.

        Returns the image_paths and numerical_labels (classes in a numeric encoding)
        '''
        fpath = f"{self.data_dir}"
        sub_folders = [item for item in os.listdir(fpath) if os.path.isdir(os.path.join(fpath, item))]
        image_paths = []
        labels = []
        numerical_labels = []

        for image in os.listdir(fpath):
            fpath_i = f"{self.data_dir}/{image}"
            image_paths.append(fpath_i)
            labels.append(f"{self.celebrity}")
            numerical_labels.append(self.idx)
                
        # print(image_paths)
        # print(labels)
        
        return image_paths, numerical_labels

    def __len__(self):
        '''
        Returns the length of the dataset.
        '''
        return len(self.image_paths)

    def __getitem__(self, idx):
        '''
        gets the item in a Dataset by index. Called by iterators.
        '''
        # Load an image and its label based on the index 'idx'.
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and preprocess the image
        image = Image.open(image_path)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)
            # print(image.shape)

        return image, label


### Loading Data into DataLoaders

In [14]:
transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

fpath_train = "./data/train"
fpath_val = "./data/val"
sub_folders_train = [item for item in os.listdir(fpath_train) if os.path.isdir(os.path.join(fpath_train, item))]
sub_folders_val = [item for item in os.listdir(fpath_val) if os.path.isdir(os.path.join(fpath_val, item))]
dataloader_arr_train = []
dataloader_arr_test = []
batch_size = 2

# loops through all the celebrities, creates their own dataset (important for labeling)
for idx, celebrity_folder in enumerate(sub_folders_val):
    #print(celebrity_folder)
    dataset_train_celeb = SingleCelebrityDataset(data_dir=f"{fpath_train}/{celebrity_folder}", celebrity=celebrity_folder, idx=idx, transform=transform)
    dataset_val_celeb = SingleCelebrityDataset(data_dir=f"{fpath_val}/{celebrity_folder}", celebrity=celebrity_folder, idx=idx, transform=transform)
    dataloader_arr_train.append(DataLoader(dataset_train_celeb, batch_size=batch_size, shuffle=True))
    dataloader_arr_test.append(DataLoader(dataset_val_celeb, batch_size=batch_size, shuffle=True))

print("Training Dataloaders: ", len(dataloader_arr_train))
print("Testing Dataloaders: ", len(dataloader_arr_test))

Training Dataloaders:  14
Testing Dataloaders:  14
