In [1]:
import random

import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import PIL

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, progress
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger

from sklearn.model_selection import train_test_split

import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split

import torchmetrics

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as T

In [2]:
### Set seed for reproducibility
np.random.seed(123)
random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x1270fe4b0>

# About

<img src="imgs/autoencoder_visualization.svg" style="height:600px" class="center" alt="cnn"/><br>

# Get data

In [3]:
### Download fashion MNIST
train_dataset = datasets.FashionMNIST(
    root='./data', train=True, download=True,)
test_dataset = datasets.FashionMNIST(
    root='./data', train=False, download=True,)

# Prepare data

In [4]:
X_train = train_dataset.data.detach().numpy()
X_test = test_dataset.data.detach().numpy()

In [5]:
### PyTorch uses Dataset objects to load the data during training and testing
class MNISTDataset(Dataset):

    """Data set"""

    def __init__(
        self,
        X: np.ndarray,
        noisy_X: np.ndarray,
        accelerator_name: str = "mps",
        
    ):
        '''Assign data'''
        self.X = X.astype(np.float32)
        self.noisy_X = noisy_X.astype(np.float32)

        if accelerator_name == "mps":
            self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        elif accelerator_name == "cuda:0":
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")
        
        
    def __len__(self):
        '''function to get the length of the dataset'''
        return len(self.X)

    def __getitem__(self, idx: int):
        '''return an x, y pair'''
        x_, noisy_x_ = self.X[idx].astype(np.float32), self.noisy_X[idx].astype(np.float32)

        return torch.from_numpy(x_).float().to(self.device), torch.from_numpy(noisy_x_).float().to(self.device)


In [6]:
# split into validation and training data
val_split = 0.2
X_train, X_val = train_test_split(X_train,
                                  test_size=val_split,
                                  random_state=123)

In [7]:
# add channel axes (N, H, W) -> (N, C, H, W) because C = 1 in this case
X_train = X_train[:, np.newaxis, :, :].astype(np.float32)
X_test = X_test[:, np.newaxis, :, :].astype(np.float32)
X_val = X_val[:, np.newaxis, :, :].astype(np.float32)

In [8]:
noisy_X_train = X_train.copy().astype(np.float32)
noisy_X_test = X_test.copy().astype(np.float32)
noisy_X_val = X_val.copy().astype(np.float32)

In [9]:
noise_level = 0.1

In [10]:
noisy_X_train += noise_level * np.random.standard_normal(X_train.shape)
noisy_X_train -= np.min(noisy_X_train)
noisy_X_train /= np.max(noisy_X_train)

noisy_X_val += noise_level * np.random.standard_normal(X_val.shape)
noisy_X_val -= np.min(noisy_X_val)
noisy_X_val /= np.max(noisy_X_val)

noisy_X_test += noise_level * np.random.standard_normal(X_test.shape)
noisy_X_test -= np.min(noisy_X_test)
noisy_X_test /= np.max(noisy_X_test)

In [11]:
# get H = W
input_xy = X_train.shape[-1]

In [12]:
batch_size = 64

In [13]:
##### Now we actually make the dataset and dataloader in PyTorch fashion
train_data = MNISTDataset(X_train, noisy_X_train)
val_data = MNISTDataset(X_val, noisy_X_val)
test_data = MNISTDataset(X_test, noisy_X_test)

# make the loader
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data)

# Define Models

## Encoder

In [14]:
class Encoder(pl.LightningModule):

    def __init__(self, 
                 kernel_size: int = 3, 
                 dropout: float = 0.25,
                 cnn_layer_dims: list = [126, 64, 32],
                 padding: int = 1, 
                 stride: int = 2,
                 n_channels: int = 1,
                 input_xy: int = 28, 
                 lr: float = 1e-4, 
                 weight_decay: float = 0., 
                 eps: float = 5e-7, 
                 activation: torch.nn.modules.activation = nn.GELU(),
                 use_wandb: bool=False,
                 scheduler_name: str = "none",
                 step_size: int = 5,
                 gamma: float = 0.5,
                 ) -> None:
        super().__init__()
        ### Always need to call above function first in order
        ### to properly initialize a model
        '''Basic CNN to classify fashion MNIST
        We aren't going to both with some of the fancier stuff from the MLP, but it's easy enough to apply here too
        '''
        
        # model parameters
        self.cnn_dims = cnn_layer_dims
        self.activation = activation
        self.lr = lr
        self.eps = eps
        self.weight_decay = weight_decay
        self.dropout = dropout
        self.n_channels = len(cnn_layer_dims)
        self.scheduler_name = scheduler_name
        # if using a scheduler
        self.step_size = step_size
        self.gamma = gamma
        
        # log using WandB or TensorBoard
        self.use_wandb = use_wandb
        
        # what the input data looks like (allows construction of graph for logging)
        # (batch_size, channels, height, width)
        self.example_input_array = torch.zeros(
            (1, n_channels, input_xy, input_xy,),
            dtype=torch.float32
        )
        
        #### Construct the layers #####
        ## get input layer
        ## shape = (C, H, W) -> (cnn_layer_dim, H, W)
        self.input_cnn = nn.Conv2d(n_channels, cnn_layer_dims[0], 
                                   kernel_size=kernel_size, stride=stride, padding=padding)
        
        ## make CNN hidden layers
        self.encoding_layers = []
        for i in range(1, n_cnn_layers):
            self.encoding_layers.append(nn.Conv2d(cnn_layer_dims[i-1], cnn_layer_dims[i], 
                                       kernel_size=kernel_size, stride=stride, padding=padding))

        self.encoding_layers = nn.ModuleList(self.encoding_layers)
        
    def forward(self, x):
        '''Determines how data is passed through the network, 
           i.e creates the connectivity of the network'''
        
        ## send through input layer and activate
        x = self.activation(self.input_cnn(x))
        ## do max pooling (2x2)
        x = F.max_pool2d(x, 2)
            
        # pass through CNN
        for layer in self.encoding_layers:
            # pass through layer and activate
            x = layer(x)
            x = self.activation(x)
            # pool
            x = F.max_pool2d(x, 2)
        
        return x

    def configure_optimizers(self) -> (list, list):
        """Set up the optimizer and potential learning rate scheduler"""
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            eps=self.eps,
            weight_decay=self.weight_decay,
        )

        if self.scheduler_name == "none":
            return self.optimizer

        ### this decreases the learning rate by a factor of gamma every step_size
        self.scheduler = MultiStepLR(
            self.optimizer,
            list(range(0, self.trainer.max_epochs, self.step_size)),
            gamma=self.gamma,
        )

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]
        
    #### need to add these two things in case the scheduler is used ####
    def lr_scheduler_step(self, scheduler, metric) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def on_epoch_end(self) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def process_batch(self, batch, step: str = "train"):
        """Passes and logs a batch for a given type of step (test, train, validation)"""
        
        # get data (batch includes labels, which we don't need)
        _, noisy_x = batch

        # pass through network
        # logits have no activation applied
        return self(noisy_x)
        
    def training_step(self, batch, batch_idx):
        """What do do with a training batch"""
        
        return self.process_batch(batch, step="train")

    def validation_step(self, batch, batch_idx):
        '''Validation step (at the end of each epoch)'''
        
        return self.process_batch(batch, step="val")
        
    def test_step(self, batch, batch_idx):
        
        '''Test step is essentially the same as a validation step in this instance'''
        return self.process_batch(batch, step="test")

    def activation_maps(self, x, depth: int=0) -> np.ndarray:
        
        '''Gets output activation of an arbitary CNN layer'''
        
        i = 0
        # input layer
        x = self.activation(self.input_cnn(x))
        
        if depth == 0:
            return x.detach().numpy()
        
        i += 1
        x = F.max_pool2d(x, 2)
            
        # pass through CNN and return when you reach the appropriate depth
        for layer in self.encoding_layers:
            x = self.activation(layer(x))
            if i == depth:
                return x.detach().numpy()
            i += 1
            x = F.max_pool2d(x, 2)


