In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from poutyne import Model
from poutyne import CosineAnnealingLR
from callbacks import WandbCallback, DummySyncCallback, MasterDummySyncCallback
# from compression import UINT8Compressor

device= 'cuda:0'

In [2]:
import torch
import numpy as np
from scipy import ndimage
from poutyne import torch_to_numpy, numpy_to_torch


def fast_quantile_encode(weight, window_size=4.):
    scale = 128 / (window_size * weight.std())

    scaled_weight = scale * weight
    quant_weight = scaled_weight.astype('int8')

    quant_weight[scaled_weight > 127] = 127
    quant_weight[scaled_weight < -128] = -128

    quant_weight = quant_weight.astype('uint8')

    lookup = ndimage.mean(weight, labels=quant_weight, index=np.arange(256))
    lookup[np.isnan(lookup)] = 0.
    lookup = lookup.astype('float32')
    return quant_weight, lookup


UNIFORM_BUCKETS_STD_RANGE = 6
UINT8_RANGE = 256



def average_buckets(tensor, quant_weight, n_bins):
    bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
    
    bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
    lookup = bin_sums / bin_counts
    return quant_weight, lookup

def uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
    tensor= tensor.cpu()
    offset = UINT8_RANGE // 2
    shift = tensor.mean()
    scale = range_in_sigmas * tensor.std() / UINT8_RANGE

    quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()

    quant_weight, lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)
    return quant_weight, lookup



class UINT8Compressor(object):
    def __init__(self, parameter_names):
        self.parameter_names = parameter_names

    def encode(self, weight):
        with torch.no_grad():
            quant_weight,lookup = uint8_uniform_buckets_encode(weight,UNIFORM_BUCKETS_STD_RANGE)
            return dict(quant_weight=quant_weight, lookup=lookup)

    def decode(self, encoded):
        quant_weight, lookup = encoded['quant_weight'], encoded['lookup']
        return lookup[quant_weight.long()].float()

    def serialize(self, state_dict):
        for name in self.parameter_names:
            state_dict[name] = self.encode(state_dict[name])
        return state_dict

    def deserialize(self, state_dict):
        for name in self.parameter_names:
            state_dict[name] = self.decode(state_dict[name])
        return state_dict

In [3]:
# layers_to_compress = \
# ['layer1.0.conv1.weight',
#  'layer1.0.conv2.weight',
#  'layer1.0.conv3.weight',
#  'layer1.0.shortcut.0.weight',
#  'layer1.1.conv1.weight',
#  'layer1.1.conv2.weight',
#  'layer1.1.conv3.weight',
#  'layer1.2.conv1.weight',
#  'layer1.2.conv2.weight',
#  'layer1.2.conv3.weight',
#  'layer2.0.conv1.weight',
#  'layer2.0.conv2.weight',
#  'layer2.0.conv3.weight',
#  'layer2.0.shortcut.0.weight',
#  'layer2.1.conv1.weight',
#  'layer2.1.conv2.weight',
#  'layer2.1.conv3.weight',
#  'layer2.2.conv1.weight',
#  'layer2.2.conv2.weight',
#  'layer2.2.conv3.weight',
#  'layer2.3.conv1.weight',
#  'layer2.3.conv2.weight',
#  'layer2.3.conv3.weight',
#  'layer3.0.conv1.weight',
#  'layer3.0.conv2.weight',
#  'layer3.0.conv3.weight',
#  'layer3.0.shortcut.0.weight',
#  'layer3.1.conv1.weight',
#  'layer3.1.conv2.weight',
#  'layer3.1.conv3.weight',
#  'layer3.2.conv1.weight',
#  'layer3.2.conv2.weight',
#  'layer3.2.conv3.weight',
#  'layer3.3.conv1.weight',
#  'layer3.3.conv2.weight',
#  'layer3.3.conv3.weight',
#  'layer3.4.conv1.weight',
#  'layer3.4.conv2.weight',
#  'layer3.4.conv3.weight',
#  'layer3.5.conv1.weight',
#  'layer3.5.conv2.weight',
#  'layer3.5.conv3.weight',
#  'layer4.0.conv1.weight',
#  'layer4.0.conv2.weight',
#  'layer4.0.conv3.weight',
#  'layer4.0.shortcut.0.weight',
#  'layer4.1.conv1.weight',
#  'layer4.1.conv2.weight',
#  'layer4.1.conv3.weight',
#  'layer4.2.conv1.weight',
#  'layer4.2.conv2.weight',
#  'layer4.2.conv3.weight',
#  'linear.weight']

layers_to_compress = list(ResNet50().state_dict().keys())
layers_to_compress = [l for l in layers_to_compress if  'num_batches_tracked' not in l]

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='/storage/cifar10', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='/storage/cifar10', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified
Files already downloaded and verified


In [5]:
net = ResNet50()
net.load_state_dict(
    torch.load('/storage/monty/resnet_compress/resnet50_init.pth')
)

<All keys matched successfully>

In [6]:
# !ls /tmp/resnet50_runs
# !mkdir /storage/monty/resnet50_runs

In [7]:
compressor = UINT8Compressor(layers_to_compress)

In [8]:
# !mkdir /tmp/resnet50_runs

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=5e-4)

scheduler = CosineAnnealingLR(T_max=200)
wandb_callback = WandbCallback(
    project="resnet_grad_compress", 
    entity="montyponty",
    name='int8_avg_compress_not_all_master',
#     prefix='master'
)
sync_callback = MasterDummySyncCallback(
    save_dir =  '/tmp/resnet50_runs',
    serialize_fn = compressor.serialize,
    deserialize_fn = compressor.deserialize,
    rank_id=0,
    n_workers=2,
    period=8
)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mmontyponty[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.25 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
model = Model(
    net,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model.to(device)
history = model.fit_generator(
    trainloader,
    testloader,
    epochs=200,
    callbacks=[scheduler, wandb_callback, sync_callback]
)

[93mEpoch: [94m1/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m163.10s [93mloss:[96m 1.934945[93m acc:[96m 31.802000[93m val_loss:[96m 1.577674[93m val_acc:[96m 42.490000[0m
[93mEpoch: [94m2/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m166.74s [93mloss:[96m 1.287786[93m acc:[96m 53.620000[93m val_loss:[96m 1.270575[93m val_acc:[96m 58.230000[0m
[93mEpoch: [94m3/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m167.78s [93mloss:[96m 0.961316[93m acc:[96m 66.060000[93m val_loss:[96m 0.844800[93m val_acc:[96m 70.650000[0m
[93mEpoch: [94m4/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m167.94s [93mloss:[96m 0.738599[93m acc:[96m 74.404000[93m val_loss:[96m 0.806705[93m val_acc:[96m 73.850000[0m
[93mEpoch: [94m5/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m168.17s [

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[93mEpoch: [94m26/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m168.64s [93mloss:[96m 0.115382[93m acc:[96m 95.920000[93m val_loss:[96m 0.439988[93m val_acc:[96m 87.780000[0m
[93mEpoch: [94m27/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m167.34s [93mloss:[96m 0.106873[93m acc:[96m 96.246000[93m val_loss:[96m 0.344857[93m val_acc:[96m 90.380000[0m
[93mEpoch: [94m28/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m168.13s [93mloss:[96m 0.097955[93m acc:[96m 96.608000[93m val_loss:[96m 0.352455[93m val_acc:[96m 90.200000[0m
[93mEpoch: [94m29/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m167.96s [93mloss:[96m 0.102643[93m acc:[96m 96.476000[93m val_loss:[96m 0.370012[93m val_acc:[96m 89.390000[0m
[93mEpoch: [94m30/200 [93mStep: [94m391/391 [93m100.00% |[92m█████████████████████████[93m|[32m168.6