In [23]:
import matplotlib.pyplot as plt
import pandas as p
import torch
import numpy as np
import torch.nn as nn
import os
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset

In [24]:
def load_mnist_data(data_dir):
    train_data_path = os.path.join(data_dir, 'train-images.idx3-ubyte')
    train_labels_path = os.path.join(data_dir, 'train-labels.idx1-ubyte')
    test_data_path = os.path.join(data_dir, 't10k-images.idx3-ubyte')
    test_labels_path = os.path.join(data_dir, 't10k-labels.idx1-ubyte')


    # Load training images
    with open(train_data_path, 'rb') as f:
        magic, num_images, rows, cols = np.fromfile(f, dtype=np.dtype('>i4'), count=4)
        train_data = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows, cols)

    # Load training labels
    with open(train_labels_path, 'rb') as f:
        magic, num_labels = np.fromfile(f, dtype=np.dtype('>i4'), count=2)
        train_labels = np.fromfile(f, dtype=np.uint8)

    # Load test images
    with open(test_data_path, 'rb') as f:
        magic, num_images, rows, cols = np.fromfile(f, dtype=np.dtype('>i4'), count=4)
        test_data = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows, cols)

    # Load test labels
    with open(test_labels_path, 'rb') as f:
        magic, num_labels = np.fromfile(f, dtype=np.dtype('>i4'), count=2)
        test_labels = np.fromfile(f, dtype=np.uint8)

    # Convert images to torch tensors and apply augmentation transforms
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Normalizing the test images
    transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ])
   
    train_data = torch.stack([transform_train(img) for img in train_data])
    test_data = torch.stack([transform_test(img) for img in test_data])

    return train_data, train_labels, test_data, test_labels


In [25]:
dataset_path = './dataset'
train_images, train_labels, test_images, test_labels = load_mnist_data(dataset_path)

#print shapes of train_images, train_labels, test_images, test_labels
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

# Create DataLoader objects for images and labels
train_images_dataset = TensorDataset(train_images, torch.tensor(train_labels))
test_images_dataset = TensorDataset(test_images, torch.tensor(test_labels))

train_loader = DataLoader(train_images_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_images_dataset, batch_size=32, shuffle=False)


RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

In [None]:
# Any data processing we need to do here?

In [17]:
# tweakable hyperparameters for the model
# loss function and optimizer are placed in the training cell
batch_size = 32
num_of_epochs = 300
learning_rate = 0.005
neurons_per_layer = 128 # for the hidden layers

numbers_of_layers = 3 # num of hidden layers excl. the input and output layers
activation_function = torch.nn.Sigmoid() # try other activation functions too!

In [18]:
# the model
class DigitClassifier(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        
        self.all_layers = torch.nn.Sequential()

        # input layer
        self.all_layers.add_module('input', torch.nn.Linear(num_features, neurons_per_layer))
        self.all_layers.add_module('input_activation', activation_function)

        # hidden layers
        for i in range(numbers_of_layers):
            self.all_layers.add_module(f'hidden_{i}', torch.nn.Linear(neurons_per_layer, neurons_per_layer))
            self.all_layers.add_module(f'hidden_{i}_activation', activation_function)

        # output layer
        self.all_layers.add_module('output', torch.nn.Linear(neurons_per_layer, num_classes))
        self.all_layers.add_module('output_activation', torch.nn.Softmax(dim=1))

    def forward(self, x):
        return self.all_layers(x)        

In [None]:
##### ignore for now##########
# Data loader
from torch.utils.data import Dataset, DataLoader

#TODO: implement our dataloader, the commented lines are suggested by Copilot but I'm not sure if they are correct
class DigitsDataset(Dataset):
    def __init__(self, images, labels):
        pass
        # self.images = images
        # self.labels = labels
    
    def __len__(self):
        pass
        # return len(self.images)
    
    def __getitem__(self, idx):
        pass
        # return self.images[idx], self.labels[idx]
    
train_dataset = DigitsDataset(train_images, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# the training loop

import torch.nn.functional as F
torch.manual_seed(555)

# dimension n of the n*n input picture
n = 28

# number of classes (digits 0-9)
num_classes = 10

# the model
model = DigitClassifier(n*n, num_classes)

# loss and optimizer, try other combos too?
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# train!
num_epochs = num_of_epochs
epoch_losses = [] # to plot the loss curve later

for epoch in range(num_epochs):

    model = model.train()

    for batch_idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        logits = model(features)
        loss = F.mse_loss(logits, labels) # Loss function
        loss.backward()
        optimizer.step()
    
    # logging + save the loss
    print(f"Epoch: {epoch+1:03d}/{num_epochs:03d}" f" | Batch {batch_idx:03d}/{len(train_loader):03d}" f" | Train/Val Loss: {loss:.2f}")
    epoch_losses.append(loss.item())


# save model after training
torch.save(model.state_dict(), './model.pth')

In [None]:
# training loss plot
plt.plot(epoch_losses)
plt.legend(["Training Loss"])
plt.title("Training Loss vs. Number of Epoch")
plt.xlabel("Number of Epoch")
plt.ylabel("Training Loss")