New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
how Sampler works? #43
Comments
Hi @jprachir, Thank you for opening this issue and sorry for the late reply. I will try to make my answer as detailed as possible if future questions arise. Samplers designThe samplers are designed the same way the models are. This means that for any sampling technique (GMM, MAF ...) you will find in the Sampler definitionThen, the samplers can be used with any suited model to generate new data. A Example - GMM samplerFor instance, let's say we have trained a vae and want to sample for it using a >>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Define your sampler configuration
... gmm_sampler_config = GaussianMixtureSamplerConfig(
... n_components=12
... )
>>> # Build your sampler
... my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae # A trained `pythae.models` instance
... ) This works the same for any other sampler that is implemented in The
|
z = torch.randn(batch_size, self.model.latent_dim).to(self.device) |
Then, we pass them to the decoder of the model we are sampling from with
x_gen = self.model.decoder(z)["reconstruction"].detach() |
The other samplers work the same but may use fancier scheme to generate the embedding in 1. In particular, some of them needs to be fitted before generating.
The fit
method
As you may have noticed, the previous sample
method only takes arguments that are relevant to the generation function itself (number of samples, batch size, whether to return the generated samples or not etc). However, some samplers need to be fitted before we call the sample
method. For instance, a GaussianMixtureSample
will first need to fit a Gaussian Mixture on the embeddings learned by your model. Similarly, the TwoStageVAESampler
, MAFSampler
, IAFSampler
or PixelCNNSampler
instances will require you to call the fit
method before sampling with the sample
function since they require a model (VAE, Normalizing Flow, Autoregressive Flow) to be fitted on the learned embeddings first as well.
Example - GMM sampler
Hence, if you want to sample for these samplers and you do not call the fit
method before sample
you should get the following error
my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> my_sampler.sample(10)
... Traceback (most recent call last):
... File "<stdin>", line 1, in <module>
... File "/home/clement/Documents/these/implem/benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py", line 128, in sample
... raise ArithmeticError(
... ArithmeticError: The sampler needs to be fitted by calling smapler.fit() method before sampling.
While the correct usage is
>>> my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )
Now, what happens in the fit
method. Calling this method means that you need your sampler to be fitted with some elements coming for your trained model (for instance the learned embeddings/latent variables). This is why, when you call this method you will be able to pass your train/eval data to retrieve the embeddings for instance.
Example 1 - GMM sampler
If we look at the example of the GaussianMixtureSampler
again, you will see that in the fit
method we:
- Retrieve the needed train embeddings from our trained model
z_ = self.model(inputs).z - Then, we use these embeddings to fit our Gaussian mixture with them
benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py
Lines 92 to 99 in a9a5388
gmm = mixture.GaussianMixture( n_components=self.n_components, covariance_type="full", max_iter=2000, verbose=0, tol=1e-3, ) gmm.fit(z.cpu().detach()) - Finally, we assign to the sampler the GMM model to further use it in the
sample
method.
self.gmm = gmm
Then as explained in the previous section for the NormalSampler
, the GMM model is used in the sample
method to
- Generate embeddings
benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py
Lines 140 to 144 in a9a5388
z = ( torch.tensor(self.gmm.sample(batch_size)[0]) .to(self.device) .type(torch.float) ) - Retrieve the generated samples in the data space.
x_gen = self.model.decoder(z)["reconstruction"].detach()
Example 2 - MAF sampler
The other samplers work the same.
For instance, the MAFSampler
will
-
Retrieve the needed train and eval embeddings from our trained model
benchmark_VAE/src/pythae/samplers/maf_sampler/maf_sampler.py
Lines 93 to 95 in a9a5388
for _, inputs in enumerate(train_loader): encoder_output = self.model(inputs) z_ = encoder_output.z
benchmark_VAE/src/pythae/samplers/maf_sampler/maf_sampler.py
Lines 126 to 128 in a9a5388
for _, inputs in enumerate(eval_loader): encoder_output = self.model(inputs) z_ = encoder_output.z -
Then use these embeddings to fit the normalizing flow
benchmark_VAE/src/pythae/samplers/maf_sampler/maf_sampler.py
Lines 140 to 147 in a9a5388
trainer = BaseTrainer( model=self.flow_contained_model, train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, ) trainer.train() -
The flow model is assigned to the sampler for further use in the
sample
method
benchmark_VAE/src/pythae/samplers/maf_sampler/maf_sampler.py
Lines 149 to 151 in a9a5388
self.maf_model = MAF.load_from_folder( os.path.join(trainer.training_dir, "final_model") ).to(self.device)
Then, the flow model is used in the sample
method to
-
Generate embeddings
benchmark_VAE/src/pythae/samplers/maf_sampler/maf_sampler.py
Lines 194 to 195 in a9a5388
u = self.prior.sample((batch_size,)) z = self.maf_model.inverse(u).out -
Retrieve the generated samples in the data space.
x_gen = self.model.decoder(z).reconstruction.detach()
Conclusion
I hope this helps you better apprehending how the samplers work. In any case, do not hesitate if you have any other questions or need me to clarify some points :)
Best,
Clément
Hi @clementchadebec, Best, |
Hi Clément:
Great work on introducing the VAE-oriented library!
You have made it more modular like predefined models, pipelines, and so forth. Can you share brief details on how the sampler works under the hood for generations?
Prachi
The text was updated successfully, but these errors were encountered: