In [1]:
# -*- coding: utf-8 -*-

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn


import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data

import os

import numpy as np

import model

# Data preparation & transforms

In [2]:
print('Preparing data..')
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform1 = transforms.Compose([
    transforms.RandomHorizontalFlip(1.),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset1 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainset2 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform1)
dataset = data.ConcatDataset([trainset1, trainset2])

trainloader = torch.utils.data.DataLoader(dataset, batch_size= 128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size= 128, shuffle=False, num_workers=4)
print('Finished')

Preparing data..
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Finished


# finetune

In [3]:
def finetune(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        optimizer.step()
        
        #zeroing mask
        for k, v in net.state_dict().items():
            if 'conv' in k:
                checkpoint[k] = v.cuda() * masks[k]
        net.load_state_dict(checkpoint)
        
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print("Finetune epoch: {:d} Acc: {:.3f} ({:d}/{:d})".format(epoch, 100.*correct/total, correct, total))

# test

In [4]:
def test():
#     net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    accuracy = 100.*correct/total
    print("Test Acc: {:.3f} ({:d}/{:d})".format(accuracy, correct, total))
    
    return accuracy

# prune a layer

In [5]:
def prune_layer(in_weight, prune_ratio):
    tres = -1e-4
    mask = (abs (in_weight) > tres)
    rat = 0
    while rat < np.array(prune_ratio):     
        mask = (abs (in_weight) > tres)
        TAR = float(mask.sum())/float(in_weight.nelement())
        rat = 1.0 - TAR
        tres += (1e-5)
    
    pruned_weight = in_weight*mask.float()
    

    return pruned_weight, mask.float()

In [None]:
# def prune_layer(in_weight, prune_ratio):
#     mask = torch.ones_like(in_weight)
#     mask[:, :int(mask.shape[1]*prune_ratio), :, :] = 0
#     pruned_weight = mask*in_weight
#     return pruned_weight, mask

# prune network

In [6]:
def prune(prune_ratio):
    masks = {}

    for k, v in net.state_dict().items():
        if 'conv' in k:
            #print("pruning layer:", k)
            weights=v
            weights, masks[k] = prune_layer(weights, prune_ratio)
            checkpoint[k] = weights
    net.load_state_dict(checkpoint)
    
    return masks

# print out network sparsity

In [7]:
def print_sparsity():
    num_el = 0
    num_zero = 0
    for k, v in net.state_dict().items():
        if 'conv' in k:
            num_el += v.numel()
            num_zero+=(v==0).sum().cpu().numpy()
#     print(f"num: {num_el} {num_zero}")
    sparsity = 100.*num_zero/num_el
    print("Sparsity: {:.3f}%".format(sparsity))
    return sparsity

# Preparing model and loading original checkpoint

In [71]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model
print('Building model..')
net = model.resnet18() 

net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)



# Load weights that TA provided
print('Resuming from checkpoint..')
checkpoint_path = "./checkpoint/resnet18_pruned.t7"
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint)
print("Finished")

Building model..
Resuming from checkpoint..
Finished


# See the original accuracy and sparsity

In [53]:
_ = test()
_ = print_sparsity()

Test Acc: 88.670 (8867/10000)
Sparsity: 0.000%


# Do the pruning & finetuning here

In [69]:
# (ratio, epoch)
sca = [(0.982, 20), (0.983, 20), (0.985, 20)]
# tar, bse = 0.97, 5
# for idx in range(6):
#     sca.append((tar*idx/5,idx*idx+5))
print(sca)
idx, acc, edx = 1, 1000, 0
for (ratio, epoch) in sca:
    masks = prune(ratio)
    print(f"Ratio: {ratio}, Epoch: {epoch}")
    _ = print_sparsity()
    edx += 1
    if (edx == 2):
        acc = 0
    for epoch in range(1, epoch+1):
        finetune(idx)
        idx += 1
        nacc = test()
        if (nacc >= acc):
            print("-"*10, f"new score {nacc}", "-"*10, "\nSaving model...")
            acc = nacc
        torch.save(net.state_dict(), './checkpoint/resnet18_pruned_{idx}.t7')


[(0.698, 1)]
Ratio: 0.698, Epoch: 1
Sparsity: 69.837%

Finetune epoch: 40 Acc: 88.562 (88562/100000)
Test Acc: 85.150 (8515/10000)
---------- new score 85.15 ---------- 
Saving model...


# I'll use the following functions to evaluate your pruned model
## Make sure your saved checkpoint can run this on 414 server

In [7]:
def rate_checkpoint():
    #resnet18_pruned will be changed to resnet18_fine_StudentID or resnet18_coarse_StudentID during evaluation
    path = "./checkpoint/resnet18_pruned.t7"
    net = model.resnet18()
    net.load_state_dict(torch.load(path))
    accuracy = test()
    sparsity = print_sparsity()
    #acc_threshold = 90 for fine-grained, = 85 for coarse-grained
    #spar_threshold = 70 for fine=grained, = 25 for coarse-grained
    acc_threshold = 85
    spar_threshold = 25
    if accuracy < acc_threshold or sparsity < spar_threshold:
        print("failed, accuracy = {:.3f}% sparsity = {:.3f}%".format(accuracy, sparsity))
    else:
        print("succeeded, accuracy = {:.3f}% sparsity = {:.3f}%".format(accuracy, sparsity))
    

In [None]:
rate_checkpoint()