In [None]:
import matplotlib.pyplot as plt

import torchvision
import torch
from torch import nn

params = {
    "LR": 1e-4,
    "N_BATCHS": 32,
    "N_EPOCHS": 10,
}

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Device type:", device)

In [None]:
class linear_autoencoder(nn.Module):
    def __init__(self,
        num_hid = 128,
        act_fn = nn.ReLU,
    ) -> None:
        super(linear_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(1*28*28, num_hid),
            act_fn(),
            nn.Linear(num_hid, 50),
            act_fn(),
            nn.Linear(50, 10),
            act_fn(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(10, 50),
            act_fn(),
            nn.Linear(50, num_hid),
            act_fn(),
            nn.Linear(num_hid, 1*28*28),
            nn.Sigmoid(),
        )
        self.latent = nn.Sequential(
            nn.Linear(10, 10),
            act_fn(),
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X.view(X.shape[0], -1)
        h = self.encoder(X)
        h = self.latent(h)
        X_hat = self.decoder(h)
        return X_hat.view(-1, 1, 28, 28)

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = torchvision.datasets.MNIST(
    root="./datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="./datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=params["N_BATCHS"], shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=params["N_BATCHS"], shuffle=True, num_workers=4
)

In [None]:
model = linear_autoencoder()
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=params["LR"])

train_history, test_history = [], []

for epoch in range(params["N_EPOCHS"]):
    train_loss = 0.
    test_loss = 0.
    for X, _ in train_loader:
        X_hat = model(X)

        # print("x", X.shape)
        # print("xhat", X_hat.shape)

        loss = loss_fn(X_hat, X)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()

    for X, _ in train_loader:
        X_hat = model(X)
        loss = loss_fn(X_hat, X)
        test_loss += loss.item()

    train_loss /= len(train_loader)            
    test_loss /= len(test_loader)
    print(f"> {epoch+1}/{params['N_EPOCHS']} | train_loss = {train_loss:.6f}; eval_loss = {test_loss:.6f}")
    train_history.append(train_loss)
    test_history.append(test_loss)

In [None]:
with torch.no_grad():
    fig, ax = plt.subplots(nrows=2, ncols=5,figsize=(20, 4))

    for idx, col in enumerate(ax):
        for idx2, row in enumerate(col):
            X, y = train_dataset[idx2]
            # torch.randint(high=len(train_dataset), size=(1,)).item()
            
            # X *= 255.
            if idx == 0:
                row.set_title(f"{y}")
                row.imshow(X.view(28,28))
            else:
                row.set_title(f"{y}")
                X_hat = model(X).view(28,28)
                row.imshow(X_hat)
    plt.show()

    plt.plot(train_history, label="train")
    plt.plot(test_history, label="test")
    plt.legend()
    plt.grid(True)
    plt.show()