In [37]:
%matplotlib inline

We use the Quickstart tutorial from PyTorch ([link](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)) for our example.

In [38]:
from pathlib import Path
import logging
import sys
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from sniper import SniperTraining

In [39]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [40]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [41]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Using cuda device


In [42]:
model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [43]:
loss_fn = nn.CrossEntropyLoss()
optim_lr = 1e-3

# SNIPER example starts here

In [44]:
# Define sparsity schedule
schedule = {0: 40,
            5: 20,
            10: 10,
            20: 0}

# Define variables needed for SNIPER training
model_builder = NeuralNetwork
snip_module_name = ''
batch_iterator = train_dataloader
def get_loss_fn(model, batch):
    X, y = batch
    X, y = X.to(device), y.to(device)
    pred = model(X)
    return loss_fn(pred, y)

In [45]:
# Set logger
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
logger.setLevel(logging.INFO)

The below will save model initial values, and the masks of the scheduled sparsities, in `masks`.

This is a minimal example; there are more options such as setting maximum sparsity of parameters (see documentation on `sniper.train()` for more info).

In [46]:
sniper_dir = Path('masks')
sniper = SniperTraining(sniper_dir=sniper_dir, logger=logger)
sniper.train(
    schedule=schedule,
    model=model,
    model_builder=model_builder,
    snip_module_name=snip_module_name,
    batch_iterator=batch_iterator,
    get_loss_fn=get_loss_fn,
    optim_lr=optim_lr,
)


Loading initial model state from masks/init_values.pt
Loading initial model state from masks/init_values.pt
False
False
True
True
All required sparsities present, loading sparsity 40...
All required sparsities present, loading sparsity 40...
Loaded mask from masks/masks_40.pt
Loaded mask from masks/masks_40.pt
Adding mask operation to forward hook...
Adding mask operation to forward hook...
Creating optimizer learning rates
Creating optimizer learning rates


The optimizer needs to accept the param groups from `sniper` to modify learning rates. If you do not have a scheduler, `sniper` will **directly** modify the learning rates during sparsity changes. If you have a scheduler, you need to set it in `sniper` to change  `base_lrs` when sparsity changes.

In [47]:
params = model.parameters() if sniper.param_groups is None else sniper.param_groups
optimizer = torch.optim.SGD(params, lr=optim_lr)
print(optimizer)
sniper.optimizers = [optimizer]
# sniper.schedulers = [scheduler]  # if you are using a scheduler

SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.0016033872578390255
    maximize: False
    momentum: 0
    name: linear_relu_stack.0.weight
    nesterov: False
    weight_decay: 0

Parameter Group 1
    dampening: 0
    foreach: None
    lr: 0.0011403118040089087
    maximize: False
    momentum: 0
    name: linear_relu_stack.0.bias
    nesterov: False
    weight_decay: 0

Parameter Group 2
    dampening: 0
    foreach: None
    lr: 0.001791996500006836
    maximize: False
    momentum: 0
    name: linear_relu_stack.2.weight
    nesterov: False
    weight_decay: 0

Parameter Group 3
    dampening: 0
    foreach: None
    lr: 0.0011797235023041474
    maximize: False
    momentum: 0
    name: linear_relu_stack.2.bias
    nesterov: False
    weight_decay: 0

Parameter Group 4
    dampening: 0
    foreach: None
    lr: 0.0011918063314711358
    maximize: False
    momentum: 0
    name: linear_relu_stack.4.weight
    nesterov: False
    weight_decay: 0

Parameter Grou

In [48]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [49]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Add `sniper.step()` to your training loop.

In [50]:
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
    sniper.step()
print("Done!")

Epoch 1
-------------------------------
loss: 2.305630  [    0/60000]
loss: 2.286907  [ 6400/60000]
loss: 2.260293  [12800/60000]
loss: 2.249516  [19200/60000]
loss: 2.217139  [25600/60000]
loss: 2.189889  [32000/60000]
loss: 2.196691  [38400/60000]
loss: 2.146585  [44800/60000]
loss: 2.130537  [51200/60000]
loss: 2.072554  [57600/60000]
Test Error: 
 Accuracy: 50.3%, Avg loss: 2.069967 

Epoch 2
-------------------------------
loss: 2.093181  [    0/60000]
loss: 2.062598  [ 6400/60000]
loss: 1.970859  [12800/60000]
loss: 1.975963  [19200/60000]
loss: 1.869963  [25600/60000]
loss: 1.817361  [32000/60000]
loss: 1.816961  [38400/60000]
loss: 1.706102  [44800/60000]
loss: 1.695969  [51200/60000]
loss: 1.580570  [57600/60000]
Test Error: 
 Accuracy: 58.6%, Avg loss: 1.600097 

Epoch 3
-------------------------------
loss: 1.664915  [    0/60000]
loss: 1.613560  [ 6400/60000]
loss: 1.459231  [12800/60000]
loss: 1.506228  [19200/60000]
loss: 1.377245  [25600/60000]
loss: 1.374366  [32000/600

In [51]:
torch.save(model.state_dict(), "sniper_model.pth")

Now compare to the model trained without SNIPER. The SNIPER-trained model should converge faster.

In [53]:
model = NeuralNetwork().to(device)
init_values = torch.load(sniper_dir / 'init_values.pt', map_location=device)
model.load_state_dict(init_values)
optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr)

In [54]:
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.304559  [    0/60000]
loss: 2.291708  [ 6400/60000]
loss: 2.276042  [12800/60000]
loss: 2.267381  [19200/60000]
loss: 2.247044  [25600/60000]
loss: 2.228423  [32000/60000]
loss: 2.235480  [38400/60000]
loss: 2.204573  [44800/60000]
loss: 2.200688  [51200/60000]
loss: 2.166080  [57600/60000]
Test Error: 
 Accuracy: 42.9%, Avg loss: 2.164025 

Epoch 2
-------------------------------
loss: 2.177513  [    0/60000]
loss: 2.165572  [ 6400/60000]
loss: 2.116897  [12800/60000]
loss: 2.125358  [19200/60000]
loss: 2.068360  [25600/60000]
loss: 2.028746  [32000/60000]
loss: 2.048096  [38400/60000]
loss: 1.975801  [44800/60000]
loss: 1.973027  [51200/60000]
loss: 1.901443  [57600/60000]
Test Error: 
 Accuracy: 56.5%, Avg loss: 1.906071 

Epoch 3
-------------------------------
loss: 1.940768  [    0/60000]
loss: 1.910886  [ 6400/60000]
loss: 1.804289  [12800/60000]
loss: 1.829366  [19200/60000]
loss: 1.716968  [25600/60000]
loss: 1.684815  [32000/600

In [None]:
torch.save(model.state_dict(), "original_model.pth")