In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from resnet import ResNet
from torchvision import transforms, datasets

from torchmetrics import Accuracy
from torchinfo import summary


In [5]:
class MNISTNoZero(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super(MNISTNoZero, self).__init__(*args, **kwargs)
        
        # Filter out indices of all '0' digits
        self.non_zero_indices = [i for i, target in enumerate(self.targets) if target != 0]
        
        # Keep only the data and targets that are not '0'
        self.data = self.data[self.non_zero_indices]
        self.targets = self.targets[self.non_zero_indices] - 1


# Define the preprocessing transformation
transform = transforms.Compose([
    transforms.Resize(256),  # Resize the image to 256x256 pixels
    transforms.CenterCrop(224),  # Crop the center 224x224 pixels of the image
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize pixel values
])


# Create training and test datasets without '0' digits
mnist_trainset_no_zero = MNISTNoZero(root='./data', train=True, download=True, transform=transform)

mnist_testset_no_zero = MNISTNoZero(root='./data', train=False, download=True, transform=transform)


# Apply the transformation to your image
# preprocessed_image = transform(image)

In [6]:
train_dataset, val_dataset = torch.utils.data.random_split(mnist_trainset_no_zero, [45077, 9000])

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True, num_workers=0)

# good
testloader = torch.utils.data.DataLoader(mnist_testset_no_zero, batch_size=64,
                                            shuffle=False, num_workers=0)

valloader = torch.utils.data.DataLoader(val_dataset, batch_size=64,
                                            shuffle=False, num_workers=0)