In [1]:
import numpy as np
import torch
print(torch.__version__)

2.0.1+cu117


This is the max norm constraint.

In [2]:
# SGD with MaxNorm constraint
# max_norm must be packaged inside parameters
class MNSGD(torch.optim.SGD):    
    def step(self):
        super().step()
        with torch.no_grad():
            # apply the max_norm weight correction per group of parameters
            for group in self.param_groups:
                # rescale iff group has specified max_norm
                if group.get('max_norm'):
                    for tensor in group['params']:
                        if tensor.dim() > 1:
                            torch.renorm(tensor, p=2, dim=0, maxnorm=group['max_norm'], out=tensor)

In [3]:
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=False,
    transform=ToTensor()
)

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self, dropout_p=0.3):
        # call __init__ from the parent class: nn.Module.
        super(NeuralNetwork, self).__init__()
        
        self.dropout_p = dropout_p
        
        # this flattens the images in the batch to 1d tensors suitable for nn.Sequential
        self.flatten = nn.Flatten()
        
        self.hidden = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(128, 80),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),
        )
        
        self.output = nn.Linear(80, 10)
        
    def forward(self, x: torch.Tensor):
        x = self.flatten(x)
        x = self.hidden(x)
        logits = self.output(x)
        return logits

In [6]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    
    model.train()                  # enter training mode to activate dropout
    for i, batch in enumerate(dataloader):
        # Compute prediction and loss
        X, y = batch[0].to(device), batch[1].to(device)
        pred = model(X)            #*mat: evaluate the model batch_size times and put results in a tensor "next" to each other
        loss = loss_fn(pred, y)    #*mat: compute the loss function
        
        # Backpropagation
        optimizer.zero_grad()      # place in memory for the derivative needs resetting
        loss.backward()            #*mat: calculate the derivative (**)
        optimizer.step()           # call the black box algoritm that do the magic
        
        # print the progress every 100th batch.
        if i % 100 == 0:
            loss, current = loss.item(), i*len(X)
            print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

In [7]:
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    model.eval()                   # enter testing mode to prevent dropout
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[0].to(device), batch[1].to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_loss /= num_batches
    correct /= size
    
    print(f"Test error: \n Accuracy {(100*correct):>0.1f}%, Avg loss {test_loss:>8f} \n ")

Besides training and testing loops, I've also added a `check_norms()` to verify that the max norm constraint is implemented correctly, up to a tolerance (0.0001 by default).

In [8]:
def check_norms(model, max_norm, tol=1e-4):
    with torch.no_grad():
        norm_status = 'Norms OK'
        max_vector_norm = 0
        for i, tensor in enumerate(model.hidden.parameters()):
            if tensor.dim() > 1:
                vector_norm = torch.max(
                    torch.linalg.vector_norm(tensor, ord=2, 
                                             dim=1, keepdim=True)
                )
                if vector_norm - tol > max_norm:
                    norm_status = 'Norms NOK'
                if vector_norm > max_vector_norm:
                    max_vector_norm = vector_norm
        print(norm_status + f", max vector norm is {max_vector_norm}\n")

This is where max norm is set.

In [9]:
# hyperparameters
learning_rate = 0.5
batch_size = 100
epochs = 10
max_norm = 2
momentum = 0.5

# define model
model = NeuralNetwork(dropout_p = 0.5).to(device)

loss_fn = nn.CrossEntropyLoss()
model = NeuralNetwork().to(device)
optimizer = MNSGD(
    [
        {'params': model.hidden.parameters(), 'max_norm': max_norm }, 
        {'params': model.output.parameters() }
    ],
    lr=learning_rate,
    momentum=momentum
)

This is a train/test routine that includes a check of whether weight vectors satisfy the max norm constraint after every test loop.

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=100, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size = 100, shuffle=True, pin_memory=True)

for t in range(epochs):
    print(f"Epoch {t+1}\n----------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
    check_norms(model, max_norm)
print("Done!")

Epoch 1
----------------------------
loss: 2.308669 [    0/60000]
loss: 0.436984 [10000/60000]
loss: 0.492524 [20000/60000]
loss: 0.334368 [30000/60000]
loss: 0.372025 [40000/60000]
loss: 0.467952 [50000/60000]
Test error: 
 Accuracy 93.6%, Avg loss 0.200647 
 
Norms OK, max vector norm is 2.000000238418579

Epoch 2
----------------------------
loss: 0.253347 [    0/60000]
loss: 0.236187 [10000/60000]
loss: 0.343207 [20000/60000]
loss: 0.372917 [30000/60000]
loss: 0.235315 [40000/60000]
loss: 0.212042 [50000/60000]
Test error: 
 Accuracy 95.7%, Avg loss 0.146674 
 
Norms OK, max vector norm is 2.000000238418579

Epoch 3
----------------------------
loss: 0.258380 [    0/60000]
loss: 0.123156 [10000/60000]
loss: 0.245546 [20000/60000]
loss: 0.337066 [30000/60000]
loss: 0.414379 [40000/60000]
loss: 0.301445 [50000/60000]
Test error: 
 Accuracy 95.9%, Avg loss 0.138769 
 
Norms OK, max vector norm is 2.000000238418579

Epoch 4
----------------------------
loss: 0.483095 [    0/60000]
loss

The norms are correctly scaled.