-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add custom datasets * make sanity check on loader * add DatasetOutput class * unify data loading with other models * allow custom datasets as inputs * update tests * update tests * add tutorial * update gitignore * update README * clean up inerpolate and reconstruct * add interpolate and reconstruct * update README * apply black * small update in test * remove not needed imports * add install on main * clean up notebook * black * remove DoubleBatchDataset * remove DoubleBatchDataset * black isort
- Loading branch information
1 parent
96ca154
commit 4c50f25
Showing
22 changed files
with
704 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Tutorial - Using your own Dataset\n", | ||
"\n", | ||
"In this notebook, we will see how to use your own `Dataset` within `pythae`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Install the library\n", | ||
"%pip install git+https://github.com/clementchadebec/benchmark_VAE.git" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load the data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torchvision.datasets as datasets\n", | ||
"\n", | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=None)\n", | ||
"\n", | ||
"train_dataset = mnist_trainset.data[:10000].reshape(-1, 1, 28, 28) / 255.\n", | ||
"train_targets = mnist_trainset.targets[:10000]\n", | ||
"eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.\n", | ||
"eval_targets = mnist_trainset.targets[-10000:]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Example\n", | ||
"Below is presented an example where we build a dataset inheriting from [`torchvision.datasets.ImageFolder`](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder) that iterate on images located in a folder. In particular, this loader will avoid loading all the data at once." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# We will save the data in folders to mimic the desired example\n", | ||
"\n", | ||
"import os\n", | ||
"import torch\n", | ||
"import numpy as np\n", | ||
"import imageio\n", | ||
"\n", | ||
"if not os.path.exists(\"data_folders\"):\n", | ||
" os.mkdir(\"data_folders\")\n", | ||
"if not os.path.exists(\"data_folders/train\"):\n", | ||
" os.mkdir(\"data_folders/train\")\n", | ||
"if not os.path.exists(\"data_folders/eval\"):\n", | ||
" os.mkdir(\"data_folders/eval\")\n", | ||
"\n", | ||
"for i in range(len(train_dataset)):\n", | ||
" img = 255.0*train_dataset[i][0].unsqueeze(-1)\n", | ||
" img_folder = os.path.join(\"data_folders\", \"train\", f\"{train_targets[i]}\")\n", | ||
" if not os.path.exists(img_folder):\n", | ||
" os.mkdir(img_folder)\n", | ||
" imageio.imwrite(os.path.join(img_folder, \"%08d.jpg\" % i), np.repeat(img, repeats=3, axis=-1).type(torch.uint8))\n", | ||
"\n", | ||
"for i in range(len(eval_dataset)):\n", | ||
" img = 255.0*eval_dataset[i][0].unsqueeze(-1)\n", | ||
" img_folder = os.path.join(\"data_folders\", \"eval\", f\"{eval_targets[i]}\")\n", | ||
" if not os.path.exists(img_folder):\n", | ||
" os.mkdir(img_folder)\n", | ||
" imageio.imwrite(os.path.join(img_folder, \"%08d.jpg\" % i), np.repeat(img, repeats=3, axis=-1).type(torch.uint8))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define transforms to be applied on the data when reloaded\n", | ||
"from torchvision import datasets, transforms\n", | ||
"data_transform = transforms.Compose([\n", | ||
" transforms.Grayscale(num_output_channels=1),\n", | ||
" transforms.ToTensor() # the data must be tensors\n", | ||
"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Define your `CustomDataset`\n", | ||
"\n", | ||
"In this example, we build a custom dataset inheriting from ImageFolder. The only thing you have to keep in mind when building a custom dataset that you want to use in `pythae` is that the `__getitem__` method must output a `DatasetOutput` instance containing at least `data` as key. If this is not the case, you will not be able to combine your Dataset with the `pipelines`. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pythae.data.datasets import DatasetOutput\n", | ||
"\n", | ||
"class MyCustomDataset(datasets.ImageFolder):\n", | ||
"\n", | ||
" def __init__(self, root, transform=None, target_transform=None):\n", | ||
" super().__init__(root=root, transform=transform, target_transform=target_transform)\n", | ||
"\n", | ||
" def __getitem__(self, index):\n", | ||
" X, _ = super().__getitem__(index)\n", | ||
"\n", | ||
" return DatasetOutput(\n", | ||
" data=X\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_dataset = MyCustomDataset(\n", | ||
" root=\"data_folders/train\",\n", | ||
" transform=data_transform,\n", | ||
")\n", | ||
"\n", | ||
"eval_dataset = MyCustomDataset(\n", | ||
" root=\"data_folders/eval\", \n", | ||
" transform=data_transform\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Use your CustomDataset to train a `pythae.models`\n", | ||
"\n", | ||
"Now, the datasets can be passed to the `training_pipeline` to train any model implemented in `pythae`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pythae.models import VAE, VAEConfig\n", | ||
"from pythae.trainers import BaseTrainerConfig\n", | ||
"from pythae.pipelines.training import TrainingPipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"config = BaseTrainerConfig(\n", | ||
" output_dir='my_model',\n", | ||
" learning_rate=1e-3,\n", | ||
" batch_size=100,\n", | ||
" num_epochs=10, # Change this to train the model a bit more\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"model_config = VAEConfig(\n", | ||
" input_dim=(1, 28, 28),\n", | ||
" latent_dim=16\n", | ||
")\n", | ||
"\n", | ||
"model = VAE(\n", | ||
" model_config=model_config\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline = TrainingPipeline(\n", | ||
" training_config=config,\n", | ||
" model=model\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline(\n", | ||
" train_data=train_dataset, # here we use the custom train dataset\n", | ||
" eval_data=eval_dataset # here we use the custom eval dataset\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Let's have a look to the trained model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from pythae.models import AutoModel" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"last_training = sorted(os.listdir('my_model'))[-1]\n", | ||
"trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model'))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from pythae.samplers import NormalSampler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# create normal sampler\n", | ||
"normal_samper = NormalSampler(\n", | ||
" model=trained_model\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# sample\n", | ||
"gen_data = normal_samper.sample(\n", | ||
" num_samples=25\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# show results with normal sampler\n", | ||
"fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" for j in range(5):\n", | ||
" axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n", | ||
" axes[i][j].axis('off')\n", | ||
"plt.tight_layout(pad=0.)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "5e51c5ac46389dd7ba2bd8215d251ab84152720d3cad2ff91113d77594821aef" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.8.12 ('pythae')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.13" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,7 @@ class BadInheritanceError(Exception): | |
|
||
class ModelError(Exception): | ||
pass | ||
|
||
|
||
class DatasetError(Exception): | ||
pass |
Oops, something went wrong.