# 1. Binary Logistic Regression

# Imports

In [None]:
import torch
import torch.nn as nn
import sys
import os

sys.path.append(os.path.abspath(".."))

from dataset import preprocess_mnist
from models.logistic_regression import LogisticRegressionModel
from train import train_model, test_model
from utils import plot_curves, print_confusion_matrix

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Binary 0 vs 1 subset

In [None]:
train_loader, val_loader, test_loader = preprocess_mnist(flatten=True)

# Filter 0 and 1 only

In [None]:
def filter_binary(loader):
    X, y = [], []
    for img, label in loader.dataset:
        if label in [0, 1]:
            X.append(img)
            y.append(label)
    X = torch.stack(X)
    y = torch.tensor(y)
    ds = torch.utils.data.TensorDataset(X, y)
    return torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True)

train_loader_bin = filter_binary(train_loader)
val_loader_bin = filter_binary(val_loader)
test_loader_bin = filter_binary(test_loader)

# Model

In [None]:
model_log = LogisticRegressionModel(input_dim=784)
loss_fn = nn.BCELoss()

# Train

In [None]:
train_losses, val_losses, train_accs, val_accs = train_model(
    model_log, train_loader_bin, val_loader_bin,
    epochs=30, lr=0.01, device=device, loss_fn=loss_fn, binary=True
)

# Plots

In [None]:
plot_curves(train_losses, val_losses, "Binary Logistic Regression - Loss", "Loss")
plot_curves(train_accs, val_accs, "Binary Logistic Regression - Accuracy", "Accuracy")

# Test

In [None]:
acc, cm = test_model(model_log, test_loader_bin, device, binary=True)
print(f"Test Accuracy: {acc:.4f}")
print_confusion_matrix(cm, classes=["0", "1"])

# 2. Softmax Regression

# Imports

In [None]:
from models.softmax_regression import SoftmaxRegressionModel
from train import train_model, test_model

# Model

In [None]:
model_softmax = SoftmaxRegressionModel(input_dim=784, num_classes=10)
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_loader, val_loader, test_loader = preprocess_mnist(flatten=True)

# Train

In [None]:
train_losses, val_losses, train_accs, val_accs = train_model(
    model_softmax, train_loader, val_loader,
    epochs=30, lr=0.01, device=device, loss_fn=loss_fn
)

# Plots

In [None]:
plot_curves(train_losses, val_losses, "Softmax Regression - Loss", "Loss")
plot_curves(train_accs, val_accs, "Softmax Regression - Accuracy", "Accuracy")

# Test

In [None]:
acc, cm = test_model(model_softmax, test_loader, device)
print(f"Test Accuracy: {acc:.4f}")
print_confusion_matrix(cm, classes=[str(i) for i in range(10)])