Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is this library compatible with custom datasets? #35

Closed
NamelessGaki opened this issue Jul 19, 2022 · 3 comments · Fixed by #49
Closed

Is this library compatible with custom datasets? #35

NamelessGaki opened this issue Jul 19, 2022 · 3 comments · Fixed by #49

Comments

@NamelessGaki
Copy link

Hello,

Thank you for your excellent work! As my question states, I wonder how to use this library with a custom dataset. I am new to machine learning and wanted to train a VAE on a relatively large dataset. So, I looked at the provided examples for training different models. However, it seemed to me that I had to load the whole dataset from a .npz file similar to the MNIST or the CelebA datasets. Is there a way to write my own data loader for a custom dataset and then use it with this library?

Thank you again for your work!

@clementchadebec
Copy link
Owner

Hi @NamelessGaki ,

Thank you for you kind message and opening this issue! Yes you can indeed use it with your own data. As you may have seen in the notebook tutorials (e.g. this one), you only need to convert your data to np.array or torch.tensor and then pass it to the TrainingPipeline to launch a model training. However, as of now the dataloaders are defined automatically in the pythae.trainers and so it may require a bit of work on you side to integrate you own dataloader.. I will consider integrating a new feature allowing users to pass either directly their data or a custom dataloader when launching the TrainingPipeline in the near future since it may be of interest for other users as well.

For now, what you can do to train a pythae.models instance is the following (but this needs to load all the data once)

  1. Make the needed imports
from pythae.models import VAE, VAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
  1. Build the VAE model to be trained and specify the training config to use
# Specify the config of the VAE model
model_config = VAEConfig(
    input_dim=(1, 28, 28), # You must specify here the shape of your custom data
    latent_dim=16
)

# Build the VAE
model = VAE(
    model_config=model_config,
)

# Specify your training config. This will then allow the library to create
# all the stuff needed for training (dataloaders, optimizers ...)
config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    batch_size=100,
    num_epochs=10, # Change this to train the model a bit more
)
  1. Build the TrainingPipeline and launch your training
# Build the TrainingPipeline
pipeline = TrainingPipeline(
    training_config=config, # pass the training config
    model=model # pass the model to train
)

# Launch the training
pipeline(
    train_data=train_dataset, # These are tensor or array of shape specified above (in the model config)
    eval_data=eval_dataset # These are tensor or array of shape specified above (in the model config)
)

In this example, the library will build both an encoder and decoder automatically shaped as Multi Layer Perceptron (MLP) . These networks may not be the best suited for your custom data so I suggest you have a look to this tutorial explaining how to pass your own neural nets architectures to the VAE.

I hope this helps.

In any case, do not hesitate if you have questions :)

@NamelessGaki
Copy link
Author

Thank you for your quick reply :)
Then I will train on a subset of my dataset (it is too large to be loaded all at once). Until then, I will be looking forward to the update with a custom dataloader!

Cheers!

@clementchadebec clementchadebec linked a pull request Aug 30, 2022 that will close this issue
@clementchadebec
Copy link
Owner

Hi @NamelessGaki ,

Sorry for the late reply.
Following your request, I have been working on allowing users to pass their custom datasets directly to the training_pipeline instance. I have opened #49 where you will find in particular a tutorial showing how to use a custom datset within pythae.

Let me know through #49 if these changes allow you to do what you want and if the tutorial is clear enough or if further elements should be added :)

In any case, do not hesitate if you have any questions.

Best,

Clément

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants