In [None]:
!pip install tensorboardX
!pip install pytorch_ignite

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/07/84/46421bd3e0e89a92682b1a38b40efc22dafb6d8e3d947e4ceefd4a5fabc7/tensorboardX-2.2-py2.py3-none-any.whl (120kB)
[K     |██▊                             | 10kB 22.4MB/s eta 0:00:01[K     |█████▍                          | 20kB 18.1MB/s eta 0:00:01[K     |████████▏                       | 30kB 14.8MB/s eta 0:00:01[K     |██████████▉                     | 40kB 13.6MB/s eta 0:00:01[K     |█████████████▋                  | 51kB 8.9MB/s eta 0:00:01[K     |████████████████▎               | 61kB 9.2MB/s eta 0:00:01[K     |███████████████████             | 71kB 9.6MB/s eta 0:00:01[K     |█████████████████████▊          | 81kB 10.7MB/s eta 0:00:01[K     |████████████████████████▌       | 92kB 10.0MB/s eta 0:00:01[K     |███████████████████████████▏    | 102kB 8.7MB/s eta 0:00:01[K     |██████████████████████████████  | 112kB 8.7MB/s eta 0:00:01[K     |████████████████████████████████| 122

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision import transforms

#from tensorboardX import SummaryWriter
#from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
#from ignite.metrics import Accuracy, Loss
#from ignite.contrib.handlers import ProgressBar

#from snip import SNIP
import copy
import types

In [None]:
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

LOG_INTERVAL = 20
INIT_LR = 0.1
WEIGHT_DECAY_RATE = 0.0005
EPOCHS = 70
REPEAT_WITH_DIFFERENT_SEED = 1

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
print(device)

cuda:0


In [None]:
def snip_forward_conv2d(self, x):
        return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                        self.stride, self.padding, self.dilation, self.groups)


def snip_forward_linear(self, x):
        return F.linear(x, self.weight * self.weight_mask, self.bias)


def SNIP(net, keep_ratio, train_dataloader, device):
    # TODO: shuffle?

    # Grab a single batch from the training dataset
    inputs, targets = next(iter(train_dataloader))
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Let's create a fresh copy of the network so that we're not worried about
    # affecting the actual training-phase
    net = copy.deepcopy(net)

    # Monkey-patch the Linear and Conv2d layer to learn the multiplicative mask
    # instead of the weights
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))
            nn.init.xavier_normal_(layer.weight)
            layer.weight.requires_grad = False

        # Override the forward methods:
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)

        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)

    # Compute gradients (but don't apply them)
    net.zero_grad()
    outputs = net.forward(inputs)
    loss = F.nll_loss(outputs, targets)
    loss.backward()

    grads_abs = []
    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            grads_abs.append(torch.abs(layer.weight_mask.grad))

    # Gather all scores in a single vector and normalise
    all_scores = torch.cat([torch.flatten(x) for x in grads_abs])
    norm_factor = torch.sum(all_scores)
    all_scores.div_(norm_factor)

    num_params_to_keep = int(len(all_scores) * keep_ratio)
    threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
    acceptable_score = threshold[-1]

    keep_masks = []
    for g in grads_abs:
        keep_masks.append(((g / norm_factor) >= acceptable_score).float())
        
    print(torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

    return (keep_masks)


In [None]:
def apply_prune_mask(net, keep_masks):

    # Before I can zip() layers and pruning masks I need to make sure they match
    # one-to-one by removing all the irrelevant modules:
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(
            layer, nn.Linear), net.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            """
            The hook function can't be defined directly here because of Python's
            late binding which would result in all hooks getting the very last
            mask! Getting it through another function forces early binding.
            """

            def hook(grads):
                return grads * keep_mask

            return hook

        # mask[i] == 0 --> Prune parameter
        # mask[i] == 1 --> Keep parameter

        # Step 1: Set the masked weights to zero (NB the biases are ignored)
        # Step 2: Make sure their gradients remain zero
        layer.weight.data[keep_mask == 0.] = 0.
        layer.weight.register_hook(hook_factory(keep_mask))

In [None]:
VGG_CONFIGS = {
    # M for MaxPool, Number for channels
    'D': [
        64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
        512, 512, 512, 'M'
    ],
}


class VGG_SNIP(nn.Module):
    """
    This is a base class to generate three VGG variants used in SNIP paper:
        1. VGG-C (16 layers)
        2. VGG-D (16 layers)
        3. VGG-like

    Some of the differences:
        * Reduced size of FC layers to 512
        * Adjusted flattening to match CIFAR-10 shapes
        * Replaced dropout layers with BatchNorm
    """

    def __init__(self, config, num_classes=10):
        super().__init__()

        self.features = self.make_layers(VGG_CONFIGS[config], batch_norm=True)

        self.classifier = nn.Sequential(
            nn.Linear(512, 512),  # 512 * 7 * 7 in the original VGG
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),  # instead of dropout
            nn.Linear(512, num_classes),
        )

    @staticmethod
    def make_layers(config, batch_norm=False):  # TODO: BN yes or no?
        layers = []
        in_channels = 3
        for v in config:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [
                        conv2d,
                        nn.BatchNorm2d(v),
                        nn.ReLU(inplace=True)
                    ]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)  
        x = F.log_softmax(x, dim=1)
        return x


In [None]:
def get_cifar10_dataloaders(train_batch_size, test_batch_size):

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = CIFAR10('_dataset', True, train_transform, download=True)
    test_dataset = CIFAR10('_dataset', False, test_transform, download=False)

    train_loader = DataLoader(
        train_dataset,
        train_batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True)
    test_loader = DataLoader(
        test_dataset,
        test_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    return train_loader, test_loader


In [None]:
def cifar10_experiment():
    
    BATCH_SIZE = 128
    LR_DECAY_INTERVAL = 20
    
    #net = VGG_SNIP('D').to(device)
    net = 
    optimiser = optim.SGD(
        net.parameters(),
        lr=INIT_LR,
        momentum=0.9,
        weight_decay=WEIGHT_DECAY_RATE)
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimiser, LR_DECAY_INTERVAL, gamma=0.1)
    
    train_loader, val_loader = get_cifar10_dataloaders(BATCH_SIZE,
                                                       BATCH_SIZE)  # TODO

    return net, optimiser, lr_scheduler, train_loader, val_loader

In [None]:
def training(epoch, model, optimizer, scheduler, criterion, device, train_loader):
  model.train()
  avg_loss = 0.0
  av_loss=0.0
  total=0
  for batch_num, (feats, labels) in enumerate(train_loader):
      feats, labels = feats.to(device), labels.to(device)
      
      optimizer.zero_grad()

      outputs = model(feats)


      loss = criterion(outputs, labels.long())
      loss.backward()
      
      optimizer.step()
      
      avg_loss += loss.item()
      av_loss += loss.item() 
      total +=len(feats) 
      # if batch_num % 10 == 9:
      #     print('Epoch: {}\tBatch: {}\tAv-Loss: {:.4f}'.format(epoch+1, batch_num+1, av_loss/10))
      #     av_loss = 0.0

      torch.cuda.empty_cache()
      del feats
      del labels
      del loss

  del train_loader

  return avg_loss/total
  


In [None]:
def validate(epoch, model, criterion, device, data_loader):
    with torch.no_grad():
        model.eval()
        running_loss, accuracy,total  = 0.0, 0.0, 0

        
        for i, (X, Y) in enumerate(data_loader):
            
            X, Y = X.to(device), Y.to(device)
            output= model(X)
            loss = criterion(output, Y.long())

            _,pred_labels = torch.max(F.softmax(output, dim=1), 1)
            pred_labels = pred_labels.view(-1)
            
            accuracy += torch.sum(torch.eq(pred_labels, Y)).item()

            running_loss += loss.item()
            total += len(X)

            torch.cuda.empty_cache()
            
            del X
            del Y
        
        return running_loss/total, accuracy/total

In [None]:
def train():

    writer = SummaryWriter()

    net, optimiser, lr_scheduler, train_loader, val_loader = cifar10_experiment()

    # Pre-training pruning using SKIP
    keep_masks = SNIP(net, 0.05, train_loader, device)  # TODO: shuffle?
    apply_prune_mask(net, keep_masks)
    
    """trainer = create_supervised_trainer(net, optimiser, F.nll_loss, device)
    evaluator = create_supervised_evaluator(net, {
        'accuracy': Accuracy(),
        'nll': Loss(F.nll_loss)
    }, device)

    pbar = ProgressBar()
    pbar.attach(trainer)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        lr_scheduler.step()
        iter_in_epoch = (engine.state.iteration - 1) % len(train_loader) + 1
        if engine.state.iteration % LOG_INTERVAL == 0:
            # pbar.log_message("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
            #       "".format(engine.state.epoch, iter_in_epoch, len(train_loader), engine.state.output))
            # writer.add_scalar("training/loss", engine.state.output,
            #                   engine.state.iteration)
            print("training/loss", engine.state.output,
                               engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_epoch(engine):
        evaluator.run(val_loader)

        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']

        # pbar.log_message("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        #       .format(engine.state.epoch, avg_accuracy, avg_nll))

        writer.add_scalar("validation/loss", avg_nll, engine.state.iteration)
        writer.add_scalar("validation/accuracy", avg_accuracy,
                          engine.state.iteration)

    trainer.run(train_loader, EPOCHS)

    # Let's look at the final weights
    # for name, param in net.named_parameters():
    #     if name.endswith('weight'):
    #         writer.add_histogram(name, param)

    writer.close()"""
    #training(net, optimiser, scheduler,F.nll_loss, device)



In [None]:
if __name__ == '__main__':

    for _ in range(REPEAT_WITH_DIFFERENT_SEED):
      #train()
      net, optimiser, lr_scheduler, train_loader, val_loader = cifar10_experiment()
      net = net.to(device)
      # Pre-training pruning using SKIP
      keep_masks = SNIP(net, 0.05, train_loader, device)  # TODO: shuffle?
      apply_prune_mask(net, keep_masks)
      
      criterion = nn.CrossEntropyLoss()

      for epoch in range(EPOCHS):
          train_loss = training(epoch, net, optimiser, lr_scheduler, criterion, device,train_loader)

          val_loss, val_acc = validate(epoch, net, criterion, device, val_loader)

          lr_scheduler.step()

          print('Epoch: {} \t train-Loss: {:.4f}, \tval-Loss: {:.4f}, \tval-acc: {:.4f}'.format(epoch+1,  train_loss, val_loss, val_acc))





        


Files already downloaded and verified
tensor(761993, device='cuda:0')
Epoch: 1 	 train-Loss: 0.0120, 	val-Loss: 0.0094, 	val-acc: 0.5663
Epoch: 2 	 train-Loss: 0.0078, 	val-Loss: 0.0070, 	val-acc: 0.6875
Epoch: 3 	 train-Loss: 0.0062, 	val-Loss: 0.0060, 	val-acc: 0.7340
Epoch: 4 	 train-Loss: 0.0053, 	val-Loss: 0.0061, 	val-acc: 0.7342
Epoch: 5 	 train-Loss: 0.0049, 	val-Loss: 0.0066, 	val-acc: 0.7186
Epoch: 6 	 train-Loss: 0.0046, 	val-Loss: 0.0055, 	val-acc: 0.7583
Epoch: 7 	 train-Loss: 0.0044, 	val-Loss: 0.0062, 	val-acc: 0.7342
Epoch: 8 	 train-Loss: 0.0042, 	val-Loss: 0.0065, 	val-acc: 0.7254
Epoch: 9 	 train-Loss: 0.0041, 	val-Loss: 0.0044, 	val-acc: 0.8062
Epoch: 10 	 train-Loss: 0.0040, 	val-Loss: 0.0048, 	val-acc: 0.7946
Epoch: 11 	 train-Loss: 0.0040, 	val-Loss: 0.0054, 	val-acc: 0.7779
Epoch: 12 	 train-Loss: 0.0039, 	val-Loss: 0.0049, 	val-acc: 0.7933
Epoch: 13 	 train-Loss: 0.0038, 	val-Loss: 0.0055, 	val-acc: 0.7692
Epoch: 14 	 train-Loss: 0.0037, 	val-Loss: 0.0053, 	val