In [26]:
import pandas as pd
import numpy as np
import modules as md
import optim as opt

In [27]:
df = pd.read_csv('mnist/mnist_train.csv')

In [28]:
reduced_mnist = df[df['label'].isin([0, 1])]

In [29]:
X_train = reduced_mnist.drop('label', axis=1).values / 255.0
y_train = reduced_mnist['label'].values.astype(int)

X_train = X_train[:1000]
y_train = y_train[:1000]

In [30]:
X_train.shape, y_train.shape

((1000, 784), (1000,))

In [31]:
X_train = X_train.reshape(-1, 1, 28, 28)
X_train.shape

(1000, 1, 28, 28)

In [32]:
import torch
from torch.utils.data import Dataset

class MNISTDataset(Dataset):
    def __init__(self, images, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

# Create dataset
mnist_dataset = MNISTDataset(X_train, y_train)

In [33]:
from torch.utils.data import DataLoader

batch_size = 4
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

In [34]:
model = [
    md.Conv2d(1, 16, 3, 2, 'same'), # 28x28 -> 14x14
    md.BatchNorm2d(16),
    md.ReLU(),
    md.Conv2d(16, 32, 3, 2, 'same'), # 14x14 -> 7x7
    md.BatchNorm2d(32),
    md.ReLU(),
    md.Flatten(),
    md.Linear(32 * 7 * 7, 2),
    md.Sigmoid()
]

In [35]:
from optim import Adam, CrossEntropyLoss

parameters = []
for layer in model:
    if hasattr(layer, 'weight'):
        parameters.extend([layer.weight, layer.bias])

optimizer = Adam(parameters, lr=1e-3)
criterion = CrossEntropyLoss()

In [36]:
from optim import train_epoch

num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0
    for batch_images, batch_labels in dataloader:
        avg_loss = train_epoch(model, optimizer, criterion, [(batch_images.numpy(), batch_labels.numpy())])
        total_loss += avg_loss

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader):.4f}")

Epoch 1/10, Loss: 0.0670
Epoch 2/10, Loss: 0.0130
Epoch 3/10, Loss: 0.0102
Epoch 4/10, Loss: 0.0049
Epoch 5/10, Loss: 0.0037
Epoch 6/10, Loss: 0.0029
Epoch 7/10, Loss: 0.0025
Epoch 8/10, Loss: 0.0021
Epoch 9/10, Loss: 0.0019
Epoch 10/10, Loss: 0.0017
