In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Lenet import LeNet5
from torchvision import transforms, datasets



In [2]:
# Normalization transform
transform_normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_normalize)

In [3]:
indices_of_zeros = [i for i, label in enumerate(mnist_dataset.targets) if label == 0]

# Randomly select 100-200 indices of '0' digits
# You can adjust the number by changing the value of num_samples
num_samples = 100  # or 200, depending on your requirement
selected_indices = np.random.choice(indices_of_zeros, num_samples, replace=False)

# Create a subset from the MNIST dataset using the selected indices
subset_of_zeros = torch.utils.data.Subset(mnist_dataset, selected_indices)

# Verify the dataset
print(f"Number of images in the subset: {len(subset_of_zeros)}")


Number of images in the subset: 100


In [4]:
subset_of_zeros

<torch.utils.data.dataset.Subset at 0x16d7788a490>

In [5]:
pre_trained_model = LeNet5(num_classes=9)

pre_trained_model.load_state_dict(torch.load('lenet_1_to_9_v2.pth'))

<All keys matched successfully>

In [6]:
print(pre_trained_model)

LeNet5(
  (feature): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Tanh()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=400, out_features=120, bias=True)
    (2): Tanh()
    (3): Linear(in_features=120, out_features=84, bias=True)
    (4): Tanh()
    (5): Linear(in_features=84, out_features=9, bias=True)
  )
)
