In [6]:
%matplotlib inline
import os
import sys
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms, datasets
from torch.utils.data import ConcatDataset
import torch.optim as optim
from torchmetrics import Accuracy
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from Lenet import LeNet5

import matplotlib.pyplot as plt
import numpy as np

# Settings 

torch.set_printoptions(precision=3)

### First we train Lenet model on MNIST data set tath includes digits from 1 to 9, Secondly, we use pre-trained model to and fine tune it to classify '0' 

In [8]:
class MNISTNoZero(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super(MNISTNoZero, self).__init__(*args, **kwargs)
        
        # Filter out indices of all '0' digits
        self.non_zero_indices = [i for i, target in enumerate(self.targets) if target != 0]
        
        # Keep only the data and targets that are not '0'
        self.data = self.data[self.non_zero_indices]
        self.targets = self.targets[self.non_zero_indices]


# Normalization transform
transform_normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

#Use to get the '0' digits
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform_normalize)

# Create training and test datasets without '0' digits
mnist_trainset_no_zero = MNISTNoZero(root='./data', train=True, download=True, transform=transform_normalize)

mnist_testset_no_zero = MNISTNoZero(root='./data', train=False, download=True, transform=transform_normalize)
classes = ('1', '2', '3', '4', '5', '6', '7', '8', '9')


# Verify by checking the unique labels in the modified datasets
print("Unique labels in the modified training set:", mnist_trainset_no_zero.targets.unique())
print("Unique labels in the modified test set:", mnist_testset_no_zero.targets.unique())

Unique labels in the modified training set: tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
Unique labels in the modified test set: tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])


In [14]:
# size of mnist_trainset_no_zero
print("Size of the modified training set:", len(mnist_trainset_no_zero))

Size of the modified training set: 54077


In [9]:
indices_of_zeros = [i for i, label in enumerate(mnist_dataset.targets) if label == 0]

# Randomly select 100-200 indices of '0' digits
# You can adjust the number by changing the value of num_samples
num_samples = 100  # or 200, depending on your requirement
selected_indices = np.random.choice(indices_of_zeros, num_samples, replace=False)

# Create a subset from the MNIST dataset using the selected indices
subset_of_zeros = torch.utils.data.Subset(mnist_dataset, selected_indices)

# Verify the dataset
print(f"Number of images in the subset: {len(subset_of_zeros)}")


Number of images in the subset: 100


In [16]:
train_dataset, val_dataset = torch.utils.data.random_split(mnist_trainset_no_zero, [45077, 9000])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(mnist_testset_no_zero, batch_size=64,
                                            shuffle=False, num_workers=2)

valloader = torch.utils.data.DataLoader(val_dataset, batch_size=64,
                                            shuffle=False, num_workers=2)

zero_loader = torch.utils.data.DataLoader(subset_of_zeros, batch_size=10,  # Smaller batch size due to small dataset
                                          shuffle=True, num_workers=2)