## In this tutorial we create a CNN and dataloaders, and train / prune the model.

In [None]:
!pip install torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os
os.environ["KERAS_BACKEND"] = "torch" # Needs to be set, some pruning layers as well as the quantizers are Keras



Let's define a ResNet20 using a slighty modified version of https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py


In [None]:
def _weights_init(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))


    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu2(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)
        self.relu1 = nn.ReLU()
        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])


model = resnet20()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

model

## Add pruning and quantization
To add pruning and quantization, we need a config file that defines how to do that. Let's load a config file from pquant/configs/configs_pdp.yaml. The training function we use later will add the pruning layers and quantized activations automatically using this config

In [None]:
from pquant.core.utils import get_default_config

# pruning_methods: "autosparse, cl, cs, dst, pdp, wanda"
pruning_method = "pdp"
config = get_default_config(pruning_method)
# Set target sparsity to 50%. This parameter exists only for some pruning methods
config["pruning_parameters"]["sparsity"] = 0.5
# Check what's inside the config
config

In [None]:
# Replace layers with compressed layers
from pquant.core.compressed_layers import add_pruning_and_quantization
model = add_pruning_and_quantization(model, config)
model

## Pruning and quantization in the config
From the config we see that we are using the PDP pruning method, unstructured version. We aim for 50% weights pruned (sparsity 0.4), and we quantize the model to 8 bits (1 bit goes to sign). 
By default, all convolutional and linear layers, as well as activations will be quantized using the default values ```default_integer_bits``` and ```default_fractional_bits```. Similarly, by default all convolutional and linear layers will be pruned.

We can disable pruning and/or quantization by setting the enable_pruning / enable_quantization to False. To change quantization bits for a specific layer, add the layers name to the list found in ```layer_specific```, followed by number of bits. To disable pruning for a single layer, add its name to the ```disable_pruning_for_layers``` list. 

We'll show later how to create a custom quantization / pruning config file from an existing config for a given model

## About the different epochs

The config defines 20 ```pretraining_epochs```, 100 ```epochs``` and 20 ```fine_tuning_epochs```. What happens during each of these training steps is algorithm specific. 

In PDP, the pretraining phase consists of training without pruning, followed by calculation of layerwise pruning budgets. After pretraining is finished and the layerwise pruning budgets have been calculated, the training with pruning begins. The mask during this training is a soft mask, consisting of values ranging between (and including) 0 and 1. 

The fine-tuning step in PDP is optional (not mentioned in the original paper), and during it the mask is fixed and rounded to 0s and 1s.

## Create data set
#### Let's create the data loader and the training and validation loops

In [None]:
import torchvision
import torchvision.transforms as transforms

def get_cifar10_data(batch_size):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), 
                                          transforms.ToTensor(), normalize])
    test_transform = transforms.Compose([transforms.ToTensor(), normalize])  
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
    valset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=test_transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                         shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader

from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")


def train_resnet(model, trainloader, device, loss_func, epoch, optimizer, scheduler, *args, **kwargs):
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(1.), f=torch.tensor(6.)) # 8 bits input quantization
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_func(outputs, labels)
        losses = get_model_losses(model, torch.tensor(0.).to(device))
        loss += losses
        loss.backward()
        optimizer.step()
        epoch += 1
        if scheduler is not None:
            scheduler.step()


from pquant.core.compressed_layers import get_layer_keep_ratio, get_model_losses

def validate_resnet(model, testloader, device, loss_func, epoch, *args, **kwargs):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(1.), f=torch.tensor(6.)) # 8 bits input quantization
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        ratio = get_layer_keep_ratio(model)
        print(f'Accuracy: {100 * correct / total:.2f}%, remaining_weights: {ratio * 100:.2f}%')

BATCH_SIZE = 256
train_loader, val_loader = get_cifar10_data(BATCH_SIZE)

## Create loss function, scheduler and optimizer

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, 200)
loss_function = nn.CrossEntropyLoss()

## Train model
Training time. We use the train_compressed_model function from pquant to train. We need to provide some parameters such as training and validation functions, their input parameters, the model and the config file. The function automatically adds pruning layers and replaces activations with a quantized variant, trains the model, and removes the pruning layers after training is done

In [None]:
from pquant.core.train import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

trained_model = iterative_train(model = model, 
                                config = config, 
                                train_func = train_resnet, 
                                valid_func = validate_resnet, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = device, 
                                loss_func = loss_function,
                                optimizer = optimizer, 
                                scheduler = scheduler
                                )


We see from that with PDP, the number of weights goes down during training, until it reaches the target sparsity (~50% remaining weights). The algorithm actually increases the sparsity linearly, but since the mask used before fine-tuning is a soft mask, the sparsity that is printed before fine-tuning seems to increase somewhat noisily. 
During fine-tuning the mask is fixed, and turned into a mask of 0s and 1s by a simple rounding operation. 

In the original paper for PDP there was no fine-tuning after the creation of the final hard mask. We have added fine-tuning here as an option that can be turned off by simply setting ```fine_tuning_epochs``` to 0 in the config file.

In [None]:
# Remove compression layers
from pquant.core.compressed_layers import remove_pruning_from_model
model = remove_pruning_from_model(trained_model, config)
model

## Custom config from existing config
Using the ```pquant/configs/config_pdp.yaml``` as base, let's customize the quantization and pruning scheme. 

The function we use will go through the model's layers and do the following: 

Quantization:

        1. Looks for the names of convolutional and linear layers, as well as names of the activations (layer type activations and functional types)
        2. Adds the name of the layer to the layer_specific list, along with a default quantization scheme of 0 and 7 for weight and bias (if bias is not None)

Pruning: 

        1. Looks for convolutional and linear layers and adds their name to the disable_pruning_for_layers list.

In [None]:
# Base config
pruning_method = "pdp"
config = get_default_config(pruning_method)
model = resnet20()


from pquant.core.compressed_layers import add_default_layer_quantization_pruning_to_config
config = add_default_layer_quantization_pruning_to_config(config, model)
config

In [None]:
# Save config
from pquant.core.utils import write_config_to_yaml
write_config_to_yaml(config, "prune_quantize_example.yaml", sort_keys=False)

Now that we have the custom config, it is up to us to modify the quantization bits for each layer that will not use the default value. If a layer uses the default value it can be removed from the ```layer_specific``` list.

For pruning, leave those layers to the ```disable_pruning_for_layers``` list that will not be pruned, others need to be removed from the list.

## About replacing layers and activations
Layers that can currently be compressed: ```nn.Conv1d, nn.Conv2d, nn.Linear```.

Activations that can currently be automatically be replaced with a quantized variant: ```nn.ReLU, nn.Tanh```. The activations are replaced by a quantized variant, found in ```pquant.core.activations_quantizer.py```.

## More about activations
If using layer type activations, note that if you want to keep the fine-grained control over the quantization of the activation, reusing an activation layer can cause problems, as all activations will use the quantization bits set for that particular layer. To avoid this, use a separate ```nn.Tanh``` / ```nn.ReLU``` for each activation