### Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from umap import UMAP

### Autoencoder class

In [3]:
class Autoencoder(nn.Module):
    def __init__(self, d_z = 10, input_size = 24, batch_size = 2):
        super().__init__()
        self.d_z = d_z
        self.input_size = input_size
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            nn.Linear(input_size + batch_size, 128),
            nn.Linear(128, 64),
            nn.Linear(64, d_z)
        )

        self.decoder = nn.Sequential(
            nn.Linear(d_z, 64),
            nn.Linear(64, 128),
            nn.Linear(128, input_size)
        )

    def forward(self, x):
        x = x.view(-1, self.input_size + self.batch_size)
        z = self.encoder(x)
        decoded = self.decoder(z)
        return decoded.view(-1, self.input_size)
    
    def encode(self, x):
        x = x.view(-1, self.input_size + self.batch_size)
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)