Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how Sampler works? #43

Closed
jprachir opened this issue Aug 9, 2022 · 2 comments
Closed

how Sampler works? #43

jprachir opened this issue Aug 9, 2022 · 2 comments
Labels
question Further information is requested

Comments

@jprachir
Copy link

jprachir commented Aug 9, 2022

Hi Clément:
Great work on introducing the VAE-oriented library!
You have made it more modular like predefined models, pipelines, and so forth. Can you share brief details on how the sampler works under the hood for generations?

Prachi

@clementchadebec clementchadebec added the question Further information is requested label Aug 11, 2022
@clementchadebec
Copy link
Owner

Hi @jprachir,

Thank you for opening this issue and sorry for the late reply. I will try to make my answer as detailed as possible if future questions arise.

Samplers design

The samplers are designed the same way the models are. This means that for any sampling technique (GMM, MAF ...) you will find in the pythae.samplers a folder containing a nameofsampler_config.py file with a dataclass where the parameters of the sampling scheme are defined and a nameofsampler_sampler.py where the actual sampling technique is implemented.

Sampler definition

Then, the samplers can be used with any suited model to generate new data. A pythae.samplers instance needs as input a trained model from which you want to generate. This is a required input to be given to your sampler. Optionally, you can also provide a custom configuration of your sampler when building it to use different sampler hyper-parameters.

Example - GMM sampler

For instance, let's say we have trained a vae and want to sample for it using a pythae.samplers.GaussianMixtureSampler. All you have to do to build your sampler is the following:

>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Define your sampler configuration
... gmm_sampler_config = GaussianMixtureSamplerConfig(
...	n_components=12
... )
>>> # Build your sampler
... my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae # A trained `pythae.models` instance
... )

This works the same for any other sampler that is implemented in pythae

The sample method

Once, your sampler has been instantiated, you can use it to generate new data pretty easily using the sample method that will actually save and/or return the generated samples from your trained model. The sample method is the one that the sampler uses to generate the embeddings (latent variables) that will then be passed to the decoder of the model we are sampling from to get the generated samples in the data space.

Example - N(0,1) sampler

For instance, if we take the simplest example with the pythae.samplers.NormalSampler, you can see that in the sample method we:

  1. Generate the embeddings with the defined technique (here a N(0, 1)) with

    z = torch.randn(batch_size, self.model.latent_dim).to(self.device)

  2. Then, we pass them to the decoder of the model we are sampling from with

    x_gen = self.model.decoder(z)["reconstruction"].detach()

The other samplers work the same but may use fancier scheme to generate the embedding in 1. In particular, some of them needs to be fitted before generating.

The fit method

As you may have noticed, the previous sample method only takes arguments that are relevant to the generation function itself (number of samples, batch size, whether to return the generated samples or not etc). However, some samplers need to be fitted before we call the sample method. For instance, a GaussianMixtureSample will first need to fit a Gaussian Mixture on the embeddings learned by your model. Similarly, the TwoStageVAESampler, MAFSampler, IAFSampler or PixelCNNSampler instances will require you to call the fit method before sampling with the sample function since they require a model (VAE, Normalizing Flow, Autoregressive Flow) to be fitted on the learned embeddings first as well.

Example - GMM sampler

Hence, if you want to sample for these samplers and you do not call the fit method before sample you should get the following error

my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae
... )
>>> my_sampler.sample(10)
... Traceback (most recent call last):
...    File "<stdin>", line 1, in <module>
...    File "/home/clement/Documents/these/implem/benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py", line 128, in sample
...       raise ArithmeticError(
... ArithmeticError: The sampler needs to be fitted by calling smapler.fit() method before sampling.

While the correct usage is

>>> my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

Now, what happens in the fit method. Calling this method means that you need your sampler to be fitted with some elements coming for your trained model (for instance the learned embeddings/latent variables). This is why, when you call this method you will be able to pass your train/eval data to retrieve the embeddings for instance.

Example 1 - GMM sampler

If we look at the example of the GaussianMixtureSampler again, you will see that in the fit method we:

  1. Retrieve the needed train embeddings from our trained model
  2. Then, we use these embeddings to fit our Gaussian mixture with them
    gmm = mixture.GaussianMixture(
    n_components=self.n_components,
    covariance_type="full",
    max_iter=2000,
    verbose=0,
    tol=1e-3,
    )
    gmm.fit(z.cpu().detach())
  3. Finally, we assign to the sampler the GMM model to further use it in the sample method.

Then as explained in the previous section for the NormalSampler, the GMM model is used in the sample method to

  1. Generate embeddings
    z = (
    torch.tensor(self.gmm.sample(batch_size)[0])
    .to(self.device)
    .type(torch.float)
    )
  2. Retrieve the generated samples in the data space.
    x_gen = self.model.decoder(z)["reconstruction"].detach()

Example 2 - MAF sampler

The other samplers work the same.

For instance, the MAFSampler will

  1. Retrieve the needed train and eval embeddings from our trained model

    for _, inputs in enumerate(train_loader):
    encoder_output = self.model(inputs)
    z_ = encoder_output.z

    for _, inputs in enumerate(eval_loader):
    encoder_output = self.model(inputs)
    z_ = encoder_output.z

  2. Then use these embeddings to fit the normalizing flow

    trainer = BaseTrainer(
    model=self.flow_contained_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    training_config=training_config,
    )
    trainer.train()

  3. The flow model is assigned to the sampler for further use in the sample method

    self.maf_model = MAF.load_from_folder(
    os.path.join(trainer.training_dir, "final_model")
    ).to(self.device)

Then, the flow model is used in the sample method to

  1. Generate embeddings

    u = self.prior.sample((batch_size,))
    z = self.maf_model.inverse(u).out

  2. Retrieve the generated samples in the data space.

    x_gen = self.model.decoder(z).reconstruction.detach()

Conclusion

I hope this helps you better apprehending how the samplers work. In any case, do not hesitate if you have any other questions or need me to clarify some points :)

Best,

Clément

@jprachir
Copy link
Author

jprachir commented Sep 3, 2022

Hi @clementchadebec,
Thanks for the detailed reply. It certainly demonstrates the gist of sampler components. I would use this issue to clarify questions related to samplers in the future.

Best,
Prachi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants