In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DiagLinear(nn.Module):
    """Applies a diagonal linear transformation to the incoming data: :math:`y = xD^T + b`"""

    __constants__ = ['features']
    # features: int
    # weight: Tensor

    def __init__(self, features, bias=True):

        super(DiagLinear, self).__init__()

        self.features = features
        self.weight = nn.Parameter(torch.Tensor(features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        bound = 1 / np.sqrt(self.features)
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            bound = 1 / np.sqrt(self.features)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        output = input.mul(self.weight)
        if self.bias is not None:
            output += self.bias
        return output

    def extra_repr(self):
        return 'features={}, bias={}'.format(self.features, self.bias is not None)


class ConvAEEncoder(BaseModule):
    """Convolutional encoder."""

    def __init__(self, hparams):
        """

        Parameters
        ----------
        hparams : :obj:`dict`
            - 'model_class' (:obj:`str`): 'ae' | 'vae'
            - 'n_ae_latents' (:obj:`int`)
            - 'fit_sess_io_layers; (:obj:`bool`): fit session-specific input/output layers
            - 'ae_encoding_x_dim' (:obj:`list`)
            - 'ae_encoding_y_dim' (:obj:`list`)
            - 'ae_encoding_n_channels' (:obj:`list`)
            - 'ae_encoding_kernel_size' (:obj:`list`)
            - 'ae_encoding_stride_size' (:obj:`list`)
            - 'ae_encoding_x_padding' (:obj:`list`)
            - 'ae_encoding_y_padding' (:obj:`list`)
            - 'ae_encoding_layer_type' (:obj:`list`)

        """
        super().__init__()
        self.hparams = hparams
        self.encoder = None
        self.build_model()

    def __str__(self):
        """Pretty print encoder architecture."""
        format_str = 'Encoder architecture:\n'
        i = 0
        for module in self.encoder:
            format_str += str('    {:02d}: {}\n'.format(i, module))
            i += 1
        # final ff layer
        format_str += str('    {:02d}: {}\n'.format(i, self.FF))
        return format_str

    def build_model(self):
        """Construct the encoder."""

        self.encoder = nn.ModuleList()
        # Loop over layers (each conv/batch norm/max pool/relu chunk counts as
        # one layer for global_layer_num)
        global_layer_num = 0
        for i_layer in range(0, len(self.hparams['ae_encoding_n_channels'])):

            # only add if conv layer (checks within this for max pool layer)
            if self.hparams['ae_encoding_layer_type'][i_layer] == 'conv':

                # convolution layer
                args = self._get_conv2d_args(i_layer, global_layer_num)
                if self.hparams.get('fit_sess_io_layers', False) and i_layer == 0:
                    module = nn.ModuleList([
                        nn.Conv2d(
                            in_channels=args['in_channels'],
                            out_channels=args['out_channels'],
                            kernel_size=args['kernel_size'],
                            stride=args['stride'],
                            padding=args['padding'])
                        for _ in range(self.hparams['n_datasets'])])
                    self.encoder.add_module(
                        str('conv%i_sess_io_layers' % global_layer_num), module)
                else:
                    module = nn.Conv2d(
                        in_channels=args['in_channels'],
                        out_channels=args['out_channels'],
                        kernel_size=args['kernel_size'],
                        stride=args['stride'],
                        padding=args['padding'])
                    self.encoder.add_module(
                        str('conv%i' % global_layer_num), module)

                # batch norm layer
                if self.hparams['ae_batch_norm']:
                    module = nn.BatchNorm2d(
                        self.hparams['ae_encoding_n_channels'][i_layer],
                        momentum=self.hparams.get('ae_batch_norm_momentum', 0.1),
                        track_running_stats=self.hparams.get('track_running_stats', True))
                    self.encoder.add_module(
                        str('batchnorm%i' % global_layer_num), module)

                # max pool layer
                if i_layer < (len(self.hparams['ae_encoding_n_channels'])-1) \
                        and (self.hparams['ae_encoding_layer_type'][i_layer+1] == 'maxpool'):
                    args = self._get_maxpool2d_args(i_layer)
                    module = nn.MaxPool2d(
                        kernel_size=args['kernel_size'],
                        stride=args['stride'],
                        padding=args['padding'],
                        return_indices=args['return_indices'],
                        ceil_mode=args['ceil_mode'])
                    self.encoder.add_module(
                        str('maxpool%i' % global_layer_num), module)

                # leaky ReLU
                self.encoder.add_module(
                    str('relu%i' % global_layer_num), nn.LeakyReLU(0.05))
                global_layer_num += 1

        # final ff layer to latents
        last_conv_size = self.hparams['ae_encoding_n_channels'][-1] \
            * self.hparams['ae_encoding_y_dim'][-1] \
            * self.hparams['ae_encoding_x_dim'][-1]
        self.FF = nn.Linear(last_conv_size, self.hparams['n_ae_latents'])

        # If VAE model, have additional ff layer to latent variances
        if self.hparams.get('variational', False):
            self.logvar = nn.Linear(last_conv_size, self.hparams['n_ae_latents'])

    def _get_conv2d_args(self, layer, global_layer):

        if layer == 0:
            if self.hparams['model_class'] == 'cond-ae' and \
                    self.hparams.get('conditional_encoder', False):
                # labels will be appended to input if using conditional autoencoder with
                # conditional encoder flag
                n_labels = int(self.hparams['n_labels'] / 2)  # 'n_labels' key includes x/y coords
            else:
                n_labels = 0
            in_channels = self.hparams['ae_input_dim'][0] + n_labels
        else:
            in_channels = self.hparams['ae_encoding_n_channels'][layer - 1]

        out_channels = self.hparams['ae_encoding_n_channels'][layer]
        kernel_size = self.hparams['ae_encoding_kernel_size'][layer]
        stride = self.hparams['ae_encoding_stride_size'][layer]

        x_pad_0 = self.hparams['ae_encoding_x_padding'][layer][0]
        x_pad_1 = self.hparams['ae_encoding_x_padding'][layer][1]
        y_pad_0 = self.hparams['ae_encoding_y_padding'][layer][0]
        y_pad_1 = self.hparams['ae_encoding_y_padding'][layer][1]
        if (x_pad_0 == x_pad_1) and (y_pad_0 == y_pad_1):
            # if symmetric padding
            padding = (y_pad_0, x_pad_0)
        else:
            module = nn.ZeroPad2d((x_pad_0, x_pad_1, y_pad_0, y_pad_1))
            self.encoder.add_module(str('zero_pad%i' % global_layer), module)
            padding = 0

        args = {
            'in_channels': in_channels,
            'out_channels': out_channels,
            'kernel_size': kernel_size,
            'stride': stride,
            'padding': padding}
        return args

    def _get_maxpool2d_args(self, layer):
        args = {
            'kernel_size': int(self.hparams['ae_encoding_kernel_size'][layer + 1]),
            'stride': int(self.hparams['ae_encoding_stride_size'][layer + 1]),
            'padding': (
                self.hparams['ae_encoding_y_padding'][layer + 1][0],
                self.hparams['ae_encoding_x_padding'][layer + 1][0]),
            'return_indices': True}
        if self.hparams['ae_padding_type'] == 'valid':
            # no ceil mode in valid mode
            args['ceil_mode'] = False
        else:
            # using ceil mode instead of zero padding
            args['ceil_mode'] = True
        return args

    def forward(self, x, dataset=None):
        """Process input data.

        Parameters
        ----------
        x : :obj:`torch.Tensor` object
            input data
        dataset : :obj:`int`
            used with session-specific io layers

        Returns
        -------
        :obj:`tuple`
            - encoder output (:obj:`torch.Tensor`): shape (n_latents)
            - pool_idx (:obj:`list`): max pooling indices for each layer
            - output_size (:obj:`list`): output size for each layer

        """
        # loop over layers, have to collect pool_idx and output sizes if using max pooling to use
        # in unpooling
        pool_idx = []
        target_output_size = []
        for layer in self.encoder:
            if isinstance(layer, nn.MaxPool2d):
                target_output_size.append(x.size())
                x, idx = layer(x)
                pool_idx.append(idx)
            elif isinstance(layer, nn.ModuleList):
                x = layer[dataset](x)
            else:
                x = layer(x)

        # reshape for ff layer
        x = x.view(x.size(0), -1)
        if self.hparams.get('variational', False):
            return self.FF(x), self.logvar(x), pool_idx, target_output_size
        else:
            return self.FF(x), pool_idx, target_output_size


class AEPSEncoder(ConvAEEncoder):
    """Encoder that separates label-related subspace."""

    def __init__(self, hparams):

        from behavenet.models.base import DiagLinear

        super().__init__(hparams)

        # add linear transformations mapping from NN output to label-, non-label-related subspaces
        n_latents = self.hparams['n_ae_latents']
        n_labels = self.hparams['n_labels']
        # NN -> constrained latents
        self.A = nn.Linear(n_latents, n_labels, bias=False)
        # NN -> unconstrained latents
        self.B = nn.Linear(n_latents, n_latents - n_labels, bias=False)
        # constrained latents -> labels (diagonal matrix + bias)
        self.D = DiagLinear(n_labels, bias=True)

        # fix A, B to be orthogonal (and not trainable)
        from scipy.stats import ortho_group
        m = ortho_group.rvs(dim=n_latents).astype('float32')
        with torch.no_grad():
            self.A.weight = nn.Parameter(
                torch.from_numpy(m[:n_labels, :]), requires_grad=False)
            self.B.weight = nn.Parameter(
                torch.from_numpy(m[n_labels:, :]), requires_grad=False)

    def __str__(self):
        """Pretty print encoder architecture."""
        format_str = 'Encoder architecture:\n'
        i = 0
        for module in self.encoder:
            format_str += str('    {:02d}: {}\n'.format(i, module))
            i += 1
        # final ff layer
        format_str += str('    {:02d}: {}\n'.format(i, self.FF))
        # final linear transformations
        format_str += str('    {:02d}: {} (to constrained latents)\n'.format(i, self.A))
        format_str += str('    {:02d}: {} (to unconstrained latents)\n'.format(i, self.B))
        format_str += str('    {:02d}: {} (constrained latents to labels)\n'.format(i, self.D))
        return format_str

    def forward(self, x, dataset=None):
        """Process input data.

        Parameters
        ----------
        x : :obj:`torch.Tensor` object
            input data
        dataset : :obj:`int`
            used with session-specific io layers

        Returns
        -------
        :obj:`tuple`
            - encoder output y (:obj:`torch.Tensor`): constrained latents (predicted labels) of
              shape (n_labels)
            - encoder output z (:obj:`torch.Tensor`): unconstrained latents of shape
              (n_latents - n_labels)
            - logvar (:obj:`torch.Tensor`): log variance of latents of shape (n_latents)
            - pool_idx (:obj:`list`): max pooling indices for each layer
            - output_size (:obj:`list`): output size for each layer

        """
        # loop over layers, have to collect pool_idx and output sizes if using max pooling to use
        # in unpooling
        pool_idx = []
        target_output_size = []
        for layer in self.encoder:
            if isinstance(layer, nn.MaxPool2d):
                target_output_size.append(x.size())
                x, idx = layer(x)
                pool_idx.append(idx)
            elif isinstance(layer, nn.ModuleList):
                x = layer[dataset](x)
            else:
                x = layer(x)

        # reshape for ff layer
        x1 = x.view(x.size(0), -1)
        x = self.FF(x1)

        # push through linear transformations
        y = self.A(x)  # NN -> constrained latents
        w = self.B(x)  # NN -> unconstrained latents

        return y, w, self.logvar(x1), pool_idx, target_output_size

class PSVAE(AE):
    """Partitioned subspace variational autoencoder class.

    This class constructs a VAE that...

    """

    def __init__(self, hparams):
        """See constructor documentation of AE for hparams details.

        Parameters
        ----------
        hparams : :obj:`dict`
            in addition to the standard keys, must also contain:
            - 'n_labels' (:obj:`n_labels`)
            - 'ps_vae.alpha' (:obj:`float`)
            - 'ps_vae.beta' (:obj:`float`)

        """

        if hparams['model_type'] == 'linear':
            raise NotImplementedError
        if hparams['n_ae_latents'] < hparams['n_labels']:
            raise ValueError('PS-VAE model must contain at least as many latents as labels')

        self.n_latents = hparams['n_ae_latents']
        self.n_labels = hparams['n_labels']

        hparams['variational'] = True
        super().__init__(hparams)

        # set up beta annealing
        anneal_epochs = self.hparams.get('ps_vae.anneal_epochs', 0)
        self.curr_epoch = 0  # must be modified by training script
        beta = hparams['ps_vae.beta']
        # TODO: these values should not be precomputed
        if anneal_epochs > 0:
            # annealing for total correlation term
            self.beta_vals = np.append(
                np.linspace(0, beta, anneal_epochs),  # USED TO START AT 1!!
                beta * np.ones(hparams['max_n_epochs'] + 1))  # sloppy addition to fully cover rest
            # annealing for remaining kl terms - index code mutual info and dim-wise kl
            self.kl_anneal_vals = np.append(
                np.linspace(0, 1, anneal_epochs),
                np.ones(hparams['max_n_epochs'] + 1))  # sloppy addition to fully cover rest
        else:
            self.beta_vals = beta * np.ones(hparams['max_n_epochs'] + 1)
            self.kl_anneal_vals = np.ones(hparams['max_n_epochs'] + 1)

    def build_model(self):
        """Construct the model using hparams."""
        self.hparams['hidden_layer_size'] = self.hparams['n_ae_latents']
        if self.model_type == 'conv':
            self.encoding = ConvAEPSEncoder(self.hparams)
            self.decoding = ConvAEDecoder(self.hparams)
        elif self.model_type == 'linear':
            raise NotImplementedError
            # if self.hparams.get('fit_sess_io_layers', False):
            #     raise NotImplementedError
            # n_latents = self.hparams['n_ae_latents']
            # self.encoding = LinearAEEncoder(n_latents, self.img_size)
            # self.decoding = LinearAEDecoder(n_latents, self.img_size, self.encoding)
        else:
            raise ValueError('"%s" is an invalid model_type' % self.model_type)

    def forward(self, x, dataset=None, use_mean=False, **kwargs):
        """Process input data.

        Parameters
        ----------
        x : :obj:`torch.Tensor` object
            input data of shape (n_frames, n_channels, y_pix, x_pix)
        dataset : :obj:`int`
            used with session-specific io layers
        use_mean : :obj:`bool`
            True to skip sampling step

        Returns
        -------
        :obj:`tuple`
            - x_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_channels, y_pix, x_pix)
            - z (:obj:`torch.Tensor`): sampled latent variable of shape (n_frames, n_latents)
            - mu (:obj:`torch.Tensor`): mean paramter of shape (n_frames, n_latents)
            - logvar (:obj:`torch.Tensor`): logvar paramter of shape (n_frames, n_latents)
            - y_hat (:obj:`torch.Tensor`): output of shape (n_frames, n_labels)

        """
        y, w, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
        mu = torch.cat([y, w], axis=1)
        if use_mean:
            z = mu
        else:
            z = reparameterize(mu, logvar)
        x_hat = self.decoding(z, pool_idx, outsize, dataset=dataset)
        y_hat = self.encoding.D(y)
        return x_hat, z, mu, logvar, y_hat

    def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200):
        """Calculate modified ELBO loss for PSVAE.

        The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory
        requirements low; gradients are accumulated across all chunks before a gradient step is
        taken.

        Parameters
        ----------
        data : :obj:`dict`
            batch of data; keys should include 'images' and 'masks', if necessary
        dataset : :obj:`int`, optional
            used for session-specific io layers
        accumulate_grad : :obj:`bool`, optional
            accumulate gradient for training step
        chunk_size : :obj:`int`, optional
            batch is split into chunks of this size to keep memory requirements low

        Returns
        -------
        :obj:`dict`
            - 'loss' (:obj:`float`): full elbo
            - 'loss_ll' (:obj:`float`): log-likelihood portion of elbo
            - 'loss_kl' (:obj:`float`): kl portion of elbo
            - 'loss_mse' (:obj:`float`): mse (without gaussian constants)
            - 'beta' (:obj:`float`): weight in front of kl term

        """

        x = data['images'][0]
        y = data['labels'][0]
        m = data['masks'][0] if 'masks' in data else None
        n = data['labels_masks'][0] if 'labels_masks' in data else None
        batch_size = x.shape[0]
        n_chunks = int(np.ceil(batch_size / chunk_size))
        n_labels = self.hparams['n_labels']
        # n_latents = self.hparams['n_ae_latents']

        # compute hyperparameters
        alpha = self.hparams['ps_vae.alpha']
        beta = self.beta_vals[self.curr_epoch]
        kl = self.kl_anneal_vals[self.curr_epoch]

        loss_strs = [
            'loss', 'loss_data_ll', 'loss_label_ll', 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc',
            'loss_zu_dwkl']

        loss_dict_vals = {loss: 0 for loss in loss_strs}
        loss_dict_vals['loss_data_mse'] = 0

        y_hat_all = []

        for chunk in range(n_chunks):

            idx_beg = chunk * chunk_size
            idx_end = np.min([(chunk + 1) * chunk_size, batch_size])

            x_in = x[idx_beg:idx_end]
            y_in = y[idx_beg:idx_end]
            m_in = m[idx_beg:idx_end] if m is not None else None
            n_in = n[idx_beg:idx_end] if n is not None else None
            x_hat, sample, mu, logvar, y_hat = self.forward(x_in, dataset=dataset, use_mean=False)

            # reset losses
            loss_dict_torch = {loss: 0 for loss in loss_strs}

            # data log-likelihood
            loss_dict_torch['loss_data_ll'] = losses.gaussian_ll(x_in, x_hat, m_in)
            loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll']

            # label log-likelihood
            loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(y_in, y_hat, n_in)
            loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll']

            # supervised latents kl
            loss_dict_torch['loss_zs_kl'] = losses.kl_div_to_std_normal(
                mu[:, :n_labels], logvar[:, :n_labels])
            loss_dict_torch['loss'] += loss_dict_torch['loss_zs_kl']

            # compute all terms of decomposed elbo at once
            index_code_mi, total_correlation, dimension_wise_kl = losses.decomposed_kl(
                sample[:, n_labels:], mu[:, n_labels:], logvar[:, n_labels:])

            # unsupervised latents index-code mutual information
            loss_dict_torch['loss_zu_mi'] = index_code_mi
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_mi']

            # unsupervised latents total correlation
            loss_dict_torch['loss_zu_tc'] = total_correlation
            loss_dict_torch['loss'] += beta * loss_dict_torch['loss_zu_tc']

            # unsupervised latents dimension-wise kl
            loss_dict_torch['loss_zu_dwkl'] = dimension_wise_kl
            loss_dict_torch['loss'] += kl * loss_dict_torch['loss_zu_dwkl']

            if accumulate_grad:
                loss_dict_torch['loss'].backward()

            # get loss value (weighted by batch size)
            bs = idx_end - idx_beg
            for key, val in loss_dict_torch.items():
                loss_dict_vals[key] += val.item() * bs
            loss_dict_vals['loss_data_mse'] += losses.gaussian_ll_to_mse(
                loss_dict_vals['loss_data_ll'] / bs, np.prod(x.shape[1:])) * bs

            # collect predicted labels to compute R2
            y_hat_all.append(y_hat.cpu().detach().numpy())

        # use variance-weighted r2s to ignore small-variance latents
        y_hat_all = np.concatenate(y_hat_all, axis=0)
        y_all = y.cpu().detach().numpy()
        if n is not None:
            n_np = n.cpu().detach().numpy()
            r2 = r2_score(y_all[n_np == 1], y_hat_all[n_np == 1], multioutput='variance_weighted')
        else:
            r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted')

        # compile (properly weighted) loss terms
        for key in loss_dict_vals.keys():
            loss_dict_vals[key] /= batch_size

        # store hyperparams
        loss_dict_vals['alpha'] = alpha
        loss_dict_vals['beta'] = beta
        loss_dict_vals['label_r2'] = r2

        return loss_dict_vals

    def get_predicted_labels(self, x, dataset=None, use_mean=True):
        """Process input data to get predicted labels.

        Parameters
        ----------
        x : :obj:`torch.Tensor` object
            input data of shape (n_frames, n_channels, y_pix, x_pix)
        dataset : :obj:`int`
            used with session-specific io layers
        use_mean : :obj:`bool`
            True to skip sampling step

        Returns
        -------
        :obj:`torch.Tensor`
            output of shape (n_frames, n_labels)

        """
        y, w, logvar, pool_idx, outsize = self.encoding(x, dataset=dataset)
        if not use_mean:
            y = reparameterize(y, logvar[:, :self.n_labels])
        y_hat = self.encoding.D(y)
        return y_hat

    def get_transformed_latents(self, inputs, dataset=None, as_numpy=True):
        """Return latents after supervised subspace has been transformed to original label space.

        Parameters
        ----------
        inputs : :obj:`torch.Tensor` object
            - image tensor of shape (n_frames, n_channels, y_pix, x_pix)
            - latents tensor of shape (n_frames, n_ae_latents)
        dataset : :obj:`int`, optional
            used with session-specific io layers
        as_numpy : :obj:`bool`, optional
            True to return as numpy array, False to return as torch Tensor

        Returns
        -------
        :obj:`np.ndarray` or :obj:`torch.Tensor` object
            array of latents in transformed latent space of shape (n_frames, n_latents)

        """

        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)

        # check to see if inputs are images or latents
        if len(inputs.shape) == 2:
            input_type = 'latents'
        else:
            input_type = 'images'

        # get latents in original space
        if input_type == 'images':
            y_og, w_og, logvar, pool_idx, outsize = self.encoding(inputs, dataset=dataset)
        else:
            y_og = inputs[:, :self.hparams['n_labels']]
            w_og = inputs[:, self.hparams['n_labels']:]

        # transform supervised latents to label space
        y_new = self.encoding.D(y_og)

        latents_tr = torch.cat([y_new, w_og], axis=1)

        if as_numpy:
            return latents_tr.cpu().detach().numpy()
        else:
            return latents_tr

    def get_inverse_transformed_latents(self, inputs, dataset=None, as_numpy=True):
        """Return latents after they have been transformed using the diagonal mapping D.

        Parameters
        ----------
        inputs : :obj:`torch.Tensor` object
            - image tensor of shape (n_frames, n_channels, y_pix, x_pix)
            - latents tensor of shape (n_frames, n_ae_latents) where the first n_labels entries are
              assumed to be labels in the original pixel space
        dataset : :obj:`int`, optional
            used with session-specific io layers
        as_numpy : :obj:`bool`, optional
            True to return as numpy array, False to return as torch Tensor

        Returns
        -------
        :obj:`np.ndarray` or :obj:`torch.Tensor` object
            array of latents in transformed latent space of shape (n_frames, n_latents)

        """

        if not isinstance(inputs, torch.Tensor):
            inputs = torch.Tensor(inputs)

        # check to see if inputs are images or latents
        if len(inputs.shape) == 2:
            input_type = 'latents'
        else:
            input_type = 'images'

        # get latents in original space
        if input_type == 'images':
            raise NotImplementedError
        else:
            y_og = inputs[:, :self.hparams['n_labels']]
            w_og = inputs[:, self.hparams['n_labels']:]

        # transform given labels to latent space
        y_new = torch.div(torch.sub(y_og, self.encoding.D.bias), self.encoding.D.weight)

        latents_tr = torch.cat([y_new, w_og], axis=1)

        if as_numpy:
            return latents_tr.cpu().detach().numpy()
        else:
            return latents_tr

In [11]:
class AgeManifoldVAE(nn.Module):
    def __init__(
        self,
        input_dims,
        activation,
        hidden_dims: list | tuple,
        latent_dim,
        mdl_type="regressor",
        n_classes=None,
    ):
        super().__init__()
        self.input_dims = input_dims
        self.n_layers = len(hidden_dims)
        self.hidden_dim = hidden_dims
        self.latent_dim = latent_dim
        if n_classes is None:
            n_classes = latent_dim

        layers = []
        for _in, _out in zip(hidden_dims, hidden_dims[1:]):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.BatchNorm1d(_out))
            layers.append(activation())

        self.encoder = nn.Sequential(
            nn.Linear(input_dims, hidden_dims[0]),
            activation(),
            *layers,
            nn.Linear(hidden_dims[-1], latent_dim),
            # activation(),
        )

        layers = []
        for _in, _out in zip(hidden_dims[::-1], hidden_dims[::-1][1:]):
            layers.append(nn.Linear(_in, _out))
            layers.append(nn.BatchNorm1d(_out))
            layers.append(activation())

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims[-1]),
            activation(),
            *layers,
            nn.Linear(hidden_dims[0], input_dims),
        )
        # linear regressor to predict age
        if mdl_type == "regressor":
            self.lm = nn.Linear(latent_dim, 1)
        else:
            self.lm = nn.Linear(latent_dim, n_classes)

    def forward(self, x):
        latent = self.transform(x)
        return self.decoder(latent), self.lm(latent), latent

    def transform(self, x):
        return self.encoder(x)

    def predict(self, x):
        return self.lm(self.transform(x))