From 1103a99d54e655e2d012e50fe56dd7708c501f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Chadebec?= Date: Sun, 22 Oct 2023 19:42:01 +0200 Subject: [PATCH] fixes #110 & fix some tests & black 23.10 --- src/pythae/data/datasets.py | 1 - src/pythae/data/preprocessors.py | 2 - .../adversarial_ae/adversarial_ae_model.py | 19 +--- src/pythae/models/ae/ae_model.py | 2 - src/pythae/models/base/base_model.py | 6 +- .../models/beta_tc_vae/beta_tc_vae_model.py | 17 +-- src/pythae/models/beta_vae/beta_vae_model.py | 17 +-- src/pythae/models/ciwae/ciwae_model.py | 23 ++-- .../disentangled_beta_vae_model.py | 17 +-- .../models/factor_vae/factor_vae_model.py | 17 +-- .../models/factor_vae/factor_vae_utils.py | 1 - src/pythae/models/hvae/hvae_model.py | 26 ++--- src/pythae/models/info_vae/info_vae_model.py | 23 ++-- src/pythae/models/iwae/iwae_model.py | 23 ++-- src/pythae/models/miwae/miwae_model.py | 25 ++--- .../models/msssim_vae/msssim_vae_model.py | 2 - .../models/msssim_vae/msssim_vae_utils.py | 6 +- .../models/nn/benchmarks/celeba/convnets.py | 6 - .../models/nn/benchmarks/celeba/resnets.py | 6 - .../models/nn/benchmarks/cifar/convnets.py | 6 - .../models/nn/benchmarks/cifar/resnets.py | 6 - .../models/nn/benchmarks/mnist/convnets.py | 6 - .../models/nn/benchmarks/mnist/resnets.py | 6 - src/pythae/models/nn/default_architectures.py | 7 -- .../normalizing_flows/base/base_nf_model.py | 1 - .../models/normalizing_flows/iaf/iaf_model.py | 1 - src/pythae/models/normalizing_flows/layers.py | 1 - .../normalizing_flows/made/made_model.py | 3 - .../models/normalizing_flows/maf/maf_model.py | 1 - .../pixelcnn/pixelcnn_model.py | 1 - .../planar_flow/planar_flow_model.py | 1 - .../radial_flow/radial_flow_model.py | 1 - src/pythae/models/piwae/piwae_model.py | 25 ++--- src/pythae/models/pvae/pvae_model.py | 20 +--- src/pythae/models/pvae/pvae_utils.py | 22 ++-- src/pythae/models/rae_gp/rae_gp_model.py | 2 - src/pythae/models/rae_l2/rae_l2_model.py | 2 - src/pythae/models/rhvae/rhvae_model.py | 36 ++---- src/pythae/models/rhvae/rhvae_utils.py | 4 +- src/pythae/models/svae/svae_model.py | 27 ++--- src/pythae/models/vae/vae_model.py | 20 +--- src/pythae/models/vae_gan/vae_gan_model.py | 18 +-- src/pythae/models/vae_iaf/vae_iaf_model.py | 22 +--- .../models/vae_lin_nf/vae_lin_nf_model.py | 22 +--- src/pythae/models/vamp/vamp_model.py | 36 ++---- src/pythae/models/vq_vae/vq_vae_model.py | 3 - src/pythae/models/vq_vae/vq_vae_utils.py | 9 +- src/pythae/models/wae_mmd/wae_mmd_model.py | 8 +- src/pythae/pipelines/generation.py | 1 - src/pythae/pipelines/training.py | 9 +- src/pythae/samplers/base/base_sampler.py | 2 - .../gaussian_mixture_sampler.py | 2 - .../hypersphere_uniform_sampler.py | 1 - .../samplers/iaf_sampler/iaf_sampler.py | 3 - .../samplers/maf_sampler/maf_sampler.py | 3 - .../manifold_sampler/rhvae_sampler.py | 10 +- .../normal_sampling/normal_sampler.py | 1 - .../pixelcnn_sampler/pixelcnn_sampler.py | 3 - .../samplers/pvae_sampler/pvae_sampler.py | 3 - .../two_stage_sampler.py | 5 - .../samplers/vamp_sampler/vamp_sampler.py | 1 - .../adversarial_trainer.py | 8 -- .../trainers/base_trainer/base_trainer.py | 8 -- .../coupled_optimizer_adversarial_trainer.py | 7 -- .../coupled_optimizer_trainer.py | 9 -- src/pythae/trainers/training_callbacks.py | 4 - tests/test_AE.py | 58 ++++------ tests/test_Adversarial_AE.py | 59 +++------- tests/test_BetaTCVAE.py | 35 ++---- tests/test_BetaVAE.py | 35 ++---- tests/test_CIWAE.py | 37 ++----- tests/test_DisentangledBetaVAE.py | 35 ++---- tests/test_FactorVAE.py | 68 ++++-------- tests/test_HVAE.py | 35 ++---- tests/test_IWAE.py | 35 ++---- tests/test_MIWAE.py | 35 ++---- tests/test_PIWAE.py | 64 ++++------- tests/test_PoincareVAE.py | 36 ++---- tests/test_RHVAE.py | 103 +++++++----------- tests/test_SVAE.py | 35 ++---- tests/test_VAE.py | 35 ++---- tests/test_VAEGAN.py | 67 ++++-------- tests/test_VAE_IAF.py | 35 ++---- tests/test_VAE_LinFlow.py | 38 ++----- tests/test_VAMP.py | 34 ++---- tests/test_VQVAE.py | 37 ++----- tests/test_WAE_MMD.py | 35 ++---- tests/test_info_vae_mmd.py | 35 ++---- tests/test_rae_gp.py | 35 ++---- tests/test_rae_l2.py | 64 ++++------- 90 files changed, 462 insertions(+), 1259 deletions(-) diff --git a/src/pythae/data/datasets.py b/src/pythae/data/datasets.py index 99772656..59aa6ae0 100644 --- a/src/pythae/data/datasets.py +++ b/src/pythae/data/datasets.py @@ -55,7 +55,6 @@ class BaseDataset(Dataset): """ def __init__(self, data, labels): - self.labels = labels.type(torch.float) self.data = data.type(torch.float) diff --git a/src/pythae/data/preprocessors.py b/src/pythae/data/preprocessors.py index a51f2aa4..476a4f91 100644 --- a/src/pythae/data/preprocessors.py +++ b/src/pythae/data/preprocessors.py @@ -92,7 +92,6 @@ def to_dataset(data: torch.Tensor, labels: Optional[torch.Tensor] = None): return dataset def _process_data_array(self, data: np.ndarray, batch_size: int = 100): - num_samples = data.shape[0] samples_shape = data.shape @@ -102,7 +101,6 @@ def _process_data_array(self, data: np.ndarray, batch_size: int = 100): full_data = [] for i in range(num_complete_batch): - # Detect potential nan if DataProcessor.has_nan(data[i * batch_size : (i + 1) * batch_size]): raise ValueError("Nan detected in input data!") diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_model.py b/src/pythae/models/adversarial_ae/adversarial_ae_model.py index 61125393..bbb6e3e4 100644 --- a/src/pythae/models/adversarial_ae/adversarial_ae_model.py +++ b/src/pythae/models/adversarial_ae/adversarial_ae_model.py @@ -58,7 +58,6 @@ def __init__( decoder: Optional[BaseDecoder] = None, discriminator: Optional[BaseDiscriminator] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) if discriminator is None: @@ -149,22 +148,16 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z, z_prior): - N = z.shape[0] # batch size if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -232,7 +225,6 @@ def save(self, dir_path: str): @classmethod def _load_custom_discriminator_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) @@ -361,7 +353,6 @@ def load_from_hf_hub( ) else: - if not model_config.uses_default_encoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") encoder = cls._load_custom_encoder_from_folder(dir_path) diff --git a/src/pythae/models/ae/ae_model.py b/src/pythae/models/ae/ae_model.py index 6be833c6..c99d0efb 100644 --- a/src/pythae/models/ae/ae_model.py +++ b/src/pythae/models/ae/ae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - BaseAE.__init__(self, model_config=model_config, decoder=decoder) self.model_name = "AE" @@ -83,7 +82,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x): - MSE = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index f74f7aa4..c11c0e14 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -57,10 +57,9 @@ class BaseAE(nn.Module): def __init__( self, model_config: BaseAEConfig, - encoder: Optional[BaseDecoder] = None, + encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - nn.Module.__init__(self) self.model_name = "BaseAE" @@ -374,7 +373,6 @@ def _load_model_weights_from_folder(cls, dir_path): @classmethod def _load_custom_encoder_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) @@ -393,7 +391,6 @@ def _load_custom_encoder_from_folder(cls, dir_path): @classmethod def _load_custom_decoder_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) @@ -510,7 +507,6 @@ def load_from_hf_hub(cls, hf_hub_path: str, allow_pickle=False): # pragma: no c ) else: - if not model_config.uses_default_encoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") encoder = cls._load_custom_encoder_from_folder(dir_path) diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py index f88ce29b..6f7e3527 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py @@ -41,7 +41,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "BetaTCVAE" @@ -89,20 +88,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z, dataset_size): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py index 37f4b051..f78d574c 100644 --- a/src/pythae/models/beta_vae/beta_vae_model.py +++ b/src/pythae/models/beta_vae/beta_vae_model.py @@ -40,7 +40,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "BetaVAE" @@ -81,20 +80,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), diff --git a/src/pythae/models/ciwae/ciwae_model.py b/src/pythae/models/ciwae/ciwae_model.py index 7323bb53..51f7d0b8 100644 --- a/src/pythae/models/ciwae/ciwae_model.py +++ b/src/pythae/models/ciwae/ciwae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "CIWAE" @@ -92,22 +91,16 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .repeat(1, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x, x.reshape(recon_x.shape[0], -1) @@ -117,7 +110,7 @@ def loss_function(self, recon_x, x, mu, log_var, z): ).sum(dim=-1) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) KLD = -(log_p_z - log_q_z) diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py index 2ab9976a..cb07485b 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py @@ -40,7 +40,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) assert ( @@ -89,20 +88,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z, epoch): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index c2af1700..2b4b02cd 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -43,7 +43,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) @@ -132,22 +131,16 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): - N = z.shape[0] # batch size if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), diff --git a/src/pythae/models/factor_vae/factor_vae_utils.py b/src/pythae/models/factor_vae/factor_vae_utils.py index 62fd68ea..046685b0 100644 --- a/src/pythae/models/factor_vae/factor_vae_utils.py +++ b/src/pythae/models/factor_vae/factor_vae_utils.py @@ -4,7 +4,6 @@ class FactorVAEDiscriminator(nn.Module): def __init__(self, latent_dim=16, hidden_units=1000) -> None: - nn.Module.__init__(self) self.layers = nn.Sequential( diff --git a/src/pythae/models/hvae/hvae_model.py b/src/pythae/models/hvae/hvae_model.py index f6d0afd4..25fff5bd 100644 --- a/src/pythae/models/hvae/hvae_model.py +++ b/src/pythae/models/hvae/hvae_model.py @@ -42,7 +42,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "HVAE" @@ -89,7 +88,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: beta_sqrt_old = self.beta_zero_sqrt for k in range(self.n_lf): - # perform leapfrog steps # 1st leapfrog step @@ -136,7 +134,6 @@ def _dU_dz(self, z, x): return g + z def loss_function(self, x, zK, rhoK, z0, mu, log_var): - recon_x = self.decoder(zK)["reconstruction"] logpx_given_z = self._log_p_x_given_z(recon_x, x) # log p(x|z_K) @@ -161,22 +158,17 @@ def _tempering(self, k, K): return 1 / beta_k def _log_p_x_given_z(self, recon_x, x): - if self.model_config.reconstruction_loss == "mse": # sigma is taken as I_D - logp_x_given_z = ( - -0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + logp_x_given_z = -0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) # - torch.log(torch.tensor([2 * np.pi]).to(x.device)) \ # * np.prod(self.input_dim) / 2 elif self.model_config.reconstruction_loss == "bce": - logp_x_given_z = ( torch.distributions.Bernoulli(logits=recon_x.reshape(x.shape[0], -1)) .log_prob(x.reshape(x.shape[0], -1)) @@ -211,7 +203,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]).reshape(-1, 1, 28, 28) encoder_output = self.encoder(x_rep) @@ -226,7 +217,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): beta_sqrt_old = self.beta_zero_sqrt for k in range(self.n_lf): - # 1st leapfrog step rho_ = rho - (self.eps_lf / 2) * self._dU_dz(z, x_rep) @@ -244,10 +234,10 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z0_given_x = -0.5 * ( log_var + (z0 - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) - log_p_rho = -0.5 * (rho ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) + log_p_rho = -0.5 * (rho**2).sum(dim=-1) - log_p_rho0 = -0.5 * (rho ** 2).sum(dim=-1) * self.beta_zero_sqrt + log_p_rho0 = -0.5 * (rho**2).sum(dim=-1) * self.beta_zero_sqrt recon_x = self.decoder(z)["reconstruction"] diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py index 207bc3f5..0de61f68 100644 --- a/src/pythae/models/info_vae/info_vae_model.py +++ b/src/pythae/models/info_vae/info_vae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "INFOVAE_MMD" @@ -87,22 +86,16 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z, z_prior, mu, log_var): - N = z.shape[0] # batch size if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -123,7 +116,7 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var): mmd_z = (k_z - k_z.diag().diag()).sum() / ((N - 1) * N) mmd_z_prior = (k_z_prior - k_z_prior.diag().diag()).sum() / ((N - 1) * N) - mmd_cross = k_cross.sum() / (N ** 2) + mmd_cross = k_cross.sum() / (N**2) mmd_loss = mmd_z + mmd_z_prior - 2 * mmd_cross @@ -148,7 +141,7 @@ def imq_kernel(self, z1, z2): """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation""" Cbase = ( - 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth ** 2 + 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth**2 ) k = 0 @@ -162,7 +155,7 @@ def imq_kernel(self, z1, z2): def rbf_kernel(self, z1, z2): """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation""" - C = 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth ** 2 + C = 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth**2 k = torch.exp(-torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2 / C) diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py index 06b72335..f70053ce 100644 --- a/src/pythae/models/iwae/iwae_model.py +++ b/src/pythae/models/iwae/iwae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "IWAE" @@ -91,22 +90,16 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .repeat(1, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x, x.reshape(recon_x.shape[0], -1) @@ -116,7 +109,7 @@ def loss_function(self, recon_x, x, mu, log_var, z): ).sum(dim=-1) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) KLD = -(log_p_z - log_q_z) diff --git a/src/pythae/models/miwae/miwae_model.py b/src/pythae/models/miwae/miwae_model.py index 768fae0e..52a5c11f 100644 --- a/src/pythae/models/miwae/miwae_model.py +++ b/src/pythae/models/miwae/miwae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "MIWAE" @@ -100,23 +99,17 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, self.gradient_n_estimates, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x, x.reshape(recon_x.shape[0], -1) @@ -127,7 +120,7 @@ def loss_function(self, recon_x, x, mu, log_var, z): ).sum(dim=-1) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) KLD = -(log_p_z - log_q_z) diff --git a/src/pythae/models/msssim_vae/msssim_vae_model.py b/src/pythae/models/msssim_vae/msssim_vae_model.py index b942c867..0c892b87 100644 --- a/src/pythae/models/msssim_vae/msssim_vae_model.py +++ b/src/pythae/models/msssim_vae/msssim_vae_model.py @@ -40,7 +40,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "MSSSIM_VAE" @@ -82,7 +81,6 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - recon_loss = self.msssim(recon_x, x) KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) diff --git a/src/pythae/models/msssim_vae/msssim_vae_utils.py b/src/pythae/models/msssim_vae/msssim_vae_utils.py index 6e7d8e6a..7fe972de 100644 --- a/src/pythae/models/msssim_vae/msssim_vae_utils.py +++ b/src/pythae/models/msssim_vae/msssim_vae_utils.py @@ -13,7 +13,7 @@ def __init__(self, window_size=11): def _gaussian(self, sigma): gauss = torch.Tensor( [ - np.exp(-((x - self.window_size // 2) ** 2) / float(2 * sigma ** 2)) + np.exp(-((x - self.window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(self.window_size) ] ) @@ -129,8 +129,8 @@ def forward(self, img1, img2): mssim = (mssim + 1) / 2 mcs = (mcs + 1) / 2 - pow1 = mcs ** weights - pow2 = mssim ** weights + pow1 = mcs**weights + pow2 = mssim**weights output = torch.prod(pow1[:-1] * pow2[-1]) return 1 - output diff --git a/src/pythae/models/nn/benchmarks/celeba/convnets.py b/src/pythae/models/nn/benchmarks/celeba/convnets.py index 650de4e1..8e8f20c7 100644 --- a/src/pythae/models/nn/benchmarks/celeba/convnets.py +++ b/src/pythae/models/nn/benchmarks/celeba/convnets.py @@ -132,7 +132,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -289,7 +288,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -448,7 +446,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -603,7 +600,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -748,7 +744,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -765,7 +760,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): out = x for i in range(max_depth): - if i == 4: out = out.reshape(x.shape[0], -1) diff --git a/src/pythae/models/nn/benchmarks/celeba/resnets.py b/src/pythae/models/nn/benchmarks/celeba/resnets.py index a661853b..4450c8e0 100644 --- a/src/pythae/models/nn/benchmarks/celeba/resnets.py +++ b/src/pythae/models/nn/benchmarks/celeba/resnets.py @@ -130,7 +130,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -276,7 +275,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -425,7 +423,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -589,7 +586,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -743,7 +739,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -913,7 +908,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels diff --git a/src/pythae/models/nn/benchmarks/cifar/convnets.py b/src/pythae/models/nn/benchmarks/cifar/convnets.py index 3a59d75d..41836a6f 100644 --- a/src/pythae/models/nn/benchmarks/cifar/convnets.py +++ b/src/pythae/models/nn/benchmarks/cifar/convnets.py @@ -128,7 +128,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -284,7 +283,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -443,7 +441,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -584,7 +581,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -728,7 +724,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -745,7 +740,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): out = x for i in range(max_depth): - if i == 4: out = out.reshape(x.shape[0], -1) diff --git a/src/pythae/models/nn/benchmarks/cifar/resnets.py b/src/pythae/models/nn/benchmarks/cifar/resnets.py index 84ecedaf..5df606aa 100644 --- a/src/pythae/models/nn/benchmarks/cifar/resnets.py +++ b/src/pythae/models/nn/benchmarks/cifar/resnets.py @@ -124,7 +124,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -267,7 +266,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -411,7 +409,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -555,7 +552,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -695,7 +691,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -836,7 +831,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels diff --git a/src/pythae/models/nn/benchmarks/mnist/convnets.py b/src/pythae/models/nn/benchmarks/mnist/convnets.py index 560616bd..102a8a15 100644 --- a/src/pythae/models/nn/benchmarks/mnist/convnets.py +++ b/src/pythae/models/nn/benchmarks/mnist/convnets.py @@ -132,7 +132,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -287,7 +286,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -444,7 +442,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -585,7 +582,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -729,7 +725,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -746,7 +741,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): out = x for i in range(max_depth): - if i == 4: out = out.reshape(x.shape[0], -1) diff --git a/src/pythae/models/nn/benchmarks/mnist/resnets.py b/src/pythae/models/nn/benchmarks/mnist/resnets.py index 7a937dde..4176f5a9 100644 --- a/src/pythae/models/nn/benchmarks/mnist/resnets.py +++ b/src/pythae/models/nn/benchmarks/mnist/resnets.py @@ -125,7 +125,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -268,7 +267,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -411,7 +409,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -554,7 +551,6 @@ def forward(self, x: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -708,7 +704,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -863,7 +858,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels diff --git a/src/pythae/models/nn/default_architectures.py b/src/pythae/models/nn/default_architectures.py index 9eb52ddd..480231db 100644 --- a/src/pythae/models/nn/default_architectures.py +++ b/src/pythae/models/nn/default_architectures.py @@ -30,7 +30,6 @@ def forward(self, x, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -80,7 +79,6 @@ def forward(self, x, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -131,7 +129,6 @@ def forward(self, x, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -178,13 +175,11 @@ def __init__(self, args: dict): self.depth = len(layers) def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): - output = ModelOutput() max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels @@ -233,7 +228,6 @@ def __init__(self, args: dict): self.lower = nn.Linear(400, k) def forward(self, x): - h1 = self.layers(x.reshape(-1, np.prod(self.input_dim))) h21, h22 = self.diag(h1), self.lower(h1) @@ -284,7 +278,6 @@ def forward(self, z: torch.Tensor, output_layer_levels: List[int] = None): max_depth = self.depth if output_layer_levels is not None: - assert all( self.depth >= levels > 0 or levels == -1 for levels in output_layer_levels diff --git a/src/pythae/models/normalizing_flows/base/base_nf_model.py b/src/pythae/models/normalizing_flows/base/base_nf_model.py index 6f64063c..84840c8d 100644 --- a/src/pythae/models/normalizing_flows/base/base_nf_model.py +++ b/src/pythae/models/normalizing_flows/base/base_nf_model.py @@ -22,7 +22,6 @@ class BaseNF(nn.Module): """ def __init__(self, model_config: BaseNFConfig): - nn.Module.__init__(self) if model_config.input_dim is None: diff --git a/src/pythae/models/normalizing_flows/iaf/iaf_model.py b/src/pythae/models/normalizing_flows/iaf/iaf_model.py index 6366a652..186c561d 100644 --- a/src/pythae/models/normalizing_flows/iaf/iaf_model.py +++ b/src/pythae/models/normalizing_flows/iaf/iaf_model.py @@ -21,7 +21,6 @@ class IAF(BaseNF): """ def __init__(self, model_config: IAFConfig): - BaseNF.__init__(self, model_config=model_config) self.net = [] diff --git a/src/pythae/models/normalizing_flows/layers.py b/src/pythae/models/normalizing_flows/layers.py index aa410c0d..b6fc4749 100644 --- a/src/pythae/models/normalizing_flows/layers.py +++ b/src/pythae/models/normalizing_flows/layers.py @@ -72,7 +72,6 @@ def forward(self, x): def inverse(self, y): if self.training: - if not hasattr(self, "batch_mean") or not hasattr(self, "batch_var"): mean = torch.zeros(1).to(y.device) var = torch.ones(1).to(y.device) diff --git a/src/pythae/models/normalizing_flows/made/made_model.py b/src/pythae/models/normalizing_flows/made/made_model.py index 191e703d..2655176a 100644 --- a/src/pythae/models/normalizing_flows/made/made_model.py +++ b/src/pythae/models/normalizing_flows/made/made_model.py @@ -20,7 +20,6 @@ class MADE(BaseNF): """ def __init__(self, model_config: MADEConfig): - BaseNF.__init__(self, model_config=model_config) self.net = [] @@ -52,7 +51,6 @@ def __init__(self, model_config: MADEConfig): masks = self._make_mask(ordering=self.model_config.degrees_ordering) for inp, out, mask in zip(hidden_sizes[:-1], hidden_sizes[1:-1], masks[:-1]): - self.net.extend([MaskedLinear(inp, out, mask), nn.ReLU()]) # outputs mean and logvar @@ -67,7 +65,6 @@ def __init__(self, model_config: MADEConfig): self.net = nn.Sequential(*self.net) def _make_mask(self, ordering="sequential"): - # Get degrees for mask creation if ordering == "sequential": diff --git a/src/pythae/models/normalizing_flows/maf/maf_model.py b/src/pythae/models/normalizing_flows/maf/maf_model.py index 1b2f11f2..3e06eef5 100644 --- a/src/pythae/models/normalizing_flows/maf/maf_model.py +++ b/src/pythae/models/normalizing_flows/maf/maf_model.py @@ -19,7 +19,6 @@ class MAF(BaseNF): """ def __init__(self, model_config: MAFConfig): - BaseNF.__init__(self, model_config=model_config) self.net = [] diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py index 3465e449..45c86127 100644 --- a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py +++ b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_model.py @@ -19,7 +19,6 @@ class PixelCNN(BaseNF): """ def __init__(self, model_config: PixelCNNConfig): - BaseNF.__init__(self, model_config=model_config) self.model_config = model_config diff --git a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py index 4ef15611..2e6ac3f1 100644 --- a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py +++ b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_model.py @@ -26,7 +26,6 @@ class PlanarFlow(BaseNF): """ def __init__(self, model_config: PlanarFlowConfig): - BaseNF.__init__(self, model_config) self.w = nn.Parameter(torch.randn(1, self.input_dim)) diff --git a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py index 4f2b9e5a..a177a9b2 100644 --- a/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py +++ b/src/pythae/models/normalizing_flows/radial_flow/radial_flow_model.py @@ -17,7 +17,6 @@ class RadialFlow(BaseNF): """ def __init__(self, model_config: RadialFlowConfig): - BaseNF.__init__(self, model_config) self.x0 = nn.Parameter(torch.randn(1, self.input_dim)) diff --git a/src/pythae/models/piwae/piwae_model.py b/src/pythae/models/piwae/piwae_model.py index 82dbc174..a882d8d0 100644 --- a/src/pythae/models/piwae/piwae_model.py +++ b/src/pythae/models/piwae/piwae_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "PIWAE" @@ -108,23 +107,17 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, self.gradient_n_estimates, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x, x.reshape(recon_x.shape[0], -1) @@ -135,7 +128,7 @@ def loss_function(self, recon_x, x, mu, log_var, z): ).sum(dim=-1) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) KLD = -(log_p_z - log_q_z) diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py index 1fbe6498..a6581d3f 100644 --- a/src/pythae/models/pvae/pvae_model.py +++ b/src/pythae/models/pvae/pvae_model.py @@ -43,7 +43,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "PoincareVAE" @@ -125,20 +124,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, z, qz_x): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -230,7 +223,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -262,7 +254,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): recon_x = self.decoder(z.squeeze(0))["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -274,7 +265,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/pvae/pvae_utils.py b/src/pythae/models/pvae/pvae_utils.py index a836da72..c486612b 100644 --- a/src/pythae/models/pvae/pvae_utils.py +++ b/src/pythae/models/pvae/pvae_utils.py @@ -59,13 +59,13 @@ def _mobius_add(x, y, c, dim=-1): ## OK y2 = y.pow(2).sum(dim=dim, keepdim=True) xy = (x * y).sum(dim=dim, keepdim=True) num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y - denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 + denom = 1 + 2 * c * xy + c**2 * x2 * y2 return num / denom.clamp_min(MIN_NORM) def _mobius_scalar_mul(r, x, c, dim: int = -1): ## OK x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM) - sqrt_c = c ** 0.5 + sqrt_c = c**0.5 res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c) return res_c @@ -74,7 +74,7 @@ def _project(x, c, dim: int = -1, eps: float = None): ## OK norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM) if eps is None: eps = BALL_EPS[x.dtype] - maxnorm = (1 - eps) / (c ** 0.5) + maxnorm = (1 - eps) / (c**0.5) cond = norm > maxnorm projected = x / norm * maxnorm return torch.where(cond, projected, x) @@ -86,7 +86,7 @@ def _gyration(u, v, w, c, dim: int = -1): ## OK uv = (u * v).sum(dim=dim, keepdim=True) uw = (u * w).sum(dim=dim, keepdim=True) vw = (v * w).sum(dim=dim, keepdim=True) - c2 = c ** 2 + c2 = c**2 a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw b = -c2 * vw * u2 - c * uw d = 1 + 2 * c * uv + c2 * u2 * v2 @@ -115,7 +115,7 @@ def zero(self): def dist( ## OK self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False, dim=-1 ) -> torch.Tensor: ## OK - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 dist_c = artanh( sqrt_c * _mobius_add(-x, y, self.c, dim=dim).norm(dim=dim, p=2, keepdim=keepdim) @@ -137,7 +137,7 @@ def mobius_add( return res def logmap0(self, x: torch.Tensor, y: torch.Tensor, *, dim=-1) -> torch.Tensor: - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 y_norm = y.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) @@ -147,7 +147,7 @@ def logmap( sub = _mobius_add(-x, y, self.c, dim=dim) sub_norm = sub.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) lam = _lambda_x(x, self.c, keepdim=True, dim=dim) - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm def transp0(self, y: torch.Tensor, v: torch.Tensor, *, dim=-1) -> torch.Tensor: @@ -175,13 +175,13 @@ def logdetexp(self, x, y, is_vector=False, keepdim=False): ## OK ).log() def expmap0(self, u, dim: int = -1): - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) return gamma_1 def expmap(self, x, u, dim: int = -1): - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) second_term = ( tanh(sqrt_c / 2 * _lambda_x(x, self.c, keepdim=True, dim=dim) * u_norm) @@ -192,7 +192,7 @@ def expmap(self, x, u, dim: int = -1): return gamma_1 def expmap_polar(self, x, u, r, dim: int = -1): ## OK - sqrt_c = self.c ** 0.5 + sqrt_c = self.c**0.5 u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) second_term = ( tanh(torch.tensor([sqrt_c]).to(x.device) / 2 * r) * u / (sqrt_c * u_norm) @@ -217,7 +217,7 @@ def normdist2plane( norm: bool = False, ): c = self.c - sqrt_c = c ** 0.5 + sqrt_c = c**0.5 diff = self.mobius_add(-p, x, dim=dim) diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(MIN_NORM) sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim) diff --git a/src/pythae/models/rae_gp/rae_gp_model.py b/src/pythae/models/rae_gp/rae_gp_model.py index 9cec7d28..979a8406 100644 --- a/src/pythae/models/rae_gp/rae_gp_model.py +++ b/src/pythae/models/rae_gp/rae_gp_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - AE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "RAE_GP" @@ -75,7 +74,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z): - recon_loss = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) diff --git a/src/pythae/models/rae_l2/rae_l2_model.py b/src/pythae/models/rae_l2/rae_l2_model.py index dfc09293..54901501 100644 --- a/src/pythae/models/rae_l2/rae_l2_model.py +++ b/src/pythae/models/rae_l2/rae_l2_model.py @@ -39,7 +39,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - AE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "RAE_L2" @@ -76,7 +75,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z): - recon_loss = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index f692b1a5..2d18f7c7 100644 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -59,7 +59,6 @@ def __init__( decoder: Optional[BaseDecoder] = None, metric: Optional[BaseMetric] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "RHVAE" @@ -182,7 +181,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: M.unsqueeze(0) * torch.exp( -torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2 - / (self.temperature ** 2) + / (self.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) @@ -205,7 +204,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: recon_x = self.decoder(z)["reconstruction"] for k in range(self.n_lf): - # perform leapfrog steps # step 1 @@ -217,12 +215,11 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: recon_x = self.decoder(z)["reconstruction"] if self.training: - G_inv = ( M.unsqueeze(0) * torch.exp( -torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2 - / (self.temperature ** 2) + / (self.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) @@ -296,7 +293,6 @@ def predict(self, inputs: torch.Tensor) -> ModelOutput: recon_x = self.decoder(z)["reconstruction"] for k in range(self.n_lf): - # perform leapfrog steps # step 1 @@ -398,7 +394,7 @@ def G(z): self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 ) ** 2 - / (self.temperature ** 2) + / (self.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) @@ -414,7 +410,7 @@ def G_inv(z): self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1 ) ** 2 - / (self.temperature ** 2) + / (self.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) @@ -428,7 +424,6 @@ def G_inv(z): def loss_function( self, recon_x, x, z0, zK, rhoK, eps0, gamma, mu, log_var, G_inv, G_log_det ): - logpxz = self._log_p_xz(recon_x, x, zK) # log p(x, z_K) logrhoK = ( -0.5 @@ -467,23 +462,18 @@ def _tempering(self, k, K): return 1 / beta_k def _log_p_x_given_z(self, recon_x, x): - if self.model_config.reconstruction_loss == "mse": # sigma is taken as I_D - recon_loss = ( - -0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = -0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) -torch.log(torch.tensor([2 * np.pi]).to(x.device)) * np.prod( self.input_dim ) / 2 elif self.model_config.reconstruction_loss == "bce": - recon_loss = -F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -542,7 +532,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -573,7 +562,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): recon_x = self.decoder(z)["reconstruction"] for k in range(self.n_lf): - # perform leapfrog steps # step 1 @@ -598,7 +586,7 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z0_given_x = -0.5 * ( log_var + (z0 - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) log_p_rho0 = normal.log_prob(gamma) - torch.logdet( L / self.beta_zero_sqrt @@ -619,7 +607,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) * self.latent_dim / 2 # rho0 ~ N(0, G(z)) if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -631,7 +618,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -680,7 +666,6 @@ def save(self, dir_path: str): @classmethod def _load_custom_metric_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) @@ -844,7 +829,6 @@ def load_from_hf_hub( ) else: - if not model_config.uses_default_encoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") encoder = cls._load_custom_encoder_from_folder(dir_path) diff --git a/src/pythae/models/rhvae/rhvae_utils.py b/src/pythae/models/rhvae/rhvae_utils.py index ad1a54c8..29c485ea 100644 --- a/src/pythae/models/rhvae/rhvae_utils.py +++ b/src/pythae/models/rhvae/rhvae_utils.py @@ -12,7 +12,7 @@ def G(z): dim=-1, ) ** 2 - / (model.temperature ** 2) + / (model.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) @@ -33,7 +33,7 @@ def G_inv(z): dim=-1, ) ** 2 - / (model.temperature ** 2) + / (model.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index c2ac697f..09afecad 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -43,7 +43,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "SVAE" @@ -98,20 +97,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, loc, concentration, z): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -143,14 +136,13 @@ def _compute_kl(self, m, concentration): return (term1 + term2 + term3).squeeze(-1) def _sample_von_mises(self, loc, concentration): - # Generate uniformly on sphere v = torch.randn_like(loc[:, 1:]) v = v / v.norm(dim=-1, keepdim=True) w = self._acc_rej_steps(m=loc.shape[-1], k=concentration) - w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) + w_ = torch.sqrt(torch.clamp(1 - (w**2), 1e-10)) z = torch.cat((w, w_ * v), dim=-1) return self._householder_rotation(loc, z) @@ -164,10 +156,9 @@ def _householder_rotation(self, loc, z): return z - 2 * u * (u * z).sum(dim=-1, keepdim=True) def _acc_rej_steps(self, m: int, k: torch.Tensor, device: str = "cpu"): - batch_size = k.shape[0] - c = torch.sqrt(4 * k ** 2 + (m - 1) ** 2) + c = torch.sqrt(4 * k**2 + (m - 1) ** 2) b = (-2 * k + c) / (m - 1) a = (m - 1 + 2 * k + c) / 4 @@ -183,7 +174,6 @@ def _acc_rej_steps(self, m: int, k: torch.Tensor, device: str = "cpu"): i = 0 while stopping_mask.sum() > 0 and i < 100: - i += 1 eps = ( @@ -234,7 +224,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -270,7 +259,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -282,7 +270,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py index 2fea4eea..63e1dc1d 100644 --- a/src/pythae/models/vae/vae_model.py +++ b/src/pythae/models/vae/vae_model.py @@ -41,7 +41,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - BaseAE.__init__(self, model_config=model_config, decoder=decoder) self.model_name = "VAE" @@ -98,19 +97,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z): - if self.model_config.reconstruction_loss == "mse": - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -164,12 +158,11 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z_given_x = -0.5 * ( log_var + (z - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -181,7 +174,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/vae_gan/vae_gan_model.py b/src/pythae/models/vae_gan/vae_gan_model.py index 1ccea46a..b6ddf9e7 100644 --- a/src/pythae/models/vae_gan/vae_gan_model.py +++ b/src/pythae/models/vae_gan/vae_gan_model.py @@ -58,7 +58,6 @@ def __init__( decoder: Optional[BaseDecoder] = None, discriminator: Optional[BaseDiscriminator] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) if discriminator is None: @@ -76,7 +75,6 @@ def __init__( self.model_config.uses_default_discriminator = True else: - self.model_config.uses_default_discriminator = False self.set_discriminator(discriminator) @@ -175,7 +173,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z, z_prior, mu, log_var): - N = z.shape[0] # batch size # KL between prior and posterior @@ -192,14 +189,11 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var): )[f"embedding_layer_{self.reconstruction_layer}"] # MSE in feature space - recon_loss = ( - 0.5 - * F.mse_loss( - true_discr_layer.reshape(N, -1), - recon_discr_layer.reshape(N, -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + true_discr_layer.reshape(N, -1), + recon_discr_layer.reshape(N, -1), + reduction="none", + ).sum(dim=-1) encoder_loss = KLD + recon_loss @@ -296,7 +290,6 @@ def save(self, dir_path: str): @classmethod def _load_custom_discriminator_from_folder(cls, dir_path): - file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) @@ -425,7 +418,6 @@ def load_from_hf_hub( ) else: - if not model_config.uses_default_encoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") encoder = cls._load_custom_encoder_from_folder(dir_path) diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py index 1979dfa4..2db0dfd7 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_model.py +++ b/src/pythae/models/vae_iaf/vae_iaf_model.py @@ -42,7 +42,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "VAE_IAF" @@ -103,20 +102,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -171,7 +164,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -191,12 +183,11 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z_given_x = ( -0.5 * (log_var + torch.pow(z0 - mu, 2) / torch.exp(log_var)) ).sum(dim=1) - log_abs_det_jac - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -208,7 +199,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py index a034fda6..180a0c9b 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py @@ -50,7 +50,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "VAE_LinNF" @@ -116,20 +115,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -184,7 +177,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -205,12 +197,11 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z_given_x = ( -0.5 * (log_var + torch.pow(z0 - mu, 2) / torch.exp(log_var)) ).sum(dim=1) - log_abs_det_jac - log_p_z = -0.5 * (z ** 2).sum(dim=-1) + log_p_z = -0.5 * (z**2).sum(dim=-1) recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -222,7 +213,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py index a85f7062..b84a60f5 100644 --- a/src/pythae/models/vamp/vamp_model.py +++ b/src/pythae/models/vamp/vamp_model.py @@ -40,7 +40,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "VAMP" @@ -101,20 +100,14 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, mu, log_var, z, epoch): - if self.model_config.reconstruction_loss == "mse": - - recon_loss = ( - 0.5 - * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) - ) + recon_loss = 0.5 * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) elif self.model_config.reconstruction_loss == "bce": - recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), @@ -160,17 +153,11 @@ def _log_p_z(self, z): prior_mu = prior_mu.unsqueeze(0) prior_log_var = prior_log_var.unsqueeze(0) - log_p_z = ( - torch.sum( - -0.5 - * ( - prior_log_var - + (z_expand - prior_mu) ** 2 / torch.exp(prior_log_var) - ), - dim=2, - ) - - torch.log(torch.tensor(C).type(torch.float)) - ) + log_p_z = torch.sum( + -0.5 + * (prior_log_var + (z_expand - prior_mu) ** 2 / torch.exp(prior_log_var)), + dim=2, + ) - torch.log(torch.tensor(C).type(torch.float)) log_p_z = torch.logsumexp(log_p_z, dim=1) @@ -210,7 +197,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_p_x = [] for j in range(n_full_batch): - x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) @@ -227,7 +213,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": - log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), @@ -239,7 +224,6 @@ def get_nll(self, data, n_samples=1, batch_size=100): ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": - log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), diff --git a/src/pythae/models/vq_vae/vq_vae_model.py b/src/pythae/models/vq_vae/vq_vae_model.py index d36ccb37..72c7fc40 100644 --- a/src/pythae/models/vq_vae/vq_vae_model.py +++ b/src/pythae/models/vq_vae/vq_vae_model.py @@ -41,7 +41,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - AE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self._set_quantizer(model_config) @@ -49,7 +48,6 @@ def __init__( self.model_name = "VQVAE" def _set_quantizer(self, model_config): - if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" @@ -122,7 +120,6 @@ def forward(self, inputs: BaseDataset, **kwargs): return output def loss_function(self, recon_x, x, quantizer_output): - recon_loss = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) diff --git a/src/pythae/models/vq_vae/vq_vae_utils.py b/src/pythae/models/vq_vae/vq_vae_utils.py index 7eaae14c..6a4b6662 100644 --- a/src/pythae/models/vq_vae/vq_vae_utils.py +++ b/src/pythae/models/vq_vae/vq_vae_utils.py @@ -11,7 +11,6 @@ class Quantizer(nn.Module): def __init__(self, model_config: VQVAEConfig): - nn.Module.__init__(self) self.model_config = model_config @@ -28,10 +27,9 @@ def __init__(self, model_config: VQVAEConfig): ) def forward(self, z: torch.Tensor, uses_ddp: bool = False): - distances = ( (z.reshape(-1, self.embedding_dim) ** 2).sum(dim=-1, keepdim=True) - + (self.embeddings.weight ** 2).sum(dim=-1) + + (self.embeddings.weight**2).sum(dim=-1) - 2 * z.reshape(-1, self.embedding_dim) @ self.embeddings.weight.T ) @@ -80,7 +78,6 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False): class QuantizerEMA(nn.Module): def __init__(self, model_config: VQVAEConfig): - nn.Module.__init__(self) self.model_config = model_config @@ -103,10 +100,9 @@ def __init__(self, model_config: VQVAEConfig): self.register_buffer("embeddings", embeddings) def forward(self, z: torch.Tensor, uses_ddp: bool = False): - distances = ( (z.reshape(-1, self.embedding_dim) ** 2).sum(dim=-1, keepdim=True) - + (self.embeddings ** 2).sum(dim=-1) + + (self.embeddings**2).sum(dim=-1) - 2 * z.reshape(-1, self.embedding_dim) @ self.embeddings.T ) @@ -125,7 +121,6 @@ def forward(self, z: torch.Tensor, uses_ddp: bool = False): quantized = quantized.reshape_as(z) if self.training: - n_i = torch.sum(one_hot_encoding, dim=0) if uses_ddp: diff --git a/src/pythae/models/wae_mmd/wae_mmd_model.py b/src/pythae/models/wae_mmd/wae_mmd_model.py index 592d3b01..2370d8c4 100644 --- a/src/pythae/models/wae_mmd/wae_mmd_model.py +++ b/src/pythae/models/wae_mmd/wae_mmd_model.py @@ -40,7 +40,6 @@ def __init__( encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): - AE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "WAE_MMD" @@ -75,7 +74,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output def loss_function(self, recon_x, x, z, z_prior): - N = z.shape[0] # batch size recon_loss = self.reconstruction_loss_scale * F.mse_loss( @@ -94,7 +92,7 @@ def loss_function(self, recon_x, x, z, z_prior): mmd_z = (k_z - k_z.diag().diag()).sum() / ((N - 1) * N) mmd_z_prior = (k_z_prior - k_z_prior.diag().diag()).sum() / ((N - 1) * N) - mmd_cross = k_cross.sum() / (N ** 2) + mmd_cross = k_cross.sum() / (N**2) mmd_loss = mmd_z + mmd_z_prior - 2 * mmd_cross @@ -108,7 +106,7 @@ def imq_kernel(self, z1, z2): """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation""" Cbase = ( - 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth ** 2 + 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth**2 ) k = 0 @@ -122,7 +120,7 @@ def imq_kernel(self, z1, z2): def rbf_kernel(self, z1, z2): """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation""" - C = 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth ** 2 + C = 2.0 * self.model_config.latent_dim * self.model_config.kernel_bandwidth**2 k = torch.exp(-torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2 / C) diff --git a/src/pythae/pipelines/generation.py b/src/pythae/pipelines/generation.py index 97bfd3b9..4b6759c7 100644 --- a/src/pythae/pipelines/generation.py +++ b/src/pythae/pipelines/generation.py @@ -37,7 +37,6 @@ def __init__( model: Optional[BaseAE], sampler_config: Optional[BaseSamplerConfig] = None, ): - if sampler_config is None: sampler_config = NormalSamplerConfig() diff --git a/src/pythae/pipelines/training.py b/src/pythae/pipelines/training.py index 080b9dd9..6527261c 100644 --- a/src/pythae/pipelines/training.py +++ b/src/pythae/pipelines/training.py @@ -47,12 +47,13 @@ def __init__( model: Optional[BaseAE], training_config: Optional[BaseTrainerConfig] = None, ): - if training_config is None: if model.model_name == "RAE_L2": training_config = CoupledOptimizerTrainerConfig( encoder_optimizer_params={"weight_decay": 0}, - decoder_optimizer_params={"weight_decay": model.model_config.reg_weight}, + decoder_optimizer_params={ + "weight_decay": model.model_config.reg_weight + }, ) elif ( @@ -68,7 +69,6 @@ def __init__( elif model.model_name == "RAE_L2" or model.model_name == "PIWAE": if not isinstance(training_config, CoupledOptimizerTrainerConfig): - raise AssertionError( "A 'CoupledOptimizerTrainerConfig' " f"is expected for training a {model.model_name}" @@ -85,7 +85,6 @@ def __init__( elif model.model_name == "Adversarial_AE" or model.model_name == "FactorVAE": if not isinstance(training_config, AdversarialTrainerConfig): - raise AssertionError( "A 'AdversarialTrainer' " f"is expected for training a {model.model_name}" @@ -95,7 +94,6 @@ def __init__( if not isinstance( training_config, CoupledOptimizerAdversarialTrainerConfig ): - raise AssertionError( "A 'CoupledOptimizerAdversarialTrainer' " "is expected for training a VAEGAN" @@ -111,7 +109,6 @@ def __init__( self.training_config = training_config def _check_dataset(self, dataset: BaseDataset): - try: dataset_output = dataset[0] diff --git a/src/pythae/samplers/base/base_sampler.py b/src/pythae/samplers/base/base_sampler.py index 3a6ad06f..2e32aa17 100644 --- a/src/pythae/samplers/base/base_sampler.py +++ b/src/pythae/samplers/base/base_sampler.py @@ -27,7 +27,6 @@ class BaseSampler: """ def __init__(self, model: BaseAE, sampler_config: BaseSamplerConfig = None): - if sampler_config is None: sampler_config = BaseSamplerConfig() @@ -96,7 +95,6 @@ def save_img(self, img_tensor: torch.Tensor, dir_path: str, img_name: str): imwrite(os.path.join(dir_path, f"{img_name}"), img) def _set_inputs_to_device(self, inputs: Dict[str, Any]): - inputs_on_device = inputs if self.device == "cuda": diff --git a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py b/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py index 33d47770..5095841c 100644 --- a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py +++ b/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py @@ -38,7 +38,6 @@ class GaussianMixtureSampler(BaseSampler): def __init__( self, model: BaseAE, sampler_config: GaussianMixtureSamplerConfig = None ): - if sampler_config is None: sampler_config = GaussianMixtureSamplerConfig() @@ -146,7 +145,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - z = ( torch.tensor(self.gmm.sample(batch_size)[0]) .to(self.device) diff --git a/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py b/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py index 87660d1f..2f159ad3 100644 --- a/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py +++ b/src/pythae/samplers/hypersphere_uniform_sampler/hypersphere_uniform_sampler.py @@ -18,7 +18,6 @@ class HypersphereUniformSampler(BaseSampler): def __init__( self, model: BaseAE, sampler_config: HypersphereUniformSamplerConfig = None ): - if sampler_config is None: sampler_config = HypersphereUniformSamplerConfig() diff --git a/src/pythae/samplers/iaf_sampler/iaf_sampler.py b/src/pythae/samplers/iaf_sampler/iaf_sampler.py index 2fc36f4d..111d8a05 100644 --- a/src/pythae/samplers/iaf_sampler/iaf_sampler.py +++ b/src/pythae/samplers/iaf_sampler/iaf_sampler.py @@ -32,7 +32,6 @@ class IAFSampler(BaseSampler): """ def __init__(self, model: BaseAE, sampler_config: IAFSamplerConfig = None): - self.is_fitted = False if sampler_config is None: @@ -114,7 +113,6 @@ def fit( eval_dataset = None if eval_data is not None: - if not isinstance(eval_data, Dataset): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) @@ -202,7 +200,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - u = self.prior.sample((batch_size,)) z = self.iaf_model.inverse(u).out x_gen = self.model.decoder(z).reconstruction.detach() diff --git a/src/pythae/samplers/maf_sampler/maf_sampler.py b/src/pythae/samplers/maf_sampler/maf_sampler.py index 3fd14ed0..27bdab44 100644 --- a/src/pythae/samplers/maf_sampler/maf_sampler.py +++ b/src/pythae/samplers/maf_sampler/maf_sampler.py @@ -32,7 +32,6 @@ class MAFSampler(BaseSampler): """ def __init__(self, model: BaseAE, sampler_config: MAFSamplerConfig = None): - self.is_fitted = False if sampler_config is None: @@ -115,7 +114,6 @@ def fit( eval_dataset = None if eval_data is not None: - if not isinstance(eval_data, Dataset): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) @@ -202,7 +200,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - u = self.prior.sample((batch_size,)) z = self.maf_model.inverse(u).out x_gen = self.model.decoder(z).reconstruction.detach() diff --git a/src/pythae/samplers/manifold_sampler/rhvae_sampler.py b/src/pythae/samplers/manifold_sampler/rhvae_sampler.py index 85cc7494..fd5e5f3d 100644 --- a/src/pythae/samplers/manifold_sampler/rhvae_sampler.py +++ b/src/pythae/samplers/manifold_sampler/rhvae_sampler.py @@ -16,7 +16,6 @@ class RHVAESampler(BaseSampler): """ def __init__(self, model: RHVAE, sampler_config: RHVAESamplerConfig = None): - if sampler_config is None: sampler_config = RHVAESamplerConfig() @@ -65,7 +64,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - samples = self.hmc_sampling(batch_size) x_gen = self.model.decoder(z=samples)["reconstruction"].detach() @@ -98,9 +96,7 @@ def sample( return torch.cat(x_gen_list, dim=0) def hmc_sampling(self, n_samples: int): - with torch.no_grad(): - idx = torch.randint(len(self.model.centroids_tens), (n_samples,)) z0 = self.model.centroids_tens[idx] @@ -108,7 +104,6 @@ def hmc_sampling(self, n_samples: int): beta_sqrt_old = self.beta_zero_sqrt z = z0 for i in range(self.mcmc_steps_nbr): - gamma = torch.randn_like(z, device=self.device) rho = gamma / self.beta_zero_sqrt @@ -116,7 +111,6 @@ def hmc_sampling(self, n_samples: int): # print(model.G_inv(z).det()) for k in range(self.n_lf): - g = -self.grad_func(z, self.model).reshape( n_samples, self.model.latent_dim ) @@ -171,7 +165,7 @@ def grad_log_sqrt_det_G_inv(z, model): @ torch.transpose( ( -2 - / (model.temperature ** 2) + / (model.temperature**2) * (model.centroids_tens.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(2) @ ( model.M_tens.unsqueeze(0) @@ -181,7 +175,7 @@ def grad_log_sqrt_det_G_inv(z, model): dim=-1, ) ** 2 - / (model.temperature ** 2) + / (model.temperature**2) ) .unsqueeze(-1) .unsqueeze(-1) diff --git a/src/pythae/samplers/normal_sampling/normal_sampler.py b/src/pythae/samplers/normal_sampling/normal_sampler.py index 629e7690..afdd9e50 100644 --- a/src/pythae/samplers/normal_sampling/normal_sampler.py +++ b/src/pythae/samplers/normal_sampling/normal_sampler.py @@ -16,7 +16,6 @@ class NormalSampler(BaseSampler): """ def __init__(self, model: BaseAE, sampler_config: NormalSamplerConfig = None): - if sampler_config is None: sampler_config = NormalSamplerConfig() diff --git a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py index 4bdc8d18..9521fa81 100644 --- a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py +++ b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py @@ -32,7 +32,6 @@ class PixelCNNSampler(BaseSampler): """ def __init__(self, model: VQVAE, sampler_config: PixelCNNSamplerConfig = None): - self.is_fitted = False if sampler_config is None: @@ -111,7 +110,6 @@ def fit( eval_dataset = None if eval_data is not None: - if not isinstance(eval_data, Dataset): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) @@ -196,7 +194,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - z = torch.zeros( (batch_size,) + self.pixelcnn_model.model_config.input_dim ).to(self.device) diff --git a/src/pythae/samplers/pvae_sampler/pvae_sampler.py b/src/pythae/samplers/pvae_sampler/pvae_sampler.py index 8eb45b17..f929381e 100644 --- a/src/pythae/samplers/pvae_sampler/pvae_sampler.py +++ b/src/pythae/samplers/pvae_sampler/pvae_sampler.py @@ -19,7 +19,6 @@ class PoincareDiskSampler(BaseSampler): def __init__( self, model: PoincareVAE, sampler_config: PoincareDiskSamplerConfig = None ): - assert isinstance( model, PoincareVAE ), "This sampler is only suitable for PoincareVAE model" @@ -64,7 +63,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - z = self.gen_distribution.rsample(torch.Size([batch_size])).reshape( batch_size, -1 ) @@ -79,7 +77,6 @@ def sample( x_gen_list.append(x_gen) if last_batch_samples_nbr > 0: - z = self.gen_distribution.rsample( torch.Size([last_batch_samples_nbr]) ).reshape(last_batch_samples_nbr, -1) diff --git a/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py b/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py index d2c5c874..dd3e8169 100644 --- a/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py +++ b/src/pythae/samplers/two_stage_vae_sampler/two_stage_sampler.py @@ -19,7 +19,6 @@ class SecondEncoder(BaseEncoder): def __init__(self, model: VAE, sampler_config: TwoStageVAESamplerConfig): - BaseEncoder.__init__(self) layers = [] @@ -55,7 +54,6 @@ def forward(self, z: torch.Tensor): class SecondDecoder(BaseDecoder): def __init__(self, model: VAE, sampler_config: TwoStageVAESamplerConfig): - BaseDecoder.__init__(self) self.gamma_z = nn.Parameter(torch.ones(1, 1), requires_grad=True) @@ -110,7 +108,6 @@ class TwoStageVAESampler(BaseSampler): """ def __init__(self, model: VAE, sampler_config: TwoStageVAESamplerConfig = None): - assert issubclass(model.__class__, VAE), ( "The TwoStageVAESampler is only" f"applicable for VAE based models. Got {model.__class__}." @@ -194,7 +191,6 @@ def fit( eval_dataset = None if eval_data is not None: - if not isinstance(eval_data, Dataset): eval_data = data_processor.process_data(eval_data) eval_dataset = data_processor.to_dataset(eval_data) @@ -282,7 +278,6 @@ def sample( x_gen_list = [] for i in range(full_batch_nbr): - u = torch.randn(batch_size, self.model.latent_dim).to(self.device) z = self.second_vae.decoder(u).reconstruction x_gen = self.model.decoder(z)["reconstruction"].detach() diff --git a/src/pythae/samplers/vamp_sampler/vamp_sampler.py b/src/pythae/samplers/vamp_sampler/vamp_sampler.py index 5f9100e3..e4de9353 100644 --- a/src/pythae/samplers/vamp_sampler/vamp_sampler.py +++ b/src/pythae/samplers/vamp_sampler/vamp_sampler.py @@ -16,7 +16,6 @@ class VAMPSampler(BaseSampler): """ def __init__(self, model: VAMP, sampler_config: VAMPSamplerConfig = None): - assert isinstance(model, VAMP), "This sampler is only suitable for VAMP model" if sampler_config is None: diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py index 4dc1d6e9..4ebb7fca 100644 --- a/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py +++ b/src/pythae/trainers/adversarial_trainer/adversarial_trainer.py @@ -55,7 +55,6 @@ def __init__( training_config: Optional[AdversarialTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): - BaseTrainer.__init__( self, model=model, @@ -179,7 +178,6 @@ def set_discriminator_scheduler(self) -> torch.optim.lr_scheduler: self.discriminator_scheduler = scheduler def _optimizers_step(self, model_output): - autoencoder_loss = model_output.autoencoder_loss discriminator_loss = model_output.discriminator_loss @@ -193,7 +191,6 @@ def _optimizers_step(self, model_output): self.discriminator_optimizer.step() def _schedulers_step(self, autoencoder_metrics=None, discriminator_metrics=None): - if self.autoencoder_scheduler is None: pass @@ -278,7 +275,6 @@ def train(self, log_output_dir: str = None): best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): - self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, @@ -345,7 +341,6 @@ def train(self, log_output_dir: str = None): and epoch % self.training_config.steps_predict == 0 and self.is_main_process ): - true_data, reconstructions, generations = self.predict(best_model) self.callback_handler.on_prediction_step( @@ -416,12 +411,10 @@ def eval_step(self, epoch: int): epoch_loss = 0 for inputs in self.eval_loader: - inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): - model_output = self.model( inputs, epoch=epoch, @@ -481,7 +474,6 @@ def train_step(self, epoch: int): epoch_loss = 0 for inputs in self.train_loader: - inputs = self._set_inputs_to_device(inputs) model_output = self.model( diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py index 7212b9b2..073deede 100644 --- a/src/pythae/trainers/base_trainer/base_trainer.py +++ b/src/pythae/trainers/base_trainer/base_trainer.py @@ -61,7 +61,6 @@ def __init__( training_config: Optional[BaseTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): - if training_config is None: training_config = BaseTrainerConfig() @@ -353,7 +352,6 @@ def _set_optimizer_on_device(self, optim, device): return optim def _set_inputs_to_device(self, inputs: Dict[str, Any]): - inputs_on_device = inputs if self.device == "cuda": @@ -370,7 +368,6 @@ def _set_inputs_to_device(self, inputs: Dict[str, Any]): return inputs_on_device def _optimizers_step(self, model_output=None): - loss = model_output.loss self.optimizer.zero_grad() @@ -448,7 +445,6 @@ def train(self, log_output_dir: str = None): best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): - self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, @@ -561,12 +557,10 @@ def eval_step(self, epoch: int): with self.amp_context: for inputs in self.eval_loader: - inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): - model_output = self.model( inputs, epoch=epoch, @@ -619,7 +613,6 @@ def train_step(self, epoch: int): epoch_loss = 0 for inputs in self.train_loader: - inputs = self._set_inputs_to_device(inputs) with self.amp_context: @@ -712,7 +705,6 @@ def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): self.training_config.save_json(checkpoint_dir, "training_config") def predict(self, model: BaseAE): - model.eval() with self.amp_context: diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py index 5e46dfeb..2449d1a7 100644 --- a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py +++ b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer.py @@ -61,7 +61,6 @@ def __init__( training_config: Optional[CoupledOptimizerAdversarialTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): - BaseTrainer.__init__( self, model=model, @@ -226,7 +225,6 @@ def set_discriminator_scheduler(self) -> torch.optim.lr_scheduler: self.discriminator_scheduler = scheduler def _optimizers_step(self, model_output): - encoder_loss = model_output.encoder_loss decoder_loss = model_output.decoder_loss discriminator_loss = model_output.discriminator_loss @@ -355,7 +353,6 @@ def train(self, log_output_dir: str = None): best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): - self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, @@ -427,7 +424,6 @@ def train(self, log_output_dir: str = None): and epoch % self.training_config.steps_predict == 0 and self.is_main_process ): - true_data, reconstructions, generations = self.predict(best_model) self.callback_handler.on_prediction_step( @@ -499,12 +495,10 @@ def eval_step(self, epoch: int): epoch_loss = 0 for inputs in self.eval_loader: - inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): - model_output = self.model( inputs, epoch=epoch, @@ -573,7 +567,6 @@ def train_step(self, epoch: int): epoch_loss = 0 for inputs in self.train_loader: - inputs = self._set_inputs_to_device(inputs) model_output = self.model( diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py index 5c1b6dea..39faa784 100644 --- a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py +++ b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer.py @@ -54,7 +54,6 @@ def __init__( training_config: Optional[CoupledOptimizerTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): - BaseTrainer.__init__( self, model=model, @@ -166,7 +165,6 @@ def set_decoder_scheduler(self) -> torch.optim.lr_scheduler: self.decoder_scheduler = scheduler def _optimizers_step(self, model_output): - encoder_loss = model_output.encoder_loss decoder_loss = model_output.decoder_loss @@ -186,7 +184,6 @@ def _optimizers_step(self, model_output): self.decoder_optimizer.step() def _schedulers_step(self, encoder_metrics=None, decoder_metrics=None): - if self.encoder_scheduler is None: pass @@ -206,7 +203,6 @@ def _schedulers_step(self, encoder_metrics=None, decoder_metrics=None): self.decoder_scheduler.step() def prepare_training(self): - # set random seed set_seed(self.training_config.seed) @@ -271,7 +267,6 @@ def train(self, log_output_dir: str = None): best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): - self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, @@ -337,7 +332,6 @@ def train(self, log_output_dir: str = None): and epoch % self.training_config.steps_predict == 0 and self.is_main_process ): - true_data, reconstructions, generations = self.predict(best_model) self.callback_handler.on_prediction_step( @@ -408,12 +402,10 @@ def eval_step(self, epoch: int): epoch_loss = 0 for inputs in self.eval_loader: - inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): - model_output = self.model( inputs, epoch=epoch, @@ -473,7 +465,6 @@ def train_step(self, epoch: int): epoch_loss = 0 for inputs in self.train_loader: - inputs = self._set_inputs_to_device(inputs) model_output = self.model( diff --git a/src/pythae/trainers/training_callbacks.py b/src/pythae/trainers/training_callbacks.py index 9064b8f0..32d43d47 100644 --- a/src/pythae/trainers/training_callbacks.py +++ b/src/pythae/trainers/training_callbacks.py @@ -202,7 +202,6 @@ def __init__(self): self.logger.setLevel(logging.INFO) def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs): - logger = kwargs.pop("logger", self.logger) rank = kwargs.pop("rank", -1) @@ -393,7 +392,6 @@ def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs): and generations is not None ): for i in range(len(true_data)): - data_to_log.append( [ f"img_{i}", @@ -561,7 +559,6 @@ def setup( offline_directory: str = "./", **kwargs, ): - """ Setup the CometCallback. @@ -631,7 +628,6 @@ def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs): and generations is not None ): for i in range(len(true_data)): - experiment.log_image( np.moveaxis(true_data[i].cpu().detach().numpy(), 0, -1), name=f"{i}_truth", diff --git a/tests/test_AE.py b/tests/test_AE.py index c70ad44e..de077ebd 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -87,7 +87,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = AE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -110,7 +109,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -138,7 +136,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -166,7 +163,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -196,7 +192,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -232,7 +227,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -281,7 +275,6 @@ def ae(self, model_configs, demo_data): return AE(model_configs) def test_model_train_output(self, ae, demo_data): - ae.train() out = ae(demo_data) @@ -297,8 +290,8 @@ def test_model_train_output(self, ae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -322,21 +315,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -351,15 +340,15 @@ def ae(self, model_configs, demo_data): return AE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape + class Test_Model_predict: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -374,15 +363,18 @@ def ae(self, model_configs, demo_data): return AE(model_configs) def test_predict(self, ae, demo_data): - model_output = ae.predict(demo_data) - assert tuple(model_output.embedding.shape) == (demo_data.shape[0], ae.model_config.latent_dim) + assert tuple(model_output.embedding.shape) == ( + demo_data.shape[0], + ae.model_config.latent_dim, + ) + class Test_Model_embed: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -397,9 +389,12 @@ def ae(self, model_configs, demo_data): return AE(model_configs) def test_embed(self, ae, demo_data): - embedding = ae.embed(demo_data) - assert tuple(embedding.shape) == (demo_data.shape[0], ae.model_config.latent_dim) + assert tuple(embedding.shape) == ( + demo_data.shape[0], + ae.model_config.latent_dim, + ) + @pytest.mark.slow class Test_AE_Training: @@ -458,7 +453,6 @@ def trainer(self, ae, train_dataset, training_configs): return trainer def test_ae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -474,7 +468,6 @@ def test_ae_train_step(self, trainer): ) def test_ae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -490,7 +483,6 @@ def test_ae_eval_step(self, trainer): ) def test_ae_predict_step(self, train_dataset, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -510,7 +502,6 @@ def test_ae_predict_step(self, train_dataset, trainer): assert generated.shape == inputs.shape def test_ae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -526,7 +517,6 @@ def test_ae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, ae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -664,7 +654,6 @@ def test_checkpoint_saving_during_training(self, ae, trainer, training_configs): ) def test_final_model_saving(self, ae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -715,7 +704,6 @@ def test_final_model_saving(self, ae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_ae_training_pipeline(self, tmpdir, ae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py index 0265fa8e..74cc29d8 100644 --- a/tests/test_Adversarial_AE.py +++ b/tests/test_Adversarial_AE.py @@ -142,7 +142,6 @@ def test_raises_no_input_dim( def test_build_custom_arch( self, model_configs, custom_encoder, custom_decoder, custom_discriminator ): - adversarial_ae = Adversarial_AE( model_configs, encoder=custom_encoder, decoder=custom_decoder ) @@ -168,7 +167,6 @@ def test_build_custom_arch( class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -196,7 +194,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -224,7 +221,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -254,7 +250,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_custom_discriminator_model_saving( self, tmpdir, model_configs, custom_discriminator ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -289,7 +284,6 @@ def test_full_custom_model_saving( custom_decoder, custom_discriminator, ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -336,7 +330,6 @@ def test_raises_missing_files( custom_decoder, custom_discriminator, ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -396,7 +389,6 @@ def adversarial_ae(self, model_configs, demo_data): return Adversarial_AE(model_configs) def test_model_train_output(self, adversarial_ae, demo_data): - # model_configs.input_dim = demo_data['data'][0].shape[-1] # adversarial_ae = Adversarial_AE(model_configs) @@ -407,19 +399,16 @@ def test_model_train_output(self, adversarial_ae, demo_data): assert isinstance(out, ModelOutput) - assert ( - set( - [ - "loss", - "recon_loss", - "autoencoder_loss", - "discriminator_loss", - "recon_x", - "z", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_loss", + "autoencoder_loss", + "discriminator_loss", + "recon_x", + "z", + ] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape @@ -428,8 +417,8 @@ def test_model_train_output(self, adversarial_ae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -453,21 +442,17 @@ def test_interpolate(self, adversarial_ae, demo_data, granularity): interp = adversarial_ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -482,7 +467,6 @@ def adversarial_ae(self, model_configs, demo_data): return Adversarial_AE(model_configs) def test_reconstruct(self, adversarial_ae, demo_data): - recon = adversarial_ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -580,7 +564,6 @@ def trainer(self, adversarial_ae, train_dataset, training_configs): return trainer def test_adversarial_ae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -596,7 +579,6 @@ def test_adversarial_ae_train_step(self, trainer): ) def test_adversarial_ae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -612,7 +594,6 @@ def test_adversarial_ae_eval_step(self, trainer): ) def test_adversarial_ae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -632,7 +613,6 @@ def test_adversarial_ae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_adversarial_ae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -648,7 +628,6 @@ def test_adversarial_ae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, adversarial_ae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -844,7 +823,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, adversarial_ae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -905,7 +883,6 @@ def test_final_model_saving(self, adversarial_ae, trainer, training_configs): def test_adversarial_ae_training_pipeline( self, adversarial_ae, train_dataset, training_configs ): - with pytest.raises(AssertionError): pipeline = TrainingPipeline( model=adversarial_ae, training_config=BaseTrainerConfig() diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index 396be3ee..02a69f19 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -97,7 +97,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = BetaTCVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -120,7 +119,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -148,7 +146,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -176,7 +173,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -206,7 +202,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -242,7 +237,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -291,7 +285,6 @@ def betavae(self, model_configs, demo_data): return BetaTCVAE(model_configs) def test_model_train_output(self, betavae, demo_data): - betavae.train() out = betavae(demo_data) @@ -309,8 +302,8 @@ def test_model_train_output(self, betavae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -334,21 +327,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -363,7 +352,6 @@ def ae(self, model_configs, demo_data): return BetaTCVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -453,7 +441,6 @@ def trainer(self, betavae, train_dataset, training_configs): return trainer def test_betavae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -469,7 +456,6 @@ def test_betavae_train_step(self, trainer): ) def test_betavae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -485,7 +471,6 @@ def test_betavae_eval_step(self, trainer): ) def test_betavae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -505,7 +490,6 @@ def test_betavae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_betavae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -521,7 +505,6 @@ def test_betavae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -661,7 +644,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -714,7 +696,6 @@ def test_final_model_saving(self, betavae, trainer, training_configs): def test_betavae_training_pipeline( self, tmpdir, betavae, train_dataset, training_configs ): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index 051837cc..02515b98 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -88,7 +88,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = BetaVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -111,7 +110,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -139,7 +137,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -167,7 +164,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -197,7 +193,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -233,7 +228,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -282,7 +276,6 @@ def betavae(self, model_configs, demo_data): return BetaVAE(model_configs) def test_model_train_output(self, betavae, demo_data): - betavae.train() out = betavae(demo_data) @@ -300,8 +293,8 @@ def test_model_train_output(self, betavae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -325,21 +318,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -354,7 +343,6 @@ def ae(self, model_configs, demo_data): return BetaVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -444,7 +432,6 @@ def trainer(self, betavae, train_dataset, training_configs): return trainer def test_betavae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -460,7 +447,6 @@ def test_betavae_train_step(self, trainer): ) def test_betavae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -476,7 +462,6 @@ def test_betavae_eval_step(self, trainer): ) def test_betavae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -496,7 +481,6 @@ def test_betavae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_betavae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -512,7 +496,6 @@ def test_betavae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -652,7 +635,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -703,7 +685,6 @@ def test_final_model_saving(self, betavae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_betavae_training_pipeline(self, betavae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py index e537c893..8a74680d 100644 --- a/tests/test_CIWAE.py +++ b/tests/test_CIWAE.py @@ -3,8 +3,8 @@ import pytest import torch - from pydantic import ValidationError + from pythae.customexception import BadInheritanceError from pythae.models import CIWAE, AutoModel, CIWAEConfig from pythae.models.base.base_utils import ModelOutput @@ -97,7 +97,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = CIWAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -120,7 +119,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -148,7 +146,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -176,7 +173,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -206,7 +202,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -242,7 +237,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -291,7 +285,6 @@ def CIWAE(self, model_configs, demo_data): return CIWAE(model_configs) def test_model_train_output(self, CIWAE, demo_data): - CIWAE.train() out = CIWAE(demo_data) @@ -312,8 +305,8 @@ def test_model_train_output(self, CIWAE, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -337,21 +330,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -366,7 +355,6 @@ def ae(self, model_configs, demo_data): return CIWAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -454,7 +442,6 @@ def trainer(self, ciwae, train_dataset, training_configs): return trainer def test_ciwae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -470,7 +457,6 @@ def test_ciwae_train_step(self, trainer): ) def test_ciwae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -486,7 +472,6 @@ def test_ciwae_eval_step(self, trainer): ) def test_ciwae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -506,7 +491,6 @@ def test_ciwae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_ciwae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -522,7 +506,6 @@ def test_ciwae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, ciwae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -660,7 +643,6 @@ def test_checkpoint_saving_during_training(self, ciwae, trainer, training_config ) def test_final_model_saving(self, ciwae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -711,7 +693,6 @@ def test_final_model_saving(self, ciwae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_CIWAE_training_pipeline(self, ciwae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index 8c567a18..8bf12275 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -101,7 +101,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = DisentangledBetaVAE( model_configs, encoder=custom_encoder, decoder=custom_decoder ) @@ -126,7 +125,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -154,7 +152,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -182,7 +179,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -212,7 +208,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -250,7 +245,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -301,7 +295,6 @@ def betavae(self, model_configs, demo_data): return DisentangledBetaVAE(model_configs) def test_model_train_output(self, betavae, demo_data): - betavae.train() out = betavae(demo_data) @@ -319,8 +312,8 @@ def test_model_train_output(self, betavae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -344,21 +337,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -373,7 +362,6 @@ def ae(self, model_configs, demo_data): return DisentangledBetaVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -463,7 +451,6 @@ def trainer(self, betavae, train_dataset, training_configs): return trainer def test_betavae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -479,7 +466,6 @@ def test_betavae_train_step(self, trainer): ) def test_betavae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -495,7 +481,6 @@ def test_betavae_eval_step(self, trainer): ) def test_betavae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -515,7 +500,6 @@ def test_betavae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_betavae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -531,7 +515,6 @@ def test_betavae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -671,7 +654,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, betavae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -722,7 +704,6 @@ def test_final_model_saving(self, betavae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_betavae_training_pipeline(self, betavae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py index e8a99765..c5742ba0 100644 --- a/tests/test_FactorVAE.py +++ b/tests/test_FactorVAE.py @@ -97,7 +97,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - factor_ae = FactorVAE( model_configs, encoder=custom_encoder, decoder=custom_decoder ) @@ -113,7 +112,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -141,7 +139,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -169,7 +166,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -199,7 +195,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -235,7 +230,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -284,7 +278,6 @@ def factor_ae(self, model_configs, demo_data): return FactorVAE(model_configs) def test_model_train_output(self, factor_ae, demo_data): - # model_configs.input_dim = demo_data['data'][0].shape[-1] # factor_ae = FactorVAE(model_configs) @@ -298,21 +291,18 @@ def test_model_train_output(self, factor_ae, demo_data): assert isinstance(out, ModelOutput) - assert ( - set( - [ - "loss", - "recon_loss", - "autoencoder_loss", - "discriminator_loss", - "recon_x", - "recon_x_indices", - "z", - "z_bis_permuted", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_loss", + "autoencoder_loss", + "discriminator_loss", + "recon_x", + "recon_x_indices", + "z", + "z_bis_permuted", + ] + ) == set(out.keys()) assert out.z.shape[0] == int(demo_data["data"].shape[0] / 2) + 1 * ( demo_data["data"].shape[0] % 2 != 0 @@ -322,9 +312,9 @@ def test_model_train_output(self, factor_ae, demo_data): int(demo_data["data"].shape[0] / 2) + 1 * (demo_data["data"].shape[0] % 2 != 0), ) + (demo_data["data"].shape[1:]) - assert out.recon_x_indices.shape[0] == int(demo_data["data"].shape[0] / 2) + 1 * ( - demo_data["data"].shape[0] % 2 != 0 - ) + assert out.recon_x_indices.shape[0] == int( + demo_data["data"].shape[0] / 2 + ) + 1 * (demo_data["data"].shape[0] % 2 != 0) assert not torch.equal(out.z, out.z_bis_permuted) @@ -332,8 +322,8 @@ def test_model_train_output(self, factor_ae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -357,21 +347,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -386,7 +372,6 @@ def ae(self, model_configs, demo_data): return FactorVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -495,7 +480,6 @@ def trainer(self, factor_ae, train_dataset, training_configs): return trainer def test_factor_ae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -511,7 +495,6 @@ def test_factor_ae_train_step(self, trainer): ) def test_factor_ae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -527,7 +510,6 @@ def test_factor_ae_eval_step(self, trainer): ) def test_factor_ae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -551,7 +533,6 @@ def test_factor_ae_predict_step(self, trainer, train_dataset): ) + (inputs.shape[1:]) def test_factor_ae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -567,7 +548,6 @@ def test_factor_ae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, factor_ae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -749,7 +729,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, factor_ae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -803,7 +782,6 @@ def test_final_model_saving(self, factor_ae, trainer, training_configs): def test_factor_ae_training_pipeline( self, factor_ae, train_dataset, training_configs ): - with pytest.raises(AssertionError): pipeline = TrainingPipeline( model=factor_ae, training_config=BaseTrainerConfig() diff --git a/tests/test_HVAE.py b/tests/test_HVAE.py index 56f03ea5..7fbdf3e5 100644 --- a/tests/test_HVAE.py +++ b/tests/test_HVAE.py @@ -98,7 +98,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = HVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -121,7 +120,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -149,7 +147,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -177,7 +174,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -207,7 +203,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -243,7 +238,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -292,7 +286,6 @@ def hvae(self, model_configs, demo_data): return HVAE(model_configs) def test_model_train_output(self, hvae, demo_data): - hvae.train() out = hvae(demo_data) @@ -336,8 +329,8 @@ def test_nll_compute(self, hvae, demo_data, nll_params): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -361,21 +354,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -390,7 +379,6 @@ def ae(self, model_configs, demo_data): return HVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -452,7 +440,6 @@ def trainer(self, hvae, train_dataset, training_configs): return trainer def test_hvae_train_step(self, hvae, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -486,7 +473,6 @@ def test_hvae_train_step(self, hvae, trainer): ) def test_hvae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -502,7 +488,6 @@ def test_hvae_eval_step(self, trainer): ) def test_hvae_predict_step(self, train_dataset, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -522,7 +507,6 @@ def test_hvae_predict_step(self, train_dataset, trainer): assert generated.shape == inputs.shape def test_hvae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -538,7 +522,6 @@ def test_hvae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, hvae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -676,7 +659,6 @@ def test_checkpoint_saving_during_training(self, hvae, trainer, training_configs ) def test_final_model_saving(self, hvae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -729,7 +711,6 @@ def test_final_model_saving(self, hvae, trainer, training_configs): def test_hvae_training_pipeline( self, tmpdir, hvae, train_dataset, training_configs ): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index 57220f96..be2d63cd 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -100,7 +100,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = IWAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -123,7 +122,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -151,7 +149,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -179,7 +176,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -209,7 +205,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -245,7 +240,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -294,7 +288,6 @@ def iwae(self, model_configs, demo_data): return IWAE(model_configs) def test_model_train_output(self, iwae, demo_data): - iwae.train() out = iwae(demo_data) @@ -315,8 +308,8 @@ def test_model_train_output(self, iwae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -340,21 +333,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -369,7 +358,6 @@ def ae(self, model_configs, demo_data): return IWAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -457,7 +445,6 @@ def trainer(self, iwae, train_dataset, training_configs): return trainer def test_iwae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -473,7 +460,6 @@ def test_iwae_train_step(self, trainer): ) def test_iwae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -489,7 +475,6 @@ def test_iwae_eval_step(self, trainer): ) def test_iwae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -509,7 +494,6 @@ def test_iwae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_iwae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -525,7 +509,6 @@ def test_iwae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, iwae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -663,7 +646,6 @@ def test_checkpoint_saving_during_training(self, iwae, trainer, training_configs ) def test_final_model_saving(self, iwae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -716,7 +698,6 @@ def test_final_model_saving(self, iwae, trainer, training_configs): def test_iwae_training_pipeline( self, tmpdir, iwae, train_dataset, training_configs ): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_MIWAE.py b/tests/test_MIWAE.py index 6030c029..8a2cd007 100644 --- a/tests/test_MIWAE.py +++ b/tests/test_MIWAE.py @@ -95,7 +95,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = MIWAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -118,7 +117,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -146,7 +144,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -174,7 +171,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -204,7 +200,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -240,7 +235,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -289,7 +283,6 @@ def MIWAE(self, model_configs, demo_data): return MIWAE(model_configs) def test_model_train_output(self, MIWAE, demo_data): - MIWAE.train() out = MIWAE(demo_data) @@ -310,8 +303,8 @@ def test_model_train_output(self, MIWAE, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -335,21 +328,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -364,7 +353,6 @@ def ae(self, model_configs, demo_data): return MIWAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -452,7 +440,6 @@ def trainer(self, miwae, train_dataset, training_configs): return trainer def test_miwae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -468,7 +455,6 @@ def test_miwae_train_step(self, trainer): ) def test_miwae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -484,7 +470,6 @@ def test_miwae_eval_step(self, trainer): ) def test_miwae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -504,7 +489,6 @@ def test_miwae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_miwae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -520,7 +504,6 @@ def test_miwae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, miwae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -658,7 +641,6 @@ def test_checkpoint_saving_during_training(self, miwae, trainer, training_config ) def test_final_model_saving(self, miwae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -709,7 +691,6 @@ def test_final_model_saving(self, miwae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_MIWAE_training_pipeline(self, miwae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_PIWAE.py b/tests/test_PIWAE.py index 78c66928..4c8a364e 100644 --- a/tests/test_PIWAE.py +++ b/tests/test_PIWAE.py @@ -91,7 +91,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = PIWAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -114,7 +113,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -142,7 +140,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -170,7 +167,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -200,7 +196,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -236,7 +231,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -285,29 +279,25 @@ def piwae(self, model_configs, demo_data): return PIWAE(model_configs) def test_model_train_output(self, piwae, demo_data): - piwae.train() out = piwae(demo_data) assert isinstance(out, ModelOutput) - assert ( - set( - [ - "loss", - "recon_loss", - "encoder_loss", - "decoder_loss", - "update_encoder", - "update_decoder", - "reg_loss", - "recon_x", - "z", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_loss", + "encoder_loss", + "decoder_loss", + "update_encoder", + "update_decoder", + "reg_loss", + "recon_x", + "z", + ] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape @@ -316,8 +306,8 @@ def test_model_train_output(self, piwae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -341,21 +331,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -370,7 +356,6 @@ def ae(self, model_configs, demo_data): return PIWAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -441,7 +426,6 @@ def trainer(self, piwae, train_dataset, training_configs): return trainer def test_piwae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -457,7 +441,6 @@ def test_piwae_train_step(self, trainer): ) def test_piwae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -473,7 +456,6 @@ def test_piwae_eval_step(self, trainer): ) def test_piwae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -493,7 +475,6 @@ def test_piwae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_piwae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -509,7 +490,6 @@ def test_piwae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, piwae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -684,7 +664,6 @@ def test_checkpoint_saving_during_training(self, piwae, trainer, training_config ) def test_final_model_saving(self, piwae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -737,7 +716,6 @@ def test_final_model_saving(self, piwae, trainer, training_configs): def test_piwae_training_pipeline( self, tmpdir, piwae, train_dataset, training_configs ): - with pytest.raises(AssertionError): pipeline = TrainingPipeline( model=piwae, training_config=BaseTrainerConfig() diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py index e54d4694..8d94243f 100644 --- a/tests/test_PoincareVAE.py +++ b/tests/test_PoincareVAE.py @@ -116,7 +116,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = PoincareVAE( model_configs, encoder=custom_encoder, decoder=custom_decoder ) @@ -139,7 +138,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): assert not model.model_config.uses_default_decoder def test_misc_manifold_func(self): - manifold = PoincareBall(dim=2, c=0.7) x = torch.randn(10, 2) y = torch.randn(10, 2) @@ -152,7 +150,6 @@ def test_misc_manifold_func(self): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -180,7 +177,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -208,7 +204,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -238,7 +233,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -276,7 +270,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -327,7 +320,6 @@ def vae(self, model_configs, demo_data): return PoincareVAE(model_configs) def test_model_train_output(self, vae, demo_data): - vae.train() out = vae(demo_data) @@ -345,8 +337,8 @@ def test_model_train_output(self, vae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -370,21 +362,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -399,7 +387,6 @@ def ae(self, model_configs, demo_data): return PoincareVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -489,7 +476,6 @@ def trainer(self, vae, train_dataset, training_configs): return trainer def test_vae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -505,7 +491,6 @@ def test_vae_train_step(self, trainer): ) def test_vae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -521,7 +506,6 @@ def test_vae_eval_step(self, trainer): ) def test_vae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -541,7 +525,6 @@ def test_vae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -557,7 +540,6 @@ def test_vae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -695,7 +677,6 @@ def test_checkpoint_saving_during_training(self, vae, trainer, training_configs) ) def test_final_model_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -746,7 +727,6 @@ def test_final_model_saving(self, vae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vae_training_pipeline(self, vae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py index 35809316..6412279c 100644 --- a/tests/test_RHVAE.py +++ b/tests/test_RHVAE.py @@ -110,7 +110,6 @@ def test_raises_no_input_dim( def test_build_custom_arch( self, model_configs, custom_encoder, custom_decoder, custom_metric ): - rhvae = RHVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert rhvae.encoder == custom_encoder @@ -132,7 +131,6 @@ def test_build_custom_arch( class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -170,7 +168,6 @@ def test_default_model_saving(self, tmpdir, model_configs): assert callable(model_rec.G_inv) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -204,7 +201,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder assert callable(model_rec.G_inv) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -238,7 +234,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder assert callable(model_rec.G_inv) def test_custom_metric_model_saving(self, tmpdir, model_configs, custom_metric): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -274,7 +269,6 @@ def test_custom_metric_model_saving(self, tmpdir, model_configs, custom_metric): def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder, custom_metric ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -337,7 +331,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder, custom_metric ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -397,7 +390,6 @@ def rhvae(self, model_configs, demo_data): return RHVAE(model_configs) def test_model_train_output(self, rhvae, demo_data): - # model_configs.input_dim = demo_data['data'][0].shape[-1] # rhvae = RHVAE(model_configs) @@ -405,52 +397,45 @@ def test_model_train_output(self, rhvae, demo_data): rhvae.train() out = rhvae(demo_data) - assert ( - set( - [ - "loss", - "recon_x", - "z", - "z0", - "rho", - "eps0", - "gamma", - "mu", - "log_var", - "G_inv", - "G_log_det", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_x", + "z", + "z0", + "rho", + "eps0", + "gamma", + "mu", + "log_var", + "G_inv", + "G_log_det", + ] + ) == set(out.keys()) rhvae.update() def test_model_output(self, rhvae, demo_data): - # model_configs.input_dim = demo_data['data'][0].shape[-1] rhvae.eval() out = rhvae(demo_data) - assert ( - set( - [ - "loss", - "recon_x", - "z", - "z0", - "rho", - "eps0", - "gamma", - "mu", - "log_var", - "G_inv", - "G_log_det", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_x", + "z", + "z0", + "rho", + "eps0", + "gamma", + "mu", + "log_var", + "G_inv", + "G_log_det", + ] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape @@ -459,8 +444,8 @@ def test_model_output(self, rhvae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -484,21 +469,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -513,7 +494,6 @@ def ae(self, model_configs, demo_data): return RHVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -613,7 +593,6 @@ def trainer(self, rhvae, train_dataset, training_configs): return optimizer def test_rhvae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -629,7 +608,6 @@ def test_rhvae_train_step(self, trainer): ) def test_rhvae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -645,7 +623,6 @@ def test_rhvae_eval_step(self, trainer): ) def test_rhvae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -665,7 +642,6 @@ def test_rhvae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_rhvae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -681,7 +657,6 @@ def test_rhvae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, rhvae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -840,7 +815,6 @@ def test_checkpoint_saving_during_training(self, rhvae, trainer, training_config ) def test_final_model_saving(self, rhvae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -901,7 +875,6 @@ def test_final_model_saving(self, rhvae, trainer, training_configs): assert type(model_rec.metric.cpu()) == type(model.metric.cpu()) def test_rhvae_training_pipeline(self, rhvae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index 5cd9f389..5871600f 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -82,7 +82,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = SVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -105,7 +104,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -133,7 +131,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -161,7 +158,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -191,7 +187,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -227,7 +222,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -276,7 +270,6 @@ def svae(self, model_configs, demo_data): return SVAE(model_configs) def test_model_train_output(self, svae, demo_data): - svae.train() out = svae(demo_data) @@ -294,8 +287,8 @@ def test_model_train_output(self, svae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -319,21 +312,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -348,7 +337,6 @@ def ae(self, model_configs, demo_data): return SVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -436,7 +424,6 @@ def trainer(self, svae, train_dataset, training_configs): return trainer def test_svae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -452,7 +439,6 @@ def test_svae_train_step(self, trainer): ) def test_svae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -468,7 +454,6 @@ def test_svae_eval_step(self, trainer): ) def test_svae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -488,7 +473,6 @@ def test_svae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_svae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -504,7 +488,6 @@ def test_svae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, svae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -642,7 +625,6 @@ def test_checkpoint_saving_during_training(self, svae, trainer, training_configs ) def test_final_model_saving(self, svae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -693,7 +675,6 @@ def test_final_model_saving(self, svae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_svae_training_pipeline(self, svae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_VAE.py b/tests/test_VAE.py index 83cde490..38d4ed06 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -88,7 +88,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = VAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -111,7 +110,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -139,7 +137,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -167,7 +164,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -197,7 +193,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -233,7 +228,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -282,7 +276,6 @@ def vae(self, model_configs, demo_data): return VAE(model_configs) def test_model_train_output(self, vae, demo_data): - vae.train() out = vae(demo_data) @@ -300,8 +293,8 @@ def test_model_train_output(self, vae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -325,21 +318,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -354,7 +343,6 @@ def ae(self, model_configs, demo_data): return VAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -442,7 +430,6 @@ def trainer(self, vae, train_dataset, training_configs): return trainer def test_vae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -458,7 +445,6 @@ def test_vae_train_step(self, trainer): ) def test_vae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -474,7 +460,6 @@ def test_vae_eval_step(self, trainer): ) def test_vae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -494,7 +479,6 @@ def test_vae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -510,7 +494,6 @@ def test_vae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -648,7 +631,6 @@ def test_checkpoint_saving_during_training(self, vae, trainer, training_configs) ) def test_final_model_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -699,7 +681,6 @@ def test_final_model_saving(self, vae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vae_training_pipeline(self, vae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index cfba43e3..1ba24433 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -144,7 +144,6 @@ def test_raises_no_input_dim( def test_build_custom_arch( self, model_configs, custom_encoder, custom_decoder, custom_discriminator ): - vaegan = VAEGAN(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert vaegan.encoder == custom_encoder @@ -166,7 +165,6 @@ def test_build_custom_arch( class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -194,7 +192,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -222,7 +219,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -252,7 +248,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_custom_discriminator_model_saving( self, tmpdir, model_configs, custom_discriminator ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -287,7 +282,6 @@ def test_full_custom_model_saving( custom_decoder, custom_discriminator, ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -334,7 +328,6 @@ def test_raises_missing_files( custom_decoder, custom_discriminator, ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -394,7 +387,6 @@ def vaegan(self, model_configs, demo_data): return VAEGAN(model_configs) def test_model_train_output(self, vaegan, demo_data): - # model_configs.input_dim = demo_data['data'][0].shape[-1] # vaegan = VAEGAN(model_configs) @@ -405,23 +397,20 @@ def test_model_train_output(self, vaegan, demo_data): assert isinstance(out, ModelOutput) - assert ( - set( - [ - "loss", - "recon_loss", - "encoder_loss", - "decoder_loss", - "discriminator_loss", - "recon_x", - "z", - "update_discriminator", - "update_encoder", - "update_decoder", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_loss", + "encoder_loss", + "decoder_loss", + "discriminator_loss", + "recon_x", + "z", + "update_discriminator", + "update_encoder", + "update_decoder", + ] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape @@ -430,8 +419,8 @@ def test_model_train_output(self, vaegan, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -455,21 +444,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -484,7 +469,6 @@ def ae(self, model_configs, demo_data): return VAEGAN(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -610,7 +594,6 @@ def trainer(self, vaegan, train_dataset, training_configs): return trainer def test_vaegan_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -626,7 +609,6 @@ def test_vaegan_train_step(self, trainer): ) def test_vaegan_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -642,7 +624,6 @@ def test_vaegan_eval_step(self, trainer): ) def test_vaegan_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -662,7 +643,6 @@ def test_vaegan_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vaegan_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -678,7 +658,6 @@ def test_vaegan_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vaegan, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -898,7 +877,6 @@ def test_checkpoint_saving_during_training(self, vaegan, trainer, training_confi ) def test_final_model_saving(self, vaegan, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -957,7 +935,6 @@ def test_final_model_saving(self, vaegan, trainer, training_configs): assert type(model_rec.discriminator.cpu()) == type(model.discriminator.cpu()) def test_vaegan_training_pipeline(self, vaegan, train_dataset, training_configs): - with pytest.raises(AssertionError): pipeline = TrainingPipeline( model=vaegan, training_config=BaseTrainerConfig() diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index 60de1d97..c0a3d7c6 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -94,7 +94,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = VAE_IAF(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -117,7 +116,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -145,7 +143,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -173,7 +170,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -203,7 +199,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -239,7 +234,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -288,7 +282,6 @@ def vae(self, model_configs, demo_data): return VAE_IAF(model_configs) def test_model_train_output(self, vae, demo_data): - vae.train() out = vae(demo_data) @@ -306,8 +299,8 @@ def test_model_train_output(self, vae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -331,21 +324,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -360,7 +349,6 @@ def ae(self, model_configs, demo_data): return VAE_IAF(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -450,7 +438,6 @@ def trainer(self, vae, train_dataset, training_configs): return trainer def test_vae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -466,7 +453,6 @@ def test_vae_train_step(self, trainer): ) def test_vae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -482,7 +468,6 @@ def test_vae_eval_step(self, trainer): ) def test_vae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -502,7 +487,6 @@ def test_vae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -518,7 +502,6 @@ def test_vae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -656,7 +639,6 @@ def test_checkpoint_saving_during_training(self, vae, trainer, training_configs) ) def test_final_model_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -707,7 +689,6 @@ def test_final_model_saving(self, vae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vae_training_pipeline(self, vae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index 2f264246..cdb37522 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -3,8 +3,8 @@ import pytest import torch - from pydantic import ValidationError + from pythae.customexception import BadInheritanceError from pythae.models import AutoModel, VAE_LinNF, VAE_LinNF_Config from pythae.models.base.base_utils import ModelOutput @@ -64,7 +64,6 @@ def bad_net(self): return NetBadInheritance() def test_raises_wrong_flows(self): - with pytest.raises(ValidationError): conf = VAE_LinNF_Config( input_dim=(1, 28), latent_dim=5, flows=["Planar", "WrongFlow"] @@ -109,7 +108,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = VAE_LinNF(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -132,7 +130,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -160,7 +157,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -188,7 +184,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -218,7 +213,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -254,7 +248,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -303,7 +296,6 @@ def vae(self, model_configs, demo_data): return VAE_LinNF(model_configs) def test_model_train_output(self, vae, demo_data): - vae.train() out = vae(demo_data) @@ -321,8 +313,8 @@ def test_model_train_output(self, vae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -346,21 +338,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -375,7 +363,6 @@ def ae(self, model_configs, demo_data): return VAE_LinNF(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -465,7 +452,6 @@ def trainer(self, vae, train_dataset, training_configs): return trainer def test_vae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -481,7 +467,6 @@ def test_vae_train_step(self, trainer): ) def test_vae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -497,7 +482,6 @@ def test_vae_eval_step(self, trainer): ) def test_vae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -517,7 +501,6 @@ def test_vae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -533,7 +516,6 @@ def test_vae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -671,7 +653,6 @@ def test_checkpoint_saving_during_training(self, vae, trainer, training_configs) ) def test_final_model_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -722,7 +703,6 @@ def test_final_model_saving(self, vae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vae_training_pipeline(self, vae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 9408fd73..f010f3af 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -86,7 +86,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = VAMP(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -109,7 +108,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -137,7 +135,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -165,7 +162,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -195,7 +191,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -231,7 +226,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -280,7 +274,6 @@ def vamp(self, model_configs, demo_data): return VAMP(model_configs) def test_model_train_output(self, vamp, demo_data): - vamp.train() out = vamp(demo_data) @@ -298,8 +291,8 @@ def test_model_train_output(self, vamp, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -323,21 +316,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -352,7 +341,6 @@ def ae(self, model_configs, demo_data): return VAMP(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -440,7 +428,6 @@ def trainer(self, vamp, train_dataset, training_configs): return trainer def test_vamp_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -456,7 +443,6 @@ def test_vamp_train_step(self, trainer): ) def test_vamp_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -468,7 +454,6 @@ def test_vamp_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vamp_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -484,7 +469,6 @@ def test_vamp_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vamp, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -622,7 +606,6 @@ def test_checkpoint_saving_during_training(self, vamp, trainer, training_configs ) def test_final_model_saving(self, vamp, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -673,7 +656,6 @@ def test_final_model_saving(self, vamp, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vamp_training_pipeline(self, vamp, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py index 799afdd6..99c44643 100644 --- a/tests/test_VQVAE.py +++ b/tests/test_VQVAE.py @@ -3,8 +3,8 @@ import pytest import torch - from pydantic import ValidationError + from pythae.customexception import BadInheritanceError from pythae.models import VQVAE, AutoModel, VQVAEConfig from pythae.models.base.base_utils import ModelOutput @@ -101,7 +101,6 @@ def test_raises_no_input_dim( model = VQVAE(model_configs_no_input_dim, decoder=custom_decoder) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = VQVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -124,7 +123,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -152,7 +150,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -180,7 +177,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -210,7 +206,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -246,7 +241,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -295,7 +289,6 @@ def vae(self, model_configs, demo_data): return VQVAE(model_configs) def test_model_train_output(self, vae, demo_data): - vae.train() out = vae(demo_data) @@ -313,8 +306,8 @@ def test_model_train_output(self, vae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -338,21 +331,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -367,7 +356,6 @@ def ae(self, model_configs, demo_data): return VQVAE(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -429,7 +417,6 @@ def trainer(self, vae, train_dataset, training_configs): return trainer def test_vae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -445,7 +432,6 @@ def test_vae_train_step(self, trainer): ) def test_vae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -461,7 +447,6 @@ def test_vae_eval_step(self, trainer): ) def test_vae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -481,7 +466,6 @@ def test_vae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_vae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -497,7 +481,6 @@ def test_vae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -635,7 +618,6 @@ def test_checkpoint_saving_during_training(self, vae, trainer, training_configs) ) def test_final_model_saving(self, vae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -686,7 +668,6 @@ def test_final_model_saving(self, vae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_vae_training_pipeline(self, vae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_WAE_MMD.py b/tests/test_WAE_MMD.py index cb24dc1c..1126f891 100644 --- a/tests/test_WAE_MMD.py +++ b/tests/test_WAE_MMD.py @@ -91,7 +91,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = WAE_MMD(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -114,7 +113,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -142,7 +140,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -170,7 +167,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -200,7 +196,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -236,7 +231,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -285,7 +279,6 @@ def wae(self, model_configs, demo_data): return WAE_MMD(model_configs) def test_model_train_output(self, wae, demo_data): - wae.train() out = wae(demo_data) @@ -303,8 +296,8 @@ def test_model_train_output(self, wae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -328,21 +321,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -357,7 +346,6 @@ def ae(self, model_configs, demo_data): return WAE_MMD(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -421,7 +409,6 @@ def trainer(self, wae, train_dataset, training_configs): return trainer def test_wae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -437,7 +424,6 @@ def test_wae_train_step(self, trainer): ) def test_wae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -453,7 +439,6 @@ def test_wae_eval_step(self, trainer): ) def test_wae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -473,7 +458,6 @@ def test_wae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_wae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -489,7 +473,6 @@ def test_wae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, wae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -627,7 +610,6 @@ def test_checkpoint_saving_during_training(self, wae, trainer, training_configs) ) def test_final_model_saving(self, wae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -678,7 +660,6 @@ def test_final_model_saving(self, wae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_wae_training_pipeline(self, wae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index 1ab10266..b9b856e6 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -93,7 +93,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = INFOVAE_MMD( model_configs, encoder=custom_encoder, decoder=custom_decoder ) @@ -118,7 +117,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -146,7 +144,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -174,7 +171,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -204,7 +200,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -242,7 +237,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -293,7 +287,6 @@ def info_vae_mmd(self, model_configs, demo_data): return INFOVAE_MMD(model_configs) def test_model_train_output(self, info_vae_mmd, demo_data): - info_vae_mmd.train() out = info_vae_mmd(demo_data) @@ -311,8 +304,8 @@ def test_model_train_output(self, info_vae_mmd, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -336,21 +329,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -365,7 +354,6 @@ def ae(self, model_configs, demo_data): return INFOVAE_MMD(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -455,7 +443,6 @@ def trainer(self, info_vae_mmd, train_dataset, training_configs): return trainer def test_info_vae_mmd_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -471,7 +458,6 @@ def test_info_vae_mmd_train_step(self, trainer): ) def test_info_vae_mmd_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -487,7 +473,6 @@ def test_info_vae_mmd_eval_step(self, trainer): ) def test_info_vae_mmd_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -507,7 +492,6 @@ def test_info_vae_mmd_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_info_vae_mmd_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -523,7 +507,6 @@ def test_info_vae_mmd_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, info_vae_mmd, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -663,7 +646,6 @@ def test_checkpoint_saving_during_training( ) def test_final_model_saving(self, info_vae_mmd, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -716,7 +698,6 @@ def test_final_model_saving(self, info_vae_mmd, trainer, training_configs): def test_info_vae_mmd_training_pipeline( self, tmpdir, info_vae_mmd, train_dataset, training_configs ): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_rae_gp.py b/tests/test_rae_gp.py index fb84a3be..4262e66b 100644 --- a/tests/test_rae_gp.py +++ b/tests/test_rae_gp.py @@ -89,7 +89,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = RAE_GP(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -112,7 +111,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -140,7 +138,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -168,7 +165,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -198,7 +194,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -234,7 +229,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -283,7 +277,6 @@ def rae(self, model_configs, demo_data): return RAE_GP(model_configs) def test_model_train_output(self, rae, demo_data): - rae.train() out = rae(demo_data) @@ -301,8 +294,8 @@ def test_model_train_output(self, rae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -326,21 +319,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -355,7 +344,6 @@ def ae(self, model_configs, demo_data): return RAE_GP(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -419,7 +407,6 @@ def trainer(self, rae, train_dataset, training_configs): return trainer def test_rae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -435,7 +422,6 @@ def test_rae_train_step(self, trainer): ) def test_rae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -451,7 +437,6 @@ def test_rae_eval_step(self, trainer): ) def test_rae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -471,7 +456,6 @@ def test_rae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_rae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -487,7 +471,6 @@ def test_rae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, rae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -625,7 +608,6 @@ def test_checkpoint_saving_during_training(self, rae, trainer, training_configs) ) def test_final_model_saving(self, rae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -676,7 +658,6 @@ def test_final_model_saving(self, rae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_rae_training_pipeline(self, rae, train_dataset, training_configs): - dir_path = training_configs.output_dir # build pipeline diff --git a/tests/test_rae_l2.py b/tests/test_rae_l2.py index fc023346..7f28b905 100644 --- a/tests/test_rae_l2.py +++ b/tests/test_rae_l2.py @@ -94,7 +94,6 @@ def test_raises_no_input_dim( ) def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): - model = RAE_L2(model_configs, encoder=custom_encoder, decoder=custom_decoder) assert model.encoder == custom_encoder @@ -117,7 +116,6 @@ def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): class Test_Model_Saving: def test_default_model_saving(self, tmpdir, model_configs): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -145,7 +143,6 @@ def test_default_model_saving(self, tmpdir, model_configs): ) def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -173,7 +170,6 @@ def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder ) def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -203,7 +199,6 @@ def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder def test_full_custom_model_saving( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -239,7 +234,6 @@ def test_full_custom_model_saving( def test_raises_missing_files( self, tmpdir, model_configs, custom_encoder, custom_decoder ): - tmpdir.mkdir("dummy_folder") dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") @@ -288,29 +282,25 @@ def rae(self, model_configs, demo_data): return RAE_L2(model_configs) def test_model_train_output(self, rae, demo_data): - rae.train() out = rae(demo_data) assert isinstance(out, ModelOutput) - assert ( - set( - [ - "loss", - "recon_loss", - "encoder_loss", - "decoder_loss", - "update_encoder", - "update_decoder", - "embedding_loss", - "recon_x", - "z", - ] - ) - == set(out.keys()) - ) + assert set( + [ + "loss", + "recon_loss", + "encoder_loss", + "decoder_loss", + "update_encoder", + "update_decoder", + "embedding_loss", + "recon_x", + "z", + ] + ) == set(out.keys()) assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape @@ -319,8 +309,8 @@ def test_model_train_output(self, rae, demo_data): class Test_Model_interpolate: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -344,21 +334,17 @@ def test_interpolate(self, ae, demo_data, granularity): interp = ae.interpolate(demo_data, demo_data, granularity) - assert ( - tuple(interp.shape) - == ( - demo_data.shape[0], - granularity, - ) - + (demo_data.shape[1:]) - ) + assert tuple(interp.shape) == ( + demo_data.shape[0], + granularity, + ) + (demo_data.shape[1:]) class Test_Model_reconstruct: @pytest.fixture( params=[ - torch.randn(3, 2, 3, 1), - torch.randn(3, 2, 2), + torch.rand(3, 2, 3, 1), + torch.rand(3, 2, 2), torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[:][ "data" ], @@ -373,7 +359,6 @@ def ae(self, model_configs, demo_data): return RAE_L2(model_configs) def test_reconstruct(self, ae, demo_data): - recon = ae.reconstruct(demo_data) assert tuple(recon.shape) == demo_data.shape @@ -452,7 +437,6 @@ def trainer(self, rae, train_dataset, training_configs): return trainer def test_rae_train_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.train_step(epoch=1) @@ -468,7 +452,6 @@ def test_rae_train_step(self, trainer): ) def test_rae_eval_step(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) step_1_loss = trainer.eval_step(epoch=1) @@ -484,7 +467,6 @@ def test_rae_eval_step(self, trainer): ) def test_rae_predict_step(self, trainer, train_dataset): - start_model_state_dict = deepcopy(trainer.model.state_dict()) inputs, recon, generated = trainer.predict(trainer.model) @@ -504,7 +486,6 @@ def test_rae_predict_step(self, trainer, train_dataset): assert generated.shape == inputs.shape def test_rae_main_train_loop(self, trainer): - start_model_state_dict = deepcopy(trainer.model.state_dict()) trainer.train() @@ -520,7 +501,6 @@ def test_rae_main_train_loop(self, trainer): ) def test_checkpoint_saving(self, rae, trainer, training_configs): - dir_path = training_configs.output_dir # Make a training step @@ -695,7 +675,6 @@ def test_checkpoint_saving_during_training(self, rae, trainer, training_configs) ) def test_final_model_saving(self, rae, trainer, training_configs): - dir_path = training_configs.output_dir trainer.train() @@ -746,7 +725,6 @@ def test_final_model_saving(self, rae, trainer, training_configs): assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) def test_rae_training_pipeline(self, rae, train_dataset, training_configs): - with pytest.raises(AssertionError): pipeline = TrainingPipeline(model=rae, training_config=BaseTrainerConfig())