# VAE (Variational Autoencoder) for Training Data Generation

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import os 
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Create Dataset Object

In [3]:
class SingleCelebrityDataset(Dataset):
    '''
    This is the Dataset class for the Celebrity data. It's meant to conform
    to PyTorch's structure with the DataLoader. Typically, Datasets are premade,
    but this allows for customization.
    '''
    def __init__(self, data_dir, celebrity, idx, transform=None):
        self.data_dir = data_dir
        self.celebrity = celebrity
        self.transform = transform
        self.idx = idx # keeps track of celebrity numerical value
        self.encoder = LabelEncoder()
        self.image_paths, self.labels = self.load_data()

    def load_data(self):
        '''
        Interface method, called in the constructor of DataLoader (I think)
        This traverses through the ./data folder to assign X and y data to respective
        arrays.

        Returns the image_paths and numerical_labels (classes in a numeric encoding)
        '''
        fpath = f"{self.data_dir}"
        sub_folders = [item for item in os.listdir(fpath) if os.path.isdir(os.path.join(fpath, item))]
        image_paths = []
        labels = []
        numerical_labels = []

        for image in os.listdir(fpath):
            fpath_i = f"{self.data_dir}/{image}"
            image_paths.append(fpath_i)
            labels.append(f"{self.celebrity}")
            numerical_labels.append(self.idx)
                
        # print(image_paths)
        # print(labels)
        
        return image_paths, numerical_labels

    def __len__(self):
        '''
        Returns the length of the dataset.
        '''
        return len(self.image_paths)

    def __getitem__(self, idx):
        '''
        gets the item in a Dataset by index. Called by iterators.
        '''
        # Load an image and its label based on the index 'idx'.
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and preprocess the image
        image = Image.open(image_path)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if self.transform:
            image = self.transform(image)
            # print(image.shape)

        return image, label


### Loading Data into DataLoaders

In [4]:
transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

fpath_train = "./data/train"
fpath_val = "./data/val"
sub_folders_train = [item for item in os.listdir(fpath_train) if os.path.isdir(os.path.join(fpath_train, item))]
sub_folders_val = [item for item in os.listdir(fpath_val) if os.path.isdir(os.path.join(fpath_val, item))]
dataloader_arr_train = []
dataloader_arr_test = []
batch_size = 2

# loops through all the celebrities, creates their own dataset (important for labeling)
for idx, celebrity_folder in enumerate(sub_folders_val):
    #print(celebrity_folder)
    dataset_train_celeb = SingleCelebrityDataset(data_dir=f"{fpath_train}/{celebrity_folder}", celebrity=celebrity_folder, idx=idx, transform=transform)
    dataset_val_celeb = SingleCelebrityDataset(data_dir=f"{fpath_val}/{celebrity_folder}", celebrity=celebrity_folder, idx=idx, transform=transform)
    dataloader_arr_train.append(DataLoader(dataset_train_celeb, batch_size=batch_size, shuffle=True))
    dataloader_arr_test.append(DataLoader(dataset_val_celeb, batch_size=batch_size, shuffle=True))

print("Training Dataloaders: ", len(dataloader_arr_train))
print("Testing Dataloaders: ", len(dataloader_arr_test))

Training Dataloaders:  14
Testing Dataloaders:  14


## Create the VAE
<img src="https://avandekleut.github.io/assets/vae/variational-autoencoder.png" alt="variational autoencoder">

In [5]:
class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalEncoder, self).__init__()
        self.linear1 = nn.Linear(90000, 20000)
        self.linear2 = nn.Linear(20000, latent_dims)
        self.linear3 = nn.Linear(20000, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z

class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 20000)
        self.linear2 = nn.Linear(20000, 90000)

    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1, 1, 300, 300))
    

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

def train(autoencoder, data, epochs=20):
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in range(epochs):
        for x, y in data:
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
            loss.backward()
            opt.step()
            
    return autoencoder

### Train the VAE

In [4]:
vae = VariationalAutoencoder(10000)
vae = train(vae, dataloader_arr_train[0])

NameError: name 'VariationalAutoencoder' is not defined

Above code may not be useful... see code blocks below

### Image Generation with Pre-trained StableDiffusion

In [8]:
'''
I CANT RUN THIS BECAUSE I DONT HAVE SUFFICIENT GPU MEMORY. :)
''' 
from diffusers import StableDiffusionImageVariationPipeline
from PIL import Image
from io import BytesIO
import requests

pipe = StableDiffusionImageVariationPipeline.from_pretrained(
    "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
)
pipe = pipe.to("cuda")
pipe.safety_checker = lambda images, clip_input: (images, False)

unet\diffusion_pytorch_model.safetensors not found
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 4.00 GiB of which 0 bytes is free. Of the allocated memory 3.43 GiB is allocated by PyTorch, and 31.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
# get some random training images
dataiter = iter(dataloader_arr_train[0])
images, labels_numeric = next(dataiter)
#labels = dataset_train.encoder.inverse_transform(labels_numeric) # decodes the one-hot-encoding of the labels.
# show images
with open("/content/data/train/anne_hathaway/316px-Anne_Hathaway_@_2018.09.15_Human_Rights_Campaign_National_Dinner,_Washington,_DC_USA_06194_(43805104245)_(cropped).jpg", "rb") as image_file:
    # Read the image file into a BytesIO object
    image_data = BytesIO(image_file.read())
image = Image.open(image_data).convert("RGB")
out = pipe(image, num_images_per_prompt=1, guidance_scale=15)
out["images"][0].show()