In [1]:
from __future__ import print_function
import distiller 
import argparse
import numpy as np
import os, collections
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import time 

import models
from matplotlib import pyplot as plt

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR training')
parser.add_argument('--dataset', type=str, default='cifar100',
                    help='training dataset (default: cifar100)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=160, metavar='N',
                    help='number of epochs to train (default: 160)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save', default='./logs', type=str, metavar='PATH',
                    help='path to save prune model (default: current directory)')
parser.add_argument('--arch', default='vgg', type=str, 
                    help='architecture to use')
parser.add_argument('--depth', default=16, type=int,
                    help='depth of the neural network')
print('using GPU:', torch.cuda.is_available())
torch.manual_seed(1)

  from ._conv import register_converters as _register_converters


using GPU: True


<torch._C.Generator at 0x7f87d807a130>

In [5]:
# args = parser.parse_args()

# ======just for jupyter training purpose======
# args = parser.parse_args('--dataset cifar10 --arch vgg --depth 16 --start-epoch 160 --epochs 3 --resume /home/lning/PyTorch/model/vgg16/checkpoint.pth.tar --no-cuda'.split())
# args = parser.parse_args('--dataset cifar10 --arch vgg --depth 16 --start-epoch 0 --epochs 3 --no-cuda'.split())
args = parser.parse_args('--dataset cifar10 --arch vgg --depth 16 --start-epoch 0 --epochs 1'.split())
# ======just for jupyter training purpose======

args.cuda = not args.no_cuda and torch.cuda.is_available()
print('using GPU:', args.cuda)

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# create logs folder 
args.save = os.path.join(args.save, args.arch+str(args.depth), args.dataset, 'regulated_training')
if not os.path.exists(args.save):
    os.makedirs(args.save)
print('train logs will save to:', args.save)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'cifar10':
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(4),
                           transforms.RandomCrop(32),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
else:
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('./data.cifar100', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.Pad(4),
                           transforms.RandomCrop(32),
                           transforms.RandomHorizontalFlip(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)


using GPU: True
train logs will save to: ./logs/vgg16/cifar10/regulated_training
Files already downloaded and verified


In [6]:
model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))


In [7]:
# limit the quantized weight value to be -64 ~ 64. Downscale float weight accordingly
def regulate_quantized_weight():
    # print param and layer informations
#     for module_full_name, module in model.named_modules():
#         print (module_full_name)
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())
#         if param_tensor == 'feature.0.float_weight':
#             float_weight_p = model.state_dict()[param_tensor]
#             print(float_weight_p)
#         if param_tensor == 'feature.0.weight':
#             weight_p = model.state_dict()[param_tensor]
#             print(weight_p)
#         if param_tensor == 'feature.1.weight':
#             weight_n = model.state_dict()[param_tensor]
#             print(weight_n)
            
    layer_id = 0
    for param_tensor in model.state_dict():
        name_part = param_tensor.split('.')
        if 'float_weight' in name_part:
            print(param_tensor)
            quantized_weight_name = name_part[0]+'.'+name_part[1]+'.weight'
            quantized_weight_scale = name_part[0]+'.'+name_part[1]+'.weight_scale'
            quantized_weight_zero_point = name_part[0]+'.'+name_part[1]+'.weight_zero_point'
            float_weight = model.state_dict()[param_tensor]
            weight = model.state_dict()[quantized_weight_name]
            weight_scale = model.state_dict()[quantized_weight_scale]
            weight_zero_point = model.state_dict()[quantized_weight_zero_point]
#             if param_tensor == 'feature.0.float_weight':
#                 print('float_weight\n', float_weight)
#                 print('weight\n',weight)
#                 print('scale\n',weight_scale)
#                 print('zero_point\n', weight_zero_point)
        
            # regulate the weight
            weight_size = model.state_dict()[param_tensor].size()
            if np.size(weight_size) > 2:
                c = weight_size[0]
                h = weight_size[1]
                w = weight_size[2]
                change_idx_list_l = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) > 64)
                change_idx_list_s = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) < -64)
                for idx in change_idx_list_l:
                    if idx[0]*c+idx[1]*h+idx[2]*w % 8 != 7:
                        # scale the float_weight accordingly
                        float_weight[idx[0],idx[1],idx[2],idx[3]] = float_weight[idx[0],idx[1],idx[2],idx[3]] * \
                                torch.round(weight_scale * weight[idx[0],idx[1],idx[2],idx[3]] - weight_zero_point) / 64.
                        # down write the value to 64
                        weight[idx[0],idx[1],idx[2],idx[3]] = (64 + weight_zero_point)/weight_scale
                for idx in change_idx_list_s:
                    if idx[0]*c+idx[1]*h+idx[2]*w % 8 != 7:
                        # scale the float_weight accordingly
                        float_weight[idx[0],idx[1],idx[2],idx[3]] = float_weight[idx[0],idx[1],idx[2],idx[3]] * \
                                torch.round(weight_scale * weight[idx[0],idx[1],idx[2],idx[3]] - weight_zero_point) / (-64.)
                        # down write the value to 64
                        weight[idx[0],idx[1],idx[2],idx[3]] = (-64 + weight_zero_point)/weight_scale
            if np.size(weight_size) <= 2:
                c = weight_size[0]
                change_idx_list_l = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) > 64)
                change_idx_list_s = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) < -64)
                for idx in change_idx_list_l:
                    if idx[0]*c % 8 != 7:
                        # scale the float_weight accordingly
                        float_weight[idx[0],idx[1]] = float_weight[idx[0],idx[1]] * \
                                torch.round(weight_scale * weight[idx[0],idx[1]] - weight_zero_point) / 64.
                        # down write the value to 64
                        weight[idx[0],idx[1]] = (64 + weight_zero_point)/weight_scale
                for idx in change_idx_list_s:
                    if idx[0]*c % 8 != 7:
                        # scale the float_weight accordingly
                        float_weight[idx[0],idx[1]] = float_weight[idx[0],idx[1]] * \
                                torch.round(weight_scale * weight[idx[0],idx[1]] - weight_zero_point) / (-64.)
                        # down write the value to 64
                        weight[idx[0],idx[1]] = (-64 + weight_zero_point)/weight_scale
#             for idx in change_idx_list:
#                 if idx[0]*c+idx[1]*h+idx[2]*w % 8 != 7:
#                     # scale the float_weight accordingly
#                     float_weight[idx[0],idx[1],idx[2],idx[3]] = float_weight[idx[0],idx[1],idx[2],idx[3]] * \
#                             torch.round(weight_scale * weight[idx[0],idx[1],idx[2],idx[3]] - weight_zero_point) / 64.
#                     # down write the value to 64
#                     weight[idx[0],idx[1],idx[2],idx[3]] = (64 + weight_zero_point)/weight_scale
# #                     print(idx, weight_scale, weight_zero_point)
# #                     print(model.state_dict()[quantized_weight_name][0,0,0,1])
# #                     print(torch.round(weight_scale * weight[0,0,0,1] - weight_zero_point))
                    
# #                     print((64 + weight_zero_point)/weight_scale, model.state_dict()[quantized_weight_name][0,0,0,1]) 
# #                     print(weight[0,0,0,1])
# #             print(c,h,w)
# #             print(torch.round(weight_scale * weight[0,0,0,1] - weight_zero_point))
# #             print(np.nonzero(torch.round(weight_scale * weight - weight_zero_point) > 66))
        
            

In [8]:
def regulate_train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    print_flag = True
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.to(torch.device('cuda')), target.to(torch.device('cuda'))
        data, target = Variable(data), Variable(target)
        compression_scheduler.on_minibatch_begin(epoch, batch_idx, optimizer)
#         regulate_quantized_weight()
        optimizer.zero_grad()
        output = model(data)
#         loss = F.cross_entropy(output, target)
        loss = regulate_loss(output, target)
    
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        compression_scheduler.on_minibatch_end(epoch, batch_idx, 800)
            
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data))

In [9]:
def regulate_loss(output, target):
    print('weight_list\n',weight_name_list) 
    layer_num = len(weight_name_list)
    alpha = 0.00005
    beta = 0.0001
    norm_sum = F.cross_entropy(output, target)
    for layer in np.arange(layer_num):
        print(layer)
        print(len(weight_list))
        print(len(Z_list))
        print(len(U_list))
        norm_sum = norm_sum + alpha * torch.norm(weight_list[layer], 2)
        item = weight_list[layer] - Z_list[layer] + U_list[layer]
        norm_sum = norm_sum + beta * torch.norm(item, 2)
    return norm_sum

In [10]:
def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    print_flag = True
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.to(torch.device('cuda')), target.to(torch.device('cuda'))
        data, target = Variable(data), Variable(target)
#         quantizer = distiller.quantization.QuantAwareTrainRangeLinearQuantizer(model,optimizer)
#         quantizer.quantize_params()
        compression_scheduler.on_minibatch_begin(epoch, batch_idx, optimizer)
        regulate_quantized_weight()
        if print_flag:
            regulate_quantized_weight()
            print_flag = False
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        quantizer.quantize_params()
        compression_scheduler.on_minibatch_end(epoch, batch_idx, 800)

#         if epoch == args.start_epoch and batch_idx == 1:
#             # check the distribution of parameters 
#             thr = 16
#             layer_id = 0 
#             for param_name, param in model.named_parameters():
#                 if len(param.size()) < 2:
#                     continue
#                 counter = collections.Counter(np.abs(param.data.cpu().numpy().ravel())//thr)
#                 tmp = sorted(counter.items(), key=lambda x: x[0])
#                 values, counts = zip(*tmp)
#                 percentages = [count/sum(list(counts)) for count in counts]
#                 bar = plt.bar(values, percentages)
#                 for rect in bar:
#                     height = rect.get_height()
#                     plt.text(rect.get_x() + rect.get_width()/2.0, height, '%.4f%%' %(height*100), ha='center', va='bottom')
#             #     print(['%.2f' %(p) for p in percentages])
#                 #plt.hist(param.data.cpu().numpy().ravel(), bins=10, density=True)
#                 plt.xticks(values, [str(int(v)*thr+thr) for v in values])
#                 plt.title('layer_id:'+str(layer_id) + ', '+ str(tuple(param.size())))
#             #     plt.grid()
#                 plt.ylim(0, 1.1)
#                 plt.show()
            
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data))



def test():
    model.eval()
    test_loss = 0
#     correct = 0
    test_time = time.time() 
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').data # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

    test_loss /= len(test_loader.dataset)
    test_time = time.time() - time.time() 

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    # print(correct, len(test_loader.dataset))
    return correct*1.0 / len(test_loader.dataset)

def save_checkpoint(state, is_best, filepath):
    torch.save(state, os.path.join(filepath, 'checkpoint.pth.tar'))
    if is_best:
        shutil.copyfile(os.path.join(filepath, 'checkpoint.pth.tar'), os.path.join(filepath, 'model_best.pth.tar'))


In [11]:
def projection(weight, weight_scale, weight_zero_point, target):
    weight_size = weight.size()
    upper_value = (64 + weight_zero_point)/weight_scale
    lower_value = (-64 + weight_zero_point)/weight_scale
#     print(np.size(weight_size))
    if np.size(weight_size) > 2:
        c = weight_size[0]
        h = weight_size[1]
        w = weight_size[2]
        change_idx_list_l = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) > 64)
        change_idx_list_s = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) < -64)
        for idx in change_idx_list_l:
            if idx[0]*c+idx[1]*h+idx[2]*w % 8 != 7:
                # down write the value to 64
                weight[idx[0],idx[1],idx[2],idx[3]] = upper_value
        for idx in change_idx_list_s:
            if idx[0]*c+idx[1]*h+idx[2]*w % 8 != 7:
                # down write the value to 64
                weight[idx[0],idx[1],idx[2],idx[3]] = lower_value
    if np.size(weight_size) <= 2:
        c = weight_size[0]
        change_idx_list_l = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) > 64)
        print(change_idx_list_l)
        change_idx_list_s = np.nonzero(torch.round(weight_scale * weight - weight_zero_point) < -64)
        for idx in change_idx_list_l:
            if idx[0]*c % 8 != 7:
                # down write the value to 64
                weight[idx[0],idx[1]] = upper_value
        for idx in change_idx_list_s:
            if idx[0]*c % 8 != 7:
                # down write the value to 64
                weight[idx[0],idx[1]] = lower_value

In [12]:
def regulate_param_init(weight_name_list, weight_list, Z_list, U_list, target):
    # TODO: see if we want to use the quantized weight or float weight here for the regularization
    layer_id = 0
    for param_tensor in model.state_dict():
        name_part = param_tensor.split('.')
        if 'float_weight' in name_part:
            quantized_weight_name = name_part[0]+'.'+name_part[1]+'.weight'
            quantized_weight_scale = name_part[0]+'.'+name_part[1]+'.weight_scale'
            quantized_weight_zero_point = name_part[0]+'.'+name_part[1]+'.weight_zero_point'
            
            weight_name_list.append(quantized_weight_name)
            weight = model.state_dict()[quantized_weight_name]
            weight_scale = model.state_dict()[quantized_weight_scale]
            weight_zero_point = model.state_dict()[quantized_weight_zero_point]
            
            weight_list.append(weight)
#             print(weight_list)
            Z = weight.clone().detach()
            
#             print('here1')
#             Z_list.append(projection(Z, weight_scale, weight_zero_point, target))
            Z_list.append()
#             print('here2')
            U = np.zeros_like(Z)
            U_list.append(U)
            print(layer_id)
            layer_id = layer_id + 1
            
    print('weight_list\n',weight_name_list) 
    
    print(len(weight_name_list))
    print(len(weight_list))
    print(len(Z_list))
    print(len(U_list))
    print('layer_id')

In [13]:
def regulate_param_update(weight_name_list, weight_list, Z_list, U_list, target):
    weight_name_list = []
    weight_list = []
    Z_list_old = Z_list.copy()
    Z_list = []
    U_list_old = U_list.copy()
    U_list = []
    
    idx = 0
    for param_tensor in model.state_dict():
        name_part = param_tensor.split('.')
        if 'float_weight' in name_part:
            quantized_weight_name = name_part[0]+'.'+name_part[1]+'.weight'
            quantized_weight_scale = name_part[0]+'.'+name_part[1]+'.weight_scale'
            quantized_weight_zero_point = name_part[0]+'.'+name_part[1]+'.weight_zero_point'
            
            weight_name_list.append(quantized_weight_name)
            weight = model.state_dict()[quantized_weight_name]
            weight_scale = model.state_dict()[quantized_weight_scale]
            weight_zero_point = model.state_dict()[quantized_weight_zero_point]
            
            weight_list.append(weight)
            Z = weight + torch.from_numpy(U_list_old[idx])
#             Z = projection(Z, weight_scale, weight_zero_point, target)
            Z_list.append(Z)
            U = torch.from_numpy(U_list_old[idx]) + weight - Z
            U_list.append(U)
            idx += 1
    print('weight_list\n',weight_name_list) 
    print('idx: ', idx)

In [14]:
compression_scheduler = distiller.CompressionScheduler(model) 
compression_scheduler = distiller.file_config(model,optimizer, 
                                    '/home/lning/PyTorch/rethinking-network-pruning/cifar/l1-norm-pruning/quant_aware_training.yaml', 
                                              compression_scheduler, (args.start_epoch-1) if args.resume else None)

best_prec1 = 0.
print(args.start_epoch)
print(args.epochs)
enter = 0
weight_name_list = []
weight_list = []
Z_list = []
U_list = []
target = [-64, 64]
for epoch in range(args.start_epoch, args.start_epoch+args.epochs):
    compression_scheduler.on_epoch_begin(epoch)
#     quantizer = distiller.quantization.QuantAwareTrainRangeLinearQuantizer(model,optimizer,bits_activations=8,
#                                                                   bits_weights=8,
#                                                                   bits_bias=8,
#                                                                   overrides=None)
#     quantizer.prepare_model()
    # initialize the regulated parameters
    print('initialize the regulated parametes')
    regulate_param_init(weight_name_list, weight_list, Z_list, U_list, target)

    if epoch in [args.epochs*0.5, args.epochs*0.75]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
    train_time = time.time()
    print('begin the regulated training')
    regulate_train(epoch)
#     train(epoch)
    
    # compute regularization parameters
    regulate_param_update(weight_name_list, weight_list, Z_list, U_list, target)
        
    train_time = time.time() - train_time 
    prec1 = test()
    # print(prec1)
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
        'cfg': model.cfg
    }, is_best, filepath=args.save)
    compression_scheduler.on_epoch_end(epoch,optimizer)
    print('Epoch: %d, train_time: %.2f (min), prec1: %f, best_prec1: %f\n' %(epoch, train_time/60.0, prec1, best_prec1))

# with open(os.path.join(args.save, 'train.txt'), 'w') as f:
#     f.write('Epoch: %d, train_time: %.2f (min), prec1: %f, best_prec1: %f\n' %(epoch, train_time/60.0, prec1, best_prec1))


TypeError: file_config() takes from 3 to 4 positional arguments but 5 were given

In [None]:
a = 63*64+2*3+2*3

In [None]:
print(a%8)