In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

# Install Dependencies

! pip install -U lightning

# Organize Imports

In [None]:
from pathlib import Path

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.optim.lr_scheduler import CosineAnnealingLR

# Orginize Path

In [None]:
PATH = Path('../data')
model_path = PATH / 'models' / '2_layer_fashion_mnist_classifier_positive'
model_path.mkdir(parents=True, exist_ok=True)
MNIST_dir = PATH / 'fashion_mnist'
MNIST_dir.mkdir(parents=True, exist_ok=True)

# Initialize Device and Workers

In [None]:
import os
 
workers = os.cpu_count()
print("Number of CPUs in the system:", workers)

In [None]:
if torch.cuda.is_available():
    device = 'gpu'  
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu',

## Initialize Static Parameters

In [None]:
# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 0.01
WEIGHT_DECAY = 1e-4
L1_LAMBDA = 1e-4
EPOCHS = 64

# Initialize the Model

In [None]:
# Positive Weights Linear Layer
class PositiveWeightLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(PositiveWeightLinear, self).__init__()
        self.raw_weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        positive_weight = F.softplus(self.raw_weight)
        return F.linear(x, positive_weight, self.bias)

In [None]:
class PositiveWeightsNN(nn.Module):
    """MNIST classifier positive weights model"""

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.act2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 64)
        self.act3 = nn.ReLU()
        self.fc4 = PositiveWeightLinear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.fc1(x)
        h = self.act1(h)
        h = self.fc2(h)
        h = self.act2(h)
        h = self.fc3(h)
        h = self.act3(h)
        z = self.fc4(h)
        
        return z

In [None]:
class FashionMNISTClassifier(L.LightningModule):
    def __init__(self, model, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY, l1_lambda=L1_LAMBDA):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.l1_lambda = l1_lambda
        self.criterion = nn.CrossEntropyLoss()

    def l1_regularization(self):
        l1_norm = sum(p.abs().sum() for p in self.model.parameters() if p.requires_grad)
        return self.l1_lambda * l1_norm
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        l1_loss = self.l1_regularization()
        total_loss = loss + l1_loss

        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('train_loss', total_loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return total_loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_closure,
    ):
        # Execute the closure to run training_step, zero_grad, and backward.
        optimizer.step(closure=optimizer_closure)
        
        # (Optional) Custom logic: for example, enforcing positive weights:
        for name, param in self.named_parameters():
            if "weight" in name and param.requires_grad:
                param.data.clamp_(0)
        
        optimizer.zero_grad()
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
        return [optimizer], [scheduler]

# Prepare Dataset

In [None]:
train_transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])

val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])

In [None]:
def get_data_loaders(batch_size):
    train_dataset = datasets.FashionMNIST(root="./data", train=True, transform=train_transform, download=True)
    test_dataset = datasets.FashionMNIST(root="./data", train=False, transform=val_transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
# Prepare data loaders
train_loader, test_loader = get_data_loaders(BATCH_SIZE)

# Checkpointing the Model

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename=str(model_path / 'best-checkpoint'),
    verbose=True
)

last_checkpoint_callback = ModelCheckpoint(
    save_last=True,
    filename=str(model_path / 'last-checkpoint'),
    verbose=True
)

# Initiate Training

In [None]:
net = PositiveWeightsNN()

# Model training
model = FashionMNISTClassifier(net)
trainer = L.Trainer(
    max_epochs=EPOCHS,
    callbacks=[checkpoint_callback, last_checkpoint_callback],
    accelerator=device,
    devices=1,
)
trainer.fit(
    model, 
    train_loader, 
    test_loader
)

# Visualize Layer

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

# Function to visualize learned features
def visualize_weights(model, layer):
    weights = model.state_dict()[layer].cpu().numpy()
    fig, axes = plt.subplots(8, 8, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        if i < weights.shape[0]:
            ax.imshow(weights[i].reshape(28, 28), cmap='gray')
            ax.axis('off')
    plt.show()

In [None]:
visualize_weights(model, 'model.fc1.weight')

## Analysis of the Vectors