In [None]:
import matplotlib.pyplot as matplotlib_pyplot
import torch as torch
import torch.nn as torch_nn
import torch.optim as torch_optim
import torch.utils.data as torch_utils_data
import torchvision as torchvision

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Autoencoder(torch_nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = torch_nn.Sequential(
            torch_nn.Linear(784, 128),
            torch_nn.ReLU(),
            torch_nn.Linear(128, 64),
            torch_nn.ReLU(),
            torch_nn.Linear(64, 12),
            torch_nn.ReLU(),
            torch_nn.Linear(12, 3)
        )
        self.decoder = torch_nn.Sequential(
            torch_nn.Linear(3, 12),
            torch_nn.ReLU(),
            torch_nn.Linear(12, 64),
            torch_nn.ReLU(),
            torch_nn.Linear(64, 128),
            torch_nn.ReLU(),
            torch_nn.Linear(128, 784),
            torch_nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
# Load and preprocess the MNIST dataset
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# Initialize the model, loss function, and optimizer
model = Autoencoder().to(device)
criterion = torch_nn.MSELoss()
optimizer = torch_optim.Adam(model.parameters())

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.view(img.size(0), -1).to(device)

        output = model(img)
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the autoencoder
model.eval()
with torch.no_grad():
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
    test_img = test_dataset[0][0].view(-1, 28 * 28).to(device)
    output = model(test_img).cpu()

In [None]:
test_image_result = test_img.cpu().view(28, 28)
output_image_result = output.view(28, 28)

In [None]:
fig, axes = matplotlib_pyplot.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(test_img.cpu().view(28, 28), cmap='gray')
axes[0].set_title('Original Image')
axes[1].imshow(output.view(28, 28), cmap='gray')
axes[1].set_title('Reconstructed Image')
matplotlib_pyplot.tight_layout()
matplotlib_pyplot.show()