In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Install Dependencies

! pip install -U lightning

# Organize Imports

In [None]:
from pathlib import Path

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Orginize Path

In [None]:
PATH = Path('../data')
model_path = PATH / 'models' / '2_layer_128_64_sae_sigmoid'
model_path.mkdir(parents=True, exist_ok=True)
MNIST_dir = PATH / 'mnist'
MNIST_dir.mkdir(parents=True, exist_ok=True)

# Initialize Device and Workers

In [None]:
import os
 
workers = os.cpu_count()
print("Number of CPUs in the system:", workers)

In [None]:
if torch.cuda.is_available():
    device = 'gpu'  
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu',

# Initialize the Model

In [None]:
class SimpleNN(nn.Module):

    def __init__(self, input_size, hidden_size):
        super().__init__()
        # Define layers
        self.hidden = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.output = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # Flatten the image tensors
        x = x.view(x.size(0), -1)
        # Hidden layer with ReLU activation
        h = self.hidden(x)
        h = self.relu(h)
        # Output layer with Sigmoid activation
        r = self.output(h)

        return r

In [None]:
class SimpleNNPL(pl.LightningModule):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # Define layers
        self.model = SimpleNN(input_size, hidden_size)
        self.sigmoid = nn.Sigmoid()
        # Loss function for binary classification
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)


    def training_step(self, batch, batch_idx):
        x, y = batch
        # Ensure the target is of type float and has correct shape
        y = y.float().unsqueeze(1)
        y_hat = self(x)
        # Compute loss
        loss = self.loss_fn(y_hat, y)
        # Compute accuracy
        preds = (self.sigmoid(y_hat) > 0.5).float()
        acc = (preds == y).float().mean()
        # Log metrics
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_acc', acc, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.float().unsqueeze(1)
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        preds = (self.sigmoid(y_hat) > 0.5).float()
        acc = (preds == y).float().mean()
        self.log('val_loss', loss, on_epoch=True)
        self.log('val_acc', acc, on_epoch=True)


    def configure_optimizers(self):
        # Use Adam optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Function to convert labels to binary (even=0, odd=1)
def binary_target_transform(target):
    return 1 - int(target % 2 == 0 or target % 3 == 0)

# Transform to convert images to tensors
transform = transforms.Compose([transforms.ToTensor()])

# Load MNIST dataset with binary targets
train_dataset = datasets.MNIST(
    root='.',
    train=True,
    download=True,
    transform=transform,
    target_transform=binary_target_transform
)
val_dataset = datasets.MNIST(
    root='.',
    train=False,
    download=True,
    transform=transform,
    target_transform=binary_target_transform
)

val_dataset_orig = datasets.MNIST(
    root='.',
    train=False,
    download=True,
    transform=transform,
    # target_transform=binary_target_transform
)

In [None]:
# Data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True, 
    # num_workers=workers - 2,
    # persistent_workers = True,
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=128, 
    # num_workers=workers - 2,
    # persistent_workers = True,
)
val_loader_orig = DataLoader(
    val_dataset_orig, 
    batch_size=128, 
    # num_workers=workers - 2,
    # persistent_workers = True,
)

# Initialize the model
pl_model = SimpleNNPL(input_size=28 * 28, hidden_size=128)

In [None]:
for x, y in train_loader:
    print(x.shape, y.shape)
    y_hat = pl_model(x)
    print(y_hat.shape)
    # print(y, y_hat)
    break

In [None]:
for x, y in val_loader_orig:
    print(x.shape, y.shape)
    y_hat = pl_model(x)
    print(y_hat.shape)
    break

In [None]:
pl_model.loss_fn(y_hat, y.float().unsqueeze(1))
preds = (nn.Sigmoid()(y_hat) > 0.5).float()
acc = (preds == y).float().mean()
acc

# Checkpointing the Model

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=model_path,
    filename='sample-mnist-{epoch:02d}',
    save_top_k = 1,
    auto_insert_metric_name=True,
    verbose = True,
)

# Initiate Training

In [None]:
# Initialize the model
pl_model = SimpleNNPL(input_size=28*28, hidden_size=128)

# Initialize a trainer
trainer = pl.Trainer(
    max_epochs=64,
    accelerator=device,
    # accelerator='cpu',
    callbacks=[checkpoint_callback],
)

# Train the model
trainer.fit(pl_model, train_loader, val_loader)