Skip to content

Commit

Permalink
fixes #110 & fix some tests & black 23.10
Browse files Browse the repository at this point in the history
  • Loading branch information
Clément Chadebec committed Oct 22, 2023
1 parent b7c53f4 commit 1103a99
Show file tree
Hide file tree
Showing 90 changed files with 462 additions and 1,259 deletions.
1 change: 0 additions & 1 deletion src/pythae/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class BaseDataset(Dataset):
"""

def __init__(self, data, labels):

self.labels = labels.type(torch.float)
self.data = data.type(torch.float)

Expand Down
2 changes: 0 additions & 2 deletions src/pythae/data/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def to_dataset(data: torch.Tensor, labels: Optional[torch.Tensor] = None):
return dataset

def _process_data_array(self, data: np.ndarray, batch_size: int = 100):

num_samples = data.shape[0]
samples_shape = data.shape

Expand All @@ -102,7 +101,6 @@ def _process_data_array(self, data: np.ndarray, batch_size: int = 100):
full_data = []

for i in range(num_complete_batch):

# Detect potential nan
if DataProcessor.has_nan(data[i * batch_size : (i + 1) * batch_size]):
raise ValueError("Nan detected in input data!")
Expand Down
19 changes: 5 additions & 14 deletions src/pythae/models/adversarial_ae/adversarial_ae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
decoder: Optional[BaseDecoder] = None,
discriminator: Optional[BaseDiscriminator] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

if discriminator is None:
Expand Down Expand Up @@ -149,22 +148,16 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:
return output

def loss_function(self, recon_x, x, z, z_prior):

N = z.shape[0] # batch size

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
Expand Down Expand Up @@ -232,7 +225,6 @@ def save(self, dir_path: str):

@classmethod
def _load_custom_discriminator_from_folder(cls, dir_path):

file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)

Expand Down Expand Up @@ -361,7 +353,6 @@ def load_from_hf_hub(
)

else:

if not model_config.uses_default_encoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl")
encoder = cls._load_custom_encoder_from_folder(dir_path)
Expand Down
2 changes: 0 additions & 2 deletions src/pythae/models/ae/ae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

BaseAE.__init__(self, model_config=model_config, decoder=decoder)

self.model_name = "AE"
Expand Down Expand Up @@ -83,7 +82,6 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:
return output

def loss_function(self, recon_x, x):

MSE = F.mse_loss(
recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none"
).sum(dim=-1)
Expand Down
6 changes: 1 addition & 5 deletions src/pythae/models/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ class BaseAE(nn.Module):
def __init__(
self,
model_config: BaseAEConfig,
encoder: Optional[BaseDecoder] = None,
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

nn.Module.__init__(self)

self.model_name = "BaseAE"
Expand Down Expand Up @@ -374,7 +373,6 @@ def _load_model_weights_from_folder(cls, dir_path):

@classmethod
def _load_custom_encoder_from_folder(cls, dir_path):

file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)

Expand All @@ -393,7 +391,6 @@ def _load_custom_encoder_from_folder(cls, dir_path):

@classmethod
def _load_custom_decoder_from_folder(cls, dir_path):

file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)

Expand Down Expand Up @@ -510,7 +507,6 @@ def load_from_hf_hub(cls, hf_hub_path: str, allow_pickle=False): # pragma: no c
)

else:

if not model_config.uses_default_encoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl")
encoder = cls._load_custom_encoder_from_folder(dir_path)
Expand Down
17 changes: 5 additions & 12 deletions src/pythae/models/beta_tc_vae/beta_tc_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

self.model_name = "BetaTCVAE"
Expand Down Expand Up @@ -89,20 +88,14 @@ def forward(self, inputs: BaseDataset, **kwargs):
return output

def loss_function(self, recon_x, x, mu, log_var, z, dataset_size):

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
Expand Down
17 changes: 5 additions & 12 deletions src/pythae/models/beta_vae/beta_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

self.model_name = "BetaVAE"
Expand Down Expand Up @@ -81,20 +80,14 @@ def forward(self, inputs: BaseDataset, **kwargs):
return output

def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
Expand Down
23 changes: 8 additions & 15 deletions src/pythae/models/ciwae/ciwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

self.model_name = "CIWAE"
Expand Down Expand Up @@ -92,22 +91,16 @@ def forward(self, inputs: BaseDataset, **kwargs):
return output

def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x,
x.reshape(recon_x.shape[0], -1)
Expand All @@ -117,7 +110,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):
).sum(dim=-1)

log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)
log_p_z = -0.5 * (z**2).sum(dim=-1)

KLD = -(log_p_z - log_q_z)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

assert (
Expand Down Expand Up @@ -89,20 +88,14 @@ def forward(self, inputs: BaseDataset, **kwargs):
return output

def loss_function(self, recon_x, x, mu, log_var, z, epoch):

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
Expand Down
17 changes: 5 additions & 12 deletions src/pythae/models/factor_vae/factor_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):

VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)

self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim)
Expand Down Expand Up @@ -132,22 +131,16 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:
return output

def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted):

N = z.shape[0] # batch size

if self.model_config.reconstruction_loss == "mse":

recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)

elif self.model_config.reconstruction_loss == "bce":

recon_loss = F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
Expand Down
1 change: 0 additions & 1 deletion src/pythae/models/factor_vae/factor_vae_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

class FactorVAEDiscriminator(nn.Module):
def __init__(self, latent_dim=16, hidden_units=1000) -> None:

nn.Module.__init__(self)

self.layers = nn.Sequential(
Expand Down
Loading

0 comments on commit 1103a99

Please sign in to comment.