In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import SGD
import os
import struct
from mlxtend.data import loadlocal_mnist
import matplotlib.pyplot as plt
import tqdm

In [2]:
class CustomDataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data, self.labels = self.load_data(data_path, label_path)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]
    
    def load_data(self, data_path, label_path):
        return loadlocal_mnist(images_path=data_path, labels_path=label_path)

In [3]:
class ModelArchitecture(torch.nn.Module):
    def __init__(self):
        super(ModelArchitecture, self).__init__()
        self.layer1 = torch.nn.Linear(784, 256)
        self.layer2 = torch.nn.Linear(256, 512)
        self.layer3 = torch.nn.Linear(512, 256)
        self.layer4 = torch.nn.Linear(256, 128)
        self.layer5 = torch.nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = torch.relu(self.layer3(x))
        x = torch.relu(self.layer4(x))
        x = self.layer5(x)
        return x


In [4]:
# train-images-idx3-ubyte: training set images (9912422 bytes)
# train-labels-idx1-ubyte: training set labels (28881 bytes)
# t10k-images-idx3-ubyte: test set images (1648877 bytes)
# t10k-labels-idx1-ubyte: test set labels (4542 bytes)

In [5]:
train_image_path = "../data/train-images.idx3-ubyte"
train_label_path = "../data/train-labels.idx1-ubyte"

test_image_path = "../data/t10k-images.idx3-ubyte"
test_label_path = "../data/t10k-labels.idx1-ubyte"

In [6]:
# Retrieve Dataset
train_dataset = CustomDataset(train_image_path, train_label_path)
test_dataset = CustomDataset(test_image_path, test_label_path)

# model
model = ModelArchitecture()

# optimizer
optimizer = SGD(model.parameters(), lr=0.01)

# loss
criterion = torch.nn.CrossEntropyLoss()

In [7]:
# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Iterate through the DataLoader
for images, labels in train_dataloader:
    print(images.shape, labels.shape)
    break  # Just to print the shape of the first batch

torch.Size([32, 784]) torch.Size([32])


In [8]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Flatten the images
        images = images.view(images.size(0), -1).float()
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_dataloader):.4f}")

Epoch 1/10: 100%|██████████| 1875/1875 [00:08<00:00, 224.82it/s]


Epoch [1/10], Loss: 0.2360


Epoch 2/10: 100%|██████████| 1875/1875 [00:08<00:00, 222.53it/s]


Epoch [2/10], Loss: 0.0948


Epoch 3/10:   3%|▎         | 50/1875 [00:00<00:13, 135.96it/s]


KeyboardInterrupt: 

In [None]:
model.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_dataloader:
        images = images.view(images.size(0), -1).float()
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {test_loss/len(test_dataloader):.4f}")
print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Loss: 0.0871
Test Accuracy: 98.04%


In [None]:
# Check if the directory exists, if not, create it
model_dir = "../models"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Get number of items in models directory...
number = len(os.listdir(model_dir)) + 1

# Save the model
model_path = os.path.join(model_dir, f"model_{number}.pth")
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to ../models/model_1.pth