## Decoder

In [15]:
class Decoder(pl.LightningModule):

    def __init__(self, 
                 kernel_size: int = 3, 
                 dropout: float = 0.25,
                 cnn_layer_dims: list = [32, 64, 128],
                 padding: int = 1, 
                 stride: int = 2,
                 image_input_channels: int = 1,
                 n_input_channels: int = 128,
                 input_xy: int = 4, 
                 lr: float = 1e-4, 
                 weight_decay: float = 0., 
                 eps: float = 5e-7, 
                 activation: torch.nn.modules.activation = F.gelu,
                 use_wandb: bool=False,
                 scheduler_name: str = "none",
                 step_size: int = 5,
                 gamma: float = 0.5,
                 output_padding: int = 1,
                 ) -> None:
        super().__init__()
        ### Always need to call above function first in order
        ### to properly initialize a model
        '''Basic CNN to classify fashion MNIST
        We aren't going to both with some of the fancier stuff from the MLP, but it's easy enough to apply here too
        '''
        
        # model parameters
        self.cnn_dims = cnn_layer_dims
        self.image_input_channels = image_input_channels
        self.activation = activation
        self.lr = lr
        self.eps = eps
        self.weight_decay = weight_decay
        self.dropout = dropout
        self.n_channels = len(cnn_layer_dims)
        self.scheduler_name = scheduler_name
        # if using a scheduler
        self.step_size = step_size
        self.gamma = gamma
        
        # log using WandB or TensorBoard
        self.use_wandb = use_wandb

        # what the input data looks like (allows construction of graph for logging)
        # (batch_size, channels, height, width)
        self.example_input_array = torch.zeros(
            (1, n_input_channels, input_xy, input_xy,),
            dtype=torch.float32
        )
        
        #### Construct the layers #####
        ## get input layer
        ## shape = (C, H, W) -> (cnn_layer_dim, H, W)
        self.input_cnn = nn.ConvTranspose2d(n_input_channels,
                                            cnn_layer_dims[0], 
                                            kernel_size=kernel_size, 
                                            # output_padding=output_padding, 
                                            padding=padding, stride=stride
                                          )

        ## make CNN hidden layers
        self.decoding_layers = []
        for i in range(1, n_cnn_layers):

            ### ConvTranspose2d __upscales__ the data with output padding
            self.decoding_layers.append(nn.ConvTranspose2d(cnn_layer_dims[i-1],
                                                           cnn_layer_dims[i], 
                                                           kernel_size=kernel_size, 
                                                           output_padding=output_padding, 
                                                           padding=padding, stride=stride
                                                          ))

            # self.decoding_layers.append(nn.Conv2d(cnn_layer_dims[i], cnn_layer_dims[i], 
                                       # kernel_size=kernel_size, stride=stride, padding=padding))

        self.decoding_layers.append(nn.ConvTranspose2d(cnn_layer_dims[-1],
                                                       self.image_input_channels, 
                                                       kernel_size=kernel_size, 
                                                       output_padding=output_padding, 
                                                       padding=padding, stride=stride,
                                                    ))

        # self.decoding_layers.append(nn.Tanh())

        self.decoding_layers = nn.ModuleList(self.decoding_layers)
        
    def forward(self, x):
        '''Determines how data is passed through the network, 
           i.e creates the connectivity of the network'''
        ## send through input layer and activate
        x = self.activation(self.input_cnn(x))
        # pass through CNN
        for layer in self.decoding_layers:
            # pass through layer
            x = layer(x)
            x = self.activation(x)

        return x

    def configure_optimizers(self) -> (list, list):
        """Set up the optimizer and potential learning rate scheduler"""
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            eps=self.eps,
            weight_decay=self.weight_decay,
        )

        if self.scheduler_name == "none":
            return self.optimizer

        ### this decreases the learning rate by a factor of gamma every step_size
        self.scheduler = MultiStepLR(
            self.optimizer,
            list(range(0, self.trainer.max_epochs, self.step_size)),
            gamma=self.gamma,
        )

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]
        
    #### need to add these two things in case the scheduler is used ####
    def lr_scheduler_step(self, scheduler, metric) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def on_epoch_end(self) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def process_batch(self, batch, step: str = "train"):
        """Passes and logs a batch for a given type of step (test, train, validation)"""
        
        # get data (batch includes labels, which we don't need)
        _, noisy_x = batch

        # pass through network
        # logits have no activation applied
        return self(noisy_x)
        
    def training_step(self, batch, batch_idx):
        """What do do with a training batch"""
        
        return self.process_batch(batch, step="train")

    def validation_step(self, batch, batch_idx):
        '''Validation step (at the end of each epoch)'''
        
        return self.process_batch(batch, step="val")
        
    def test_step(self, batch, batch_idx):
        
        '''Test step is essentially the same as a validation step in this instance'''
        return self.process_batch(batch, step="test")

    def activation_maps(self, x, depth: int=0) -> np.ndarray:
        
        '''Gets output activation of an arbitary CNN layer'''
        
        i = 0
        # input layer
        x = self.activation(self.input_cnn(x))
        
        if depth == 0:
            return x.detach().numpy()
        
        i += 1
        x = F.max_pool2d(x, 2)
            
        # pass through CNN and return when you reach the appropriate depth
        for layer in self.decoding_layers:
            x = self.activation(layer(x))
            if i == depth:
                return x.detach().numpy()
            i += 1
            x = F.max_pool2d(x, 2)


## Entire network

In [16]:
class Autoencoder(pl.LightningModule):

    def __init__(self, 
                 kernel_size: int = 3, 
                 dropout: float = 0.25,
                 cnn_layer_dims: list = [128, 64, 32],
                 padding: int = 1, 
                 stride: int = 2,
                 image_input_channels: int = 1,
                 latent_dim: int = 32,
                 input_xy: int = 28, 
                 lr: float = 1e-4, 
                 weight_decay: float = 0., 
                 eps: float = 5e-7, 
                 activation: torch.nn.modules.activation = F.gelu,
                 use_wandb: bool=False,
                 scheduler_name: str = "none",
                 step_size: int = 5,
                 gamma: float = 0.5,
                 output_padding: int = 1,
                 ) -> None:
        super().__init__()
        ### Always need to call above function first in order
        ### to properly initialize a model
        '''Basic CNN to classify fashion MNIST
        We aren't going to both with some of the fancier stuff from the MLP, but it's easy enough to apply here too
        '''
        
        # model parameters
        self.cnn_dims = cnn_layer_dims
        self.image_input_channels = image_input_channels
        self.activation = activation
        self.lr = lr
        self.eps = eps
        self.weight_decay = weight_decay
        self.dropout = dropout
        self.n_channels = len(cnn_layer_dims)
        self.scheduler_name = scheduler_name
        # if using a scheduler
        self.step_size = step_size
        self.gamma = gamma

        self.latent_dim = latent_dim

        self.loss_fn = nn.MSELoss()
        
        # log using WandB or TensorBoard
        self.use_wandb = use_wandb

        # what the input data looks like (allows construction of graph for logging)
        # (batch_size, channels, height, width)
        self.example_input_array = torch.zeros(
            (1, image_input_channels, input_xy, input_xy,),
            dtype=torch.float32
        )
        

        self.encoder = Encoder(
                  cnn_layer_dims=cnn_layer_dims, 
                  input_xy=input_xy, 
                  activation=activation, 
                  n_channels=n_channels,
                  use_wandb=use_wandb,
                  dropout=dropout, 
                  padding=padding,
                  eps=eps, lr=lr, 
                  weight_decay=weight_decay,
                  scheduler_name=scheduler_name,
                  gamma=gamma,
                  step_size=step_size,
        )

        # get output dimensions of encoder
        self.encoded_cxy, self.flat_dim = self._flat_layer_size()

        self.decoder = Decoder(
                  cnn_layer_dims=cnn_layer_dims[::-1], 
                  input_xy=self.encoded_cxy[-1], 
                  activation=activation, 
                  n_input_channels=cnn_layer_dims[-1],
                  image_input_channels=image_input_channels,
                  use_wandb=use_wandb,
                  dropout=dropout, 
                  padding=padding,
                  eps=eps, lr=lr, 
                  weight_decay=weight_decay,
                  scheduler_name=scheduler_name,
                  gamma=gamma,
                  step_size=step_size,
                  output_padding=output_padding,
        )


        self.latent_encoder = nn.Linear(self.flat_dim, self.latent_dim)
        self.latent_decoder = nn.Linear(self.latent_dim, self.flat_dim)
        
    def forward(self, x):
        '''Determines how data is passed through the network, 
           i.e creates the connectivity of the network'''
        
        # encode
        x = self.encoder(x)

        # put into latent space
        x = nn.Flatten()(x)
        x = self.activation(self.latent_encoder(x))

        # bring out of latent space
        x = self.activation(self.latent_decoder(x))

        # import pdb; pdb.set_trace()

        # decode
        x = x.reshape(x.shape[0], 
                      -1, 
                      self.encoded_cxy[-1], 
                      self.encoded_cxy[-1])

        # import pdb; pdb.set_trace()
        
        x = self.decoder(x)

        # import pdb; pdb.set_trace()

        return nn.ReLU()(x)

    def configure_optimizers(self) -> (list, list):
        """Set up the optimizer and potential learning rate scheduler"""
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            eps=self.eps,
            weight_decay=self.weight_decay,
        )

        if self.scheduler_name == "none":
            return self.optimizer

        ### this decreases the learning rate by a factor of gamma every step_size
        self.scheduler = MultiStepLR(
            self.optimizer,
            list(range(0, self.trainer.max_epochs, self.step_size)),
            gamma=self.gamma,
        )

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]
        
    #### need to add these two things in case the scheduler is used ####
    def lr_scheduler_step(self, scheduler, metric) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def on_epoch_end(self) -> None:
        if self.scheduler_name != "none":
            self.scheduler.step()

    def process_batch(self, batch, step: str = "train"):
        """Passes and logs a batch for a given type of step (test, train, validation)"""
        
        # get data (batch includes labels, which we don't need)
        x, noisy_x = batch

        # pass through network
        denoised_x = self(noisy_x)

        # flatten to get easy MSE
        denoised_x = nn.Flatten()(denoised_x)
        x = nn.Flatten()(x)

        # import pdb; pdb.set_trace()
        
        loss = self.loss_fn(x, denoised_x)

        self.log(f"{step}_loss", loss)

        return loss

    def training_step(self, batch, batch_idx):
        """What do do with a training batch"""
        
        return self.process_batch(batch, step="train")

    def validation_step(self, batch, batch_idx):
        '''Validation step (at the end of each epoch)'''
        
        return self.process_batch(batch, step="val")
        
    def test_step(self, batch, batch_idx):
        
        '''Test step is essentially the same as a validation step in this instance'''
        return self.process_batch(batch, step="test")

    def _flat_layer_size(self) -> int:
        
        '''Gets the dimension of the flattened CNN output layer'''
        
        x = self.example_input_array
        
        # Pass the input tensor through the CNN layers
        x = self.encoder(x)

        array_dim = x.size()
        # Calculate the flattened layer dimension
        flattened_dim = x.view(1, -1).size(1)

        return array_dim, flattened_dim



## Initialize

In [17]:
# model hyper parameters 
lr = 1e-3
eps = 1e-8
weight_decay = 1e-6
dropout = 0.25
cnn_layer_dims  = [128,]
n_cnn_layers = len(cnn_layer_dims)
latent_dim = 16
activation = F.gelu
n_channels = 1
stride = 2
kernel_size = 3
padding = 0
output_padding = 1
scheduler_name = "step"
gamma = 0.5
step_size = 5

## WandB stuff
# log with WandB or TensorBoard
use_wandb = False
# do hyperparameter sweep with WandB
use_sweep = False
# WandB project name
project_name = 'FashionMNIST_AE'
# WandB lab name
entity = ''

In [18]:
model = Autoencoder(latent_dim=latent_dim,
                  cnn_layer_dims=cnn_layer_dims, 
                  input_xy=input_xy, 
                  activation=activation, 
                  image_input_channels=n_channels,
                  use_wandb=use_wandb,
                  dropout=dropout, 
                  eps=eps, 
                  lr=lr, 
                  weight_decay=weight_decay,
                  scheduler_name=scheduler_name,
                  gamma=gamma,
                  step_size=step_size,
                  stride=stride,
                  padding=padding,
                  output_padding=output_padding,
               )

# Train

In [19]:
num_epochs = 25

In [20]:
accelerator_name = "mps"
# accelerator_name = "cpu"

In [21]:
# boilerplate to get GPU if possible
if accelerator_name == "mps":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif accelerator_name == "cuda":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

In [22]:
if not use_wandb:
    %load_ext tensorboard
    cnn_logger = TensorBoardLogger("ae_logs", name="simple_mnist_fashion_ae")
    run_name = "ae_cnn"
else:
    logger_kwargs = {
        "resume": "allow",
        "config": model_hparams,
    }
    cnn_logger = WandbLogger(project=project_name, entity=entity, **logger_kwargs)
    cnn_run_name = cnn_logger.experiment.name

In [23]:
#### necessary for newer PTL versions
devices = 1
accelerator = "gpu" if devices == 1 else "cpu"

# make the trainer
trainer = pl.Trainer(
    devices=devices,
    accelerator=accelerator,
    max_epochs=num_epochs,
    log_every_n_steps=1,
    logger=cnn_logger,
    # reload_dataloaders_every_epoch=True,
    callbacks=[
        # ModelCheckpoint(
        #     save_weights_only=False,
        #     mode="min",
        #     monitor="val_acc",
        #     save_top_k=1,
        #     every_n_epochs=1,
        #     save_on_train_epoch_end=False,
        #     dirpath=f"/AE_Checkpoints/{run_name}/",
        #     filename=f"ae_checkpoint_{run_name}",
        # ),
        LearningRateMonitor("epoch"),
        progress.TQDMProgressBar(refresh_rate=1),
        EarlyStopping(
            monitor="val_loss",
            min_delta=0,
            patience=10,
            verbose=False,
            mode="min",
        ),
    ],
)
trainer.logger._log_graph = True
trainer.logger._default_hp_metric = None

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Missing logger folder: ae_logs/simple_mnist_fashion_ae

  | Name           | Type    | Params | In sizes       | Out sizes     
-----------------------------------------------------------------------------
0 | loss_fn        | MSELoss | 0      | ?              | ?             
1 | encoder        | Encoder | 1.3 K  | [1, 1, 28, 28] | [1, 128, 6, 6]
2 | decoder        | Decoder | 148 K  | [1, 128, 6, 6] | [1, 1, 28, 28]
3 | latent_encoder | Linear  | 73.7 K | [1, 4608]      | [1, 16]       
4 | latent_decoder | Linear  | 78.3 K | [1, 16]        | [1, 4608]     
-----------------------------------------------------------------------------
302 K     Trainable params
0         Non-trainable params
302 K     Total params
1.208     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# noisy_X_train.shape