In [1]:
import torchvision.datasets as datasets

# %load_ext autoreload
# %autoreload 2

In [2]:
# %pip install pythae

In [3]:
import os
import torch
import numpy as np
import imageio

if not os.path.exists("data_folders"):
    os.mkdir("data_folders")
if not os.path.exists("data_folders/train"):
    os.mkdir("data_folders/train")
if not os.path.exists("data_folders/eval"):
    os.mkdir("data_folders/eval")


In [14]:
from torchvision import datasets, transforms
data_transform = transforms.Compose([
    # transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(), # the data must be tensors
    transforms.RandomHorizontalFlip(0.5),
    transforms.GaussianBlur(3)
])

In [15]:
from pythae.data.datasets import DatasetOutput

class MyCustomDataset(datasets.ImageFolder):

    def __init__(self, root, transform=None, target_transform=None):
        super().__init__(root=root, transform=transform, target_transform=target_transform)

    def __getitem__(self, index):
        X, _ = super().__getitem__(index)

        return DatasetOutput(
            data=X
        )


In [16]:
train_dataset = MyCustomDataset(
    root="data_folders/train",
    transform=data_transform,
)

eval_dataset = MyCustomDataset(
    root="data_folders/eval", 
    transform=data_transform
)

In [1]:
from pythae.models import VAE, VAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    num_epochs=50, # Change this to train the model a bit more
    steps_saving=5
)


model_config = VAEConfig(
    input_dim=(3, 512, 512),
    latent_dim=8
)

model = VAE(
    model_config=model_config
)
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)
pipeline(
    train_data=train_dataset, # here we use the custom train dataset
    eval_data=eval_dataset # here we use the custom eval dataset
)

In [21]:
import os
from pythae.models import AutoModel
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))
from pythae.samplers import NormalSampler
# create normal sampler
normal_samper = NormalSampler(
    model=trained_model
)
# sample
gen_data = normal_samper.sample(
    num_samples=25
)
t.tight_layout(pad=0.)

In [2]:
import matplotlib.pyplot as plt
# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(np.swapaxes(np.swapaxes(gen_data[i*5 +j], 0, 2), 0, 1))
        axes[i][j].axis('off')
