In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

# Define the autoencoder model
class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        latent = self.encoder(x)
        x_recon = self.decoder(latent)
        return x_recon, latent

# Define the classification model
class Classifier(nn.Module):
    def __init__(self, latent_dim):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.softmax(self.fc2(x), dim=1)
        return x

# Instantiate the autoencoder and classifier models
latent_dim = 10
autoencoder = Autoencoder(latent_dim)
classifier = Classifier(latent_dim)

# Load the MNIST dataset
mnist_data = MNIST(root='data', train=True, transform=ToTensor(), download=True)
data_loader = DataLoader(mnist_data, batch_size=64, shuffle=True)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(autoencoder.parameters()) + list(classifier.parameters()))

# Train the autoencoder and classifier jointly
for epoch in range(10):
    for batch_idx, (images, labels) in enumerate(data_loader):
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Pass the images through the autoencoder and classifier
        images_recon, latent = autoencoder(images)
        pred = classifier(latent)

        # Compute the loss and backpropagate
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()

        # Print training progress
        if batch_idx % 100 == 0:
            print('Epoch {} [{}/{}]: Loss={:.4f}'.format(
                epoch+1, batch_idx*len(images), len(data_loader.dataset), loss.item()))

# Evaluate the classifier on the test set
mnist_test = MNIST(root='data', train=False, transform=ToTensor(), download=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

classifier.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.view(images.size(0), -1)
        latent = autoencoder.encoder(images)
        pred = classifier(latent)
        _, predicted = torch.max(pred.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Test Accuracy: {:.2f}%'.format(100 * correct / total))



Epoch 1 [0/60000]: Loss=2.3009
Epoch 1 [6400/60000]: Loss=1.8014
Epoch 1 [12800/60000]: Loss=1.6986
Epoch 1 [19200/60000]: Loss=1.6790
Epoch 1 [25600/60000]: Loss=1.5441
Epoch 1 [32000/60000]: Loss=1.5742
Epoch 1 [38400/60000]: Loss=1.5269
Epoch 1 [44800/60000]: Loss=1.5251
Epoch 1 [51200/60000]: Loss=1.5852
Epoch 1 [57600/60000]: Loss=1.5391
Epoch 2 [0/60000]: Loss=1.5755
Epoch 2 [6400/60000]: Loss=1.5235
Epoch 2 [12800/60000]: Loss=1.5227
Epoch 2 [19200/60000]: Loss=1.5354
Epoch 2 [25600/60000]: Loss=1.5078
Epoch 2 [32000/60000]: Loss=1.5232
Epoch 2 [38400/60000]: Loss=1.5156
Epoch 2 [44800/60000]: Loss=1.5788
Epoch 2 [51200/60000]: Loss=1.5314
Epoch 2 [57600/60000]: Loss=1.4707
Epoch 3 [0/60000]: Loss=1.5372
Epoch 3 [6400/60000]: Loss=1.4768
Epoch 3 [12800/60000]: Loss=1.5596
Epoch 3 [19200/60000]: Loss=1.5041
Epoch 3 [25600/60000]: Loss=1.5386
Epoch 3 [32000/60000]: Loss=1.5356
Epoch 3 [38400/60000]: Loss=1.5255
Epoch 3 [44800/60000]: Loss=1.4987
Epoch 3 [51200/60000]: Loss=1.5275
