The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks, 19`ICLR
===
modified from https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch

## Setting

In [1]:
# Importing Libraries
import argparse
import copy
import os
import sys
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import os
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import seaborn as sns
import torch.nn.init as init
import pickle

# Custom Libraries
import utils

# Tensorboard initialization
writer = SummaryWriter()

# Plotting Style
sns.set_style('darkgrid')

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--lr",default= 1.2e-3, type=float, help="Learning rate")
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--start_iter", default=0, type=int)
parser.add_argument("--end_iter", default=10, type=int)
parser.add_argument("--print_freq", default=1, type=int)
parser.add_argument("--valid_freq", default=1, type=int)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--prune_type", default="lt", type=str, help="lt | reinit")
parser.add_argument("--gpu", default="0", type=str)
parser.add_argument("--dataset", default="cifar10", type=str, help="mnist | cifar10 | fashionmnist | cifar100")
parser.add_argument("--arch_type", default="resnet18", type=str, help="fc1 | lenet5 | alexnet | vgg16 | resnet18 | densenet121")
parser.add_argument("--prune_percent", default=10, type=int, help="Pruning percent")
parser.add_argument("--prune_iterations", default=5, type=int, help="Pruning iterations count")


args = parser.parse_args("")


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu


#FIXME resample
resample = False

prune_type == "lt" refers to Lottery Ticket Hypothesis while "reinit" refers to reinitialization

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reinit = True if args.prune_type=="reinit" else False

again, we're going to use CIFAR10

In [4]:
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.CIFAR10('../datasets/', train=True, download=True,transform=transform)
testdataset = datasets.CIFAR10('../datasets/', train=False, transform=transform)      
from archs.cifar10 import AlexNet, LeNet5, fc1, vgg, resnet, densenet 

Files already downloaded and verified


In [5]:
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=args.batch_size, shuffle=True, num_workers=0,drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=args.batch_size, shuffle=False, num_workers=0,drop_last=True)

In [6]:
global model
model = resnet.resnet18().to(device)

## defining functions:

1. Weight Initalization

In [7]:
def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)

In [8]:
model.apply(weight_init)

# Copying and Saving Initial State
initial_state_dict = copy.deepcopy(model.state_dict())
utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
torch.save(model, f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/initial_state_dict_{args.prune_type}.pth.tar")

recall that in the paper, they emphasized the importance of reusing the initialization of the original model weight

2. Mask Initialization

In [9]:
# Function to make an empty mask of the same size as the model
def make_mask(model):
    global step
    global mask
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            step = step + 1
    mask = [None]* step # first, create a stack with None; each element would correspond to each weight in the model
    
    step = 0
    for name, param in model.named_parameters(): 
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            mask[step] = np.ones_like(tensor) # then we filled up with tensor shape same as the corresponding weight, with value 1; just like how we did in the previous example
            step = step + 1
    step = 0

In [10]:
make_mask(model)

the only difference is they made the mask outside the model whereas the previous implementation puts mask in the model

3. defining optimizer and loss

In [11]:
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
criterion = nn.CrossEntropyLoss() # Default was F.nll_loss

4. function to recover the original initial weight

In [12]:
def original_initialization(mask_temp, initial_state_dict):
    global model
    
    step = 0
    for name, param in model.named_parameters(): 
        if "weight" in name: 
            weight_dev = param.device
            param.data = torch.from_numpy(mask_temp[step] * initial_state_dict[name].cpu().numpy()).to(weight_dev)
            step = step + 1
        if "bias" in name:
            param.data = initial_state_dict[name]
    step = 0

5. functions for training and testing

In [13]:
def train(model, train_loader, optimizer, criterion):
    EPS = 1e-6
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        #imgs, targets = next(train_loader)
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()

        # Freezing Pruned weights by making their gradients Zero
        for name, p in model.named_parameters():
            if 'weight' in name:
                tensor = p.data.cpu().numpy()
                grad_tensor = p.grad.data.cpu().numpy()
                grad_tensor = np.where(tensor < EPS, 0, grad_tensor) # if the weight is lower than EPS, just set to 0
                p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
    return train_loss.item()

def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy


6. function for pruning by percentile

In [14]:
def prune_by_percentile(percent, resample=False, reinit=False,**kwargs):
    global step
    global mask
    global model

    # Calculate percentile value
    step = 0
    for name, param in model.named_parameters():

        # We do not prune bias term
        if 'weight' in name:
            tensor = param.data.cpu().numpy()
            alive = tensor[np.nonzero(tensor)] # flattened array of nonzero values
            percentile_value = np.percentile(abs(alive), percent)

            # Convert Tensors to numpy and calculate
            weight_dev = param.device
            new_mask = np.where(abs(tensor) < percentile_value, 0, mask[step])

            # Apply new weight and mask
            param.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
            mask[step] = new_mask
            step += 1
    step = 0

## Pruning:

In [15]:
# NOTE First Pruning Iteration is of No Compression
bestacc = 0.0
best_accuracy = 0
ITERATION = args.prune_iterations
comp = np.zeros(ITERATION,float)
bestacc = np.zeros(ITERATION,float)
step = 0
all_loss = np.zeros(args.end_iter,float)
all_accuracy = np.zeros(args.end_iter,float)

simplification of code below <br>
<br>
for (pruning iteration):<br>
 >   if first, pass else prune <br>
  >  for (training iteration):<br>
   >>     train the model<br>
     >>   compute loss and accuracy<br>
  >  plot 

In [None]:
for _ite in range(args.start_iter, ITERATION):
    if not _ite == 0: # we first train;, then prune
        prune_by_percentile(args.prune_percent, resample=resample, reinit=reinit)
        if reinit:
            model.apply(weight_init)
            step = 0
            for name, param in model.named_parameters():
                if 'weight' in name:
                    weight_dev = param.device
                    param.data = torch.from_numpy(param.data.cpu().numpy() * mask[step]).to(weight_dev)
                    step = step + 1
            step = 0
        else:
            original_initialization(mask, initial_state_dict)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
    print(f"\n--- Pruning Level [{1}:{_ite}/{ITERATION}]: ---")

    # Print the table of Nonzeros in each layer
    comp1 = utils.print_nonzeros(model)
    comp[_ite] = comp1
    pbar = tqdm(range(args.end_iter))

    for iter_ in pbar:

        # Frequency for Testing
        if iter_ % args.valid_freq == 0:
            accuracy = test(model, test_loader, criterion)

            # Save Weights
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                utils.checkdir(f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/")
                torch.save(model,f"{os.getcwd()}/saves/{args.arch_type}/{args.dataset}/{_ite}_model_{args.prune_type}.pth.tar")

        # Training
        loss = train(model, train_loader, optimizer, criterion)
        all_loss[iter_] = loss
        all_accuracy[iter_] = accuracy

        # Frequency for Printing Accuracy and Loss
        if iter_ % args.print_freq == 0:
            pbar.set_description(
                f'Train Epoch: {iter_}/{args.end_iter} Loss: {loss:.6f} Accuracy: {accuracy:.2f}% Best Accuracy: {best_accuracy:.2f}%')       

    writer.add_scalar('Accuracy/test', best_accuracy, comp1)
    bestacc[_ite]=best_accuracy
    
    # Plotting Loss (Training), Accuracy (Testing), Iteration Curve
    #NOTE Loss is computed for every iteration while Accuracy is computed only for every {args.valid_freq} iterations. Therefore Accuracy saved is constant during the uncomputed iterations.
    #NOTE Normalized the accuracy to [0,100] for ease of plotting.
    plt.plot(np.arange(1,(args.end_iter)+1), 100*(all_loss - np.min(all_loss))/np.ptp(all_loss).astype(float), c="blue", label="Loss") 
    plt.plot(np.arange(1,(args.end_iter)+1), all_accuracy, c="red", label="Accuracy") 
    plt.title(f"Loss Vs Accuracy Vs Iterations ({args.dataset},{args.arch_type})") 
    plt.xlabel("Iterations") 
    plt.ylabel("Loss and Accuracy") 
    plt.legend() 
    plt.grid(color="gray") 
    utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
    plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_LossVsAccuracy_{comp1}.png", dpi=1200) 
    plt.close()

    # Dump Plot values
    utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
    all_loss.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_loss_{comp1}.dat")
    all_accuracy.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_all_accuracy_{comp1}.dat")

    # Dumping mask
    utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
    with open(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_mask_{comp1}.pkl", 'wb') as fp:
        pickle.dump(mask, fp)

    # Making variables into 0
    best_accuracy = 0
    all_loss = np.zeros(args.end_iter,float)
    all_accuracy = np.zeros(args.end_iter,float)


  0%|          | 0/10 [00:00<?, ?it/s]


--- Pruning Level [1:0/5]: ---
conv1.weight         | nonzeros =    1728 /    1728 (100.00%) | total_pruned =       0 | shape = (64, 3, 3, 3)
bn1.weight           | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
bn1.bias             | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 | shape = (64,)
layer1.0.conv1.weight | nonzeros =   36864 /   36864 (100.00%) | total_pruned =       0 | shape = (64, 64, 3, 3)
layer1.0.bn1.weight  | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
layer1.0.bn1.bias    | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 | shape = (64,)
layer1.0.conv2.weight | nonzeros =   36864 /   36864 (100.00%) | total_pruned =       0 | shape = (64, 64, 3, 3)
layer1.0.bn2.weight  | nonzeros =      64 /      64 (100.00%) | total_pruned =       0 | shape = (64,)
layer1.0.bn2.bias    | nonzeros =       0 /      64 (  0.00%) | total_pruned =      64 | shape = (64,)
layer1.1.conv

Train Epoch: 1/10 Loss: 2.078833 Accuracy: 13.53% Best Accuracy: 13.53%:  20%|██        | 2/10 [04:36<18:25, 138.16s/it]

## Plot 

In [None]:
# Dumping Values for Plotting
utils.checkdir(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/")
comp.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_compression.dat")
bestacc.dump(f"{os.getcwd()}/dumps/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_bestaccuracy.dat")

# Plotting
a = np.arange(args.prune_iterations)
plt.plot(a, bestacc, c="blue", label="Winning tickets") 
plt.title(f"Test Accuracy vs Unpruned Weights Percentage ({args.dataset},{args.arch_type})") 
plt.xlabel("Unpruned Weights Percentage") 
plt.ylabel("test accuracy") 
plt.xticks(a, comp, rotation ="vertical") 
plt.ylim(0,100)
plt.legend() 
plt.grid(color="gray") 
utils.checkdir(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/")
plt.savefig(f"{os.getcwd()}/plots/lt/{args.arch_type}/{args.dataset}/{args.prune_type}_AccuracyVsWeights.png", dpi=1200) 
plt.close() 