In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x78f9a0171430>

In [2]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, bottleneck_dim):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, bottleneck_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()  # Sigmoid activation for output in [0, 1]
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [6]:
class Classifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Classifier, self).__init__()
        self.feature_extractor = encoder_for_transfer
        self.classifier_layer = nn.Linear(bottleneck_dim, output_dim)

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier_layer(features)
        return output

In [3]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# Split the dataset into training and validation sets
train_data, val_data = train_test_split(mnist_dataset, test_size=0.1, random_state=42)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 74212523.98it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 32157072.96it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 25095584.60it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3660747.27it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
# Initialize the autoencoder
input_dim = 28 * 28  # MNIST image size
bottleneck_dim = 5
autoencoder = Autoencoder(input_dim, bottleneck_dim)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)


In [5]:
num_epochs = 10

for epoch in range(num_epochs):
    for data in train_loader:
        inputs, _ = data
        inputs = inputs.view(inputs.size(0), -1)

        optimizer.zero_grad()
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()

    print(f'Autoencoder Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Autoencoder Epoch [1/10], Loss: 0.9272
Autoencoder Epoch [2/10], Loss: 0.9284
Autoencoder Epoch [3/10], Loss: 0.9248
Autoencoder Epoch [4/10], Loss: 0.9256
Autoencoder Epoch [5/10], Loss: 0.9274
Autoencoder Epoch [6/10], Loss: 0.9238
Autoencoder Epoch [7/10], Loss: 0.9274
Autoencoder Epoch [8/10], Loss: 0.9231
Autoencoder Epoch [9/10], Loss: 0.9213
Autoencoder Epoch [10/10], Loss: 0.9230


In [8]:
encoder_for_transfer = autoencoder.encoder
# Initialize the classifier
output_dim = 10  # Number of classes in MNIST
classifier = Classifier(bottleneck_dim, output_dim)

# Define the loss function and optimizer for the classification task
classification_criterion = nn.CrossEntropyLoss()
classification_optimizer = optim.Adam(classifier.parameters(), lr=0.001)


In [9]:
# Training the classifier
num_classification_epochs = 5

for epoch in range(num_classification_epochs):
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.view(inputs.size(0), -1)

        classification_optimizer.zero_grad()
        features = classifier.feature_extractor(inputs)
        outputs = classifier.classifier_layer(features)
        classification_loss = classification_criterion(outputs, labels)
        classification_loss.backward()
        classification_optimizer.step()

    # Evaluate on validation set
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            inputs, labels = data
            inputs = inputs.view(inputs.size(0), -1)
            features = classifier.feature_extractor(inputs)
            outputs = classifier.classifier_layer(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Epoch [{epoch+1}/{num_classification_epochs}], Classification Loss: {classification_loss.item():.4f}, Accuracy: {accuracy:.4f}')

Epoch [1/5], Classification Loss: 0.3909, Accuracy: 0.8943
Epoch [2/5], Classification Loss: 0.4768, Accuracy: 0.9327
Epoch [3/5], Classification Loss: 0.1812, Accuracy: 0.9428
Epoch [4/5], Classification Loss: 0.1390, Accuracy: 0.9508
Epoch [5/5], Classification Loss: 0.1340, Accuracy: 0.9545
