In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

In [2]:
# Step 1: Modify the AlexNet Model
class ModifiedAlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ModifiedAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        # Initialize the last layer weights randomly
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        nn.init.normal_(self.classifier[6].weight, 0, 0.01)
        nn.init.constant_(self.classifier[6].bias, 0)

# Create an instance of the modified model
model = ModifiedAlexNet(num_classes=10)


In [3]:

# Step 2: Load Pretrained Weights
weights_url = "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth"
weights_path = "alexnet.pth"
torch.hub.download_url_to_file(weights_url, weights_path)

# Load the state_dict for the original AlexNet
pretrained_state_dict = torch.load(weights_path)

# Remove the final layer weights from the state_dict
pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('classifier.6')}

# Load the pretrained weights into the modified model
model.load_state_dict(pretrained_state_dict, strict=False)



100%|██████████| 233M/233M [00:02<00:00, 86.0MB/s]


_IncompatibleKeys(missing_keys=['classifier.6.weight', 'classifier.6.bias'], unexpected_keys=[])

In [4]:
# Step 3: Train the Model on the MNIST Dataset
# Adjust the preprocessing to convert MNIST to 3 channels to fit the AlexNet input
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Use GPU if available
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15754229.60it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 484321.92it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4376822.35it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3768650.60it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [6]:
device

'cuda'

In [9]:
# Training loop
epochs = 100
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    num = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        num += 1
        print(f"No of 64-image batches processed: {num}")

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
No of 64-image batches processed: 641
No of 64-image batches processed: 642
No of 64-image batches processed: 643
No of 64-image batches processed: 644
No of 64-image batches processed: 645
No of 64-image batches processed: 646
No of 64-image batches processed: 647
No of 64-image batches processed: 648
No of 64-image batches processed: 649
No of 64-image batches processed: 650
No of 64-image batches processed: 651
No of 64-image batches processed: 652
No of 64-image batches processed: 653
No of 64-image batches processed: 654
No of 64-image batches processed: 655
No of 64-image batches processed: 656
No of 64-image batches processed: 657
No of 64-image batches processed: 658
No of 64-image batches processed: 659
No of 64-image batches processed: 660
No of 64-image batches processed: 661
No of 64-image batches processed: 662
No of 64-image batches processed: 663
No of 64-image batches processed: 664
No of 64-image batches 

KeyboardInterrupt: 

In [10]:
# Evaluation loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy:.2f}%")

Accuracy on test set: 99.12%


In [12]:
# Save the trained model
torch.save(model.state_dict(), 'modified_alexnet_mnist.pth')
print("Model saved to modified_alexnet_mnist.pth")


Model saved to modified_alexnet_mnist.pth


In [13]:
import os
os.listdir()

['.config', 'modified_alexnet_mnist.pth', 'alexnet.pth', 'data', 'sample_data']

In [14]:
from google.colab import files
files.download('modified_alexnet_mnist.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [15]:
# Create an instance of the modified model
model = ModifiedAlexNet(num_classes=10)

# Load the saved state_dict
model.load_state_dict(torch.load('modified_alexnet_mnist.pth'))

<All keys matched successfully>

In [17]:
# Evaluation loop
model.eval()
model.to(device)
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy:.2f}%")

Accuracy on test set: 99.12%
