In [1]:
import torchvision
import pytorch_lightning as pl

In [None]:
class ExportCallback(pl.callbacks.Callback):
    
    def __init__(
        self,
        num_samples = 3,
        nrow = 8,
        padding = 2,
        normalize = False,
        norm_range = None,
        scale_each = False,
        pad_value = 0,
    ):
        super().__init__()
        self.num_samples = num_samples
        self.nrow = nrow
        self.padding = padding
        self.normalize = normalize
        self.norm_range = norm_range
        self.scale_each = scale_each
        self.pad_value = pad_value
        
    def on_epoch_end(self, trainer, pl_module):
        dim = (self.num_samples, pl_module.hparams.latent_dim)  # type: ignore[union-attr]
        z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)

        # generate images
        with torch.no_grad():
            pl_module.eval()
            images = pl_module(z)
            pl_module.train()

        if len(images.size()) == 2:
            img_dim = pl_module.img_dim
            images = images.view(self.num_samples, *img_dim)

        grid = torchvision.utils.make_grid(
            tensor=images,
            nrow=self.nrow,
            padding=self.padding,
            normalize=self.normalize,
            range=self.norm_range,
            scale_each=self.scale_each,
            pad_value=self.pad_value,
        )
        str_title = f"{pl_module.__class__.__name__}_images"
        trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)
