In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)

# Define the autoencoder model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )
        self.decoder = nn.Sequential(
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Define the training loop
def train(model, dataloader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        running_loss = 0.0
        for data in dataloader:
            images, _ = data
            images = images.view(images.size(0), -1)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print('Epoch {} loss: {:.4f}'.format(epoch+1, running_loss/len(dataloader)))

# Train the autoencoder
batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
autoencoder = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
num_epochs = 10
train(autoencoder, train_loader, criterion, optimizer, num_epochs)


In [None]:
import matplotlib.pyplot as plt

# Get a batch of images from the test set
test_data = MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=10, shuffle=True)
images, _ = next(iter(test_loader))

# Use the autoencoder to reconstruct the images
autoencoder.eval()
reconstructed = autoencoder(images.view(images.size(0), -1))

# Visualize the input and reconstructed output
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))

for images, row in zip([images, reconstructed.detach()], axes):
    for img, ax in zip(images, row):
        ax.imshow(img.view(28, 28), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

plt.show()
