# SynFlow Notebook

In [None]:
import argparse
import json
import os

In [None]:
class Args:
    dataset = 'mnist' #mnist, tiny-imagenet, cifar10, cifar100
    model = 'conv' # fc, vgg11, resnet20 or resnet18 (tiny)
#     model_class = 'lottery' #lottery (vgg11, resnet20), tinyimagenet (vgg11, resnet18)
    dense_classifier = False
    pretrained = False
    optimizer = 'adam'
    train_batch_size = 64
    test_batch_size = 256
    train_epochs = 2
    # what about stopping the IMP when the sparsity is reached?
    lr = 0.001
    lr_drops = tuple([])
    lr_drop_rate = 0.1
    weight_decay = 0.0
    pruner= 'NP'
    compression= 10.0
    iter_comp = .2 # % retained values per iteration
    compression_schedule= 'exponential'
    mask_scope= 'global'
    prune_dataset_ratio= 10
    prune_batch_size= 256
    prune_bias= False
    prune_batchnorm= False
    prune_residual= False
    prune_train_mode= False
    reinitialize= False
    shuffle= False
    invert= False
    experiment = 'BK'
    expid = True
    result_dir = 'Results/data'
    gpu = 0
    workers =4
    seed = 1
    no_cuda= True#'store_true'
    verbose= True#'store_true'
    trial = 0
    save = True
    verbose = True
    # Extra arguments for rewinding, Renda et al 2020(1)
    rewind = 'weight' # Choose your own rewinding adventure! ('LR', 'weight', 'NP')
    ## Only pertinent for traditional rewinding as seen in Renda et al.
    rewind_epochs = 2 # how far back to rewind? (Only for weight rewinding)

In [None]:
args = Args()
if args.dataset == 'mnist':
    setattr(args, 'model_class', 'default')
elif args.dataset == 'cifar10' or args.dataset == 'cifar100':
    setattr(args, 'model_class', 'lottery')
elif args.dataset == 'tiny-imagenet':
    setattr(args, 'model_class', 'tinyimagenet')

In [None]:
# ## In case of argument mistakes
# if args.rewind == None:
#     args.rewind = 'None'
#     args.prune_epochs = None
#     args.rewind_epochs = None
# elif args.rewind == 'NP':
#     args.rewind_train = None
#     args.rewind_epochs = None

In [None]:
## Construct Result Directory ##
if args.expid == False:
    print("WARNING: this experiment is not being saved.")
    setattr(args, 'save', False)
else:
    expid = args.rewind+'_'+args.pruner+'_'+args.dataset+'_'+args.model+'_trial_'+str(args.trial)
    result_dir = '{}/{}/{}'.format(args.result_dir, args.experiment, expid)
    setattr(args, 'save', True)
    setattr(args, 'result_dir', result_dir)
    os.makedirs(result_dir, exist_ok = True)
            
print('Expt ID: ' + expid) 
print('Pruner: ' + args.pruner)
print('Rewind Method: ' + args.rewind)
# print('Train epochs before rewind: ' + str(args.prune_epochs))
# print('Epochs to rewind: ' + str(args.rewind_epochs))

In [None]:
## Save Args ##
if args.save:
    with open(args.result_dir + '/args.json', 'w') as f:
        json.dump(args.__dict__, f, sort_keys=True, indent=4)

## Run Experiment ##
import numpy as np
best_acc = -np.Inf

"""
Custom experiment for the work done by Balwani & Krzyston 2021
Based off of singleshot.py & multishot.py seen in the Ganguli Lab SynFlow repo
"""
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from Utils import load
from Utils import generator
from Utils import metrics_Jay
from train import *
from prune import *
from Experiments import singleshot

#     if args.rewind == None:
#         singleshot.run(args)

#     else:

## Random Seed and Device ##
torch.manual_seed(args.seed)
device = load.device(args.gpu)

## Data ##
print('Loading {} dataset.'.format(args.dataset))
input_shape, num_classes = load.dimension(args.dataset) 
prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

## Model ##
print('Creating {} model.'.format(args.model))
model = load.model(args.model, args.model_class)(input_shape, 
                                                 num_classes, 
                                                 args.dense_classifier,
                                                 args.pretrained).to(device)

## Compute NP ratios ##
args.compression_list, layers_n_shapes, total_comp = metrics_Jay.eta_c_compute(args.model, model, args.dataset, input_shape, args.gpu, verbose = True)
# loss = nn.CrossEntropyLoss()
# opt_class, opt_kwargs = load.optimizer(args.optimizer)
# optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

# ## Save Initialiized Weights ##
# torch.save(model.state_dict(),"{}/model_init.pt".format(args.result_dir))
# torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
# torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))


# # Train to completion
# print('Training for {} epochs'.format(str(args.train_epochs)))
# for epoch in range(args.train_epochs):            
#     # Train
#     model.train()
#     total = 0
#     for batch_idx, (data, target) in enumerate(train_loader):
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         train_loss = loss(output, target)
#         total += train_loss.item() * data.size(0)
#         train_loss.backward()
#         optimizer.step()
#     print('Train Epoch: {} \tLoss: {:.6f}'.format(
#         epoch, train_loss.item()))

#     # Eval
#     model.eval()
#     total = 0
#     correct1 = 0
#     correct5 = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             output = model(data)
#             total += loss(output, target).item() * data.size(0)
#             _, pred = output.topk(5, dim=1)
#             correct = pred.eq(target.view(-1, 1).expand_as(pred))
#             correct1 += correct[:,:1].sum().item()
#             correct5 += correct[:,:5].sum().item()
#     average_loss = total / len(test_loader.dataset)
#     accuracy1 = 100. * correct1 / len(test_loader.dataset)
#     accuracy5 = 100. * correct5 / len(test_loader.dataset)
#     print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
#             average_loss, correct1, len(test_loader.dataset), accuracy1))

#     # Save weights if best performance
#     if accuracy1 > best_acc:
#         torch.save(model.state_dict(),"{}/model_best.pt".format(args.result_dir))
#         best_acc = accuracy1

In [None]:
"""
TODO:
- get it from aish
- get it tunning our own way
- run experiments
"""

if args.pruner == 'NP':
#     pruner = 
    continue
elif:
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))

In [None]:
comp = 1
comp -= args.iter_comp*comp
print('Sparsity = {}'.format(str(round(1-comp,4))))
pruner.score(model, loss, test_loader, device, in_out_sizes)
pruner.mask(1 - comp, args.mask_scope)

In [None]:
# for mask, param in pruner.masked_parameters:
#     score = pruner.scores[id(param)]
#     k = int((1.0 - sparsity) * score.numel())
#     if not k < 1:
#         threshold, _ = torch.kthvalue(torch.flatten(score), k)
#         zero = torch.tensor([0.]).to(mask.device)
#         one = torch.tensor([1.]).to(mask.device)
#         mask.copy_(torch.where(score <= threshold, zero, one))

In [None]:
print(param.shape)

In [None]:
in_out_shapes

In [None]:
for p in model.parameters():
    if p.requires_grad:
        print(p.numel())

In [None]:
p.numel()

In [None]:
for name, param in model.named_parameters():
    print(param.shape)

In [None]:
for name, param in model.named_parameters():
    if name == 'conv.weight':
        print(name)
        torch.save(param, './test_conv.pt')
    if name == 'fc.weight':
        print(name)
        torch.save(param, './test_dense.pt')

#### TODO #####
- prune by layer
- load compression ratios
- this method, this ratio @ this layers
- module.kernel_size
- compute net compression ratio

In [None]:
# eta_c, in_out_shapes = eta_c_compute(args.model, args.dataset, input_shape, verbose = True)

In [None]:
for name, module in model.named_modules():
    if 'conv' in name:
        break    

In [None]:
module.kernel_size

In [None]:
shape = tuple([1,input_shape[0],input_shape[1],input_shape[2]])
out = torch.rand((shape)) 
outs = []
eta_c = []
ind = 0
for name, module in model.named_modules():
    import torch.nn.functional as F
    mod = module.eval()
    outs.append(out.shape)
    if ind+1 < len(names):
        if 'conv' in name:
            out = mod(torch.tensor(out).float().to(device))
            if outs[-1][2] != out.shape[2]: # account for downsampling
                outs[-1] = torch.rand((1,int(outs[-1][1]*2),int(outs[-1][2]/2),int(outs[-1][3]/2))).shape
            m = (outs[-1][2]+2)*(outs[-1][3]+2) #dimensions of input
            n = out.shape[2]*out.shape[3] #dimensions of output
            eta_c.append((n*(3**2))/(m+n-1)) # kernel size is 3
            ind += 1
        if 'fc' in name: # otherwise it's the dense layer
            out = F.avg_pool2d(out, out.size()[3])
            out = out.view(out.size(0), -1)
            out = mod(torch.tensor(out).float().to(device))
            eta_c.append((module.in_features*module.out_features)/(module.in_features+module.out_features-1))

In [None]:
eta_c

In [None]:
outs[-1]

In [None]:
torch.rand((1,int(outs[-1][1]*2),int(outs[-1][2]/2),int(outs[-1][3]/2))).shape

In [None]:
out = torch.rand((shape))  
for name, module in model.named_modules():
        if 'classifier' in name:
            if name[-1].isnumeric():
                if name[-1] == '0':
                    out = out.view(out.size(0), -1)
                if '0' in name or '3' in name or '6' in name:
                    print(name)
                    mod = module.eval()
                    print("Dense In: " + str(out.shape[-1]))
                    out = mod(out.float().to(device))
                    print("Dense Out: " + str(out.shape[-1])) 
        elif 'features' in name:
            if name[-1].isnumeric():
                print(name)
                mod = module.eval()
                out = mod(out.float().to(device))
                print(out.shape)

In [None]:
names

In [None]:
out = torch.rand((shape))  
avg_pool = nn.AdaptiveAvgPool2d((1, 1))
for name, module in model.named_modules():
    if name != '' and 'residual' not in name:
        if '0' in name and '_x' not in name and '.' in name: 
            # cover the very first conv layer
            print(name)
            mod = module.eval()
            out = mod(out.float().to(device))
            print(out.shape)
        elif '_x' in name and 'residual' not in name and 'shortcut' not in name and name[-1].isnumeric():
            # every other conv layer that is not in residual nor shortut
            print(name)
            mod = module.eval()
            out = mod(out.float().to(device))
            print(out.shape)
        elif name == 'fc':
            # dense layers, will need to change with different sized resnets 
            # (number of dense layers will change)
            print(name)
            out = avg_pool(out)
            out = out.view(out.size(0), -1)
            mod = module.eval()
            print("Dense In: " + str(out.shape[-1]))
            out = mod(out.float().to(device))
            print("Dense Out: " + str(out.shape[-1])) 

In [None]:
module

In [None]:
model

In [None]:
model.state_dict()[list(model.state_dict().keys())[i]].shape

In [None]:
# import numpy as np
# import pandas as pd
# import torch
# import torch.nn as nn
# from Utils import load
# from Utils import generator
# from Utils import metrics_Jay
# from train import *
# from prune import *

# ## Random Seed and Device ##
# torch.manual_seed(args.seed)
# device = load.device(args.gpu)

# ## Data ##
# print('Loading {} dataset.'.format(args.dataset))
# input_shape, num_classes = load.dimension(args.dataset) 
# prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
# train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
# test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

# ## Model ##
# print('Creating {} model.'.format(args.model))
# model = load.model(args.model, args.model_class)(input_shape, 
#                                                  num_classes, 
#                                                  args.dense_classifier,
#                                                  args.pretrained).to(device)

# loss = nn.CrossEntropyLoss()
# opt_class, opt_kwargs = load.optimizer(args.optimizer)
# optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

In [None]:
# Rewind a fully trained network

# load model
model.load_state_dict(torch.load("{}/model_best.pt".format(args.result_dir), map_location=device))

# load pruner
pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))

print('Target Compression Ratio: {}\n'.format(str(args.compression)))

# Initialize conditions for pruning loop
comp = 1

# Reset Optimizer, and Scheduler
optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))

# Prune Model
comp -= args.iter_comp*comp
print('Sparsity = {}'.format(str(round(1-comp,4))))
pruner.score(model, loss, test_loader, device)
pruner.mask(1 - comp, args.mask_scope)

# Find the actual compression ratio
remaining_params, total_params = pruner.stats()
comp = total_params/(total_params-remaining_params)
print('New Compression: {}'.format(str(round(comp,4))))

print('Training for {} epochs'.format(str(args.prune_epochs)))
epoch = 0
for l in range(args.prune_epochs):
    model.train()
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        train_loss = loss(output, target)
        total += train_loss.item() * data.size(0)
        train_loss.backward()
        optimizer.step()
    # Save weights
    torch.save(model.state_dict(),"{}/prune_epoch_{}.pt".format(args.result_dir, str(epoch)))
    epoch += 1
    
# Eval
print('Evaluating Pruned model')
model.eval()
total = 0
correct1 = 0
correct5 = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        total += loss(output, target).item() * data.size(0)
        _, pred = output.topk(5, dim=1)
        correct = pred.eq(target.view(-1, 1).expand_as(pred))
        correct1 += correct[:,:1].sum().item()
        correct5 += correct[:,:5].sum().item()
average_loss = total / len(test_loader.dataset)
accuracy1 = 100. * correct1 / len(test_loader.dataset)
accuracy5 = 100. * correct5 / len(test_loader.dataset)
print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)\n'.format(
        average_loss, correct1, len(test_loader.dataset), accuracy1))


# Prune Result
prune_result = metrics_Jay.summary(model, 
                               pruner.scores,
                               metrics_Jay.flop(model, input_shape, device),
                               lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))

# Save Data
prune_result.to_pickle("{}/compression-{}-{}.pkl".format(args.result_dir, args.pruner, str(round(comp, 4))))

while comp <= args.compression:
    # Reset Optimizer, and Scheduler
    optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
    scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))
        
    # Prune Model
    comp = comp**(-1) - (args.iter_comp*(comp**(-1)))
    print('Sparsity = {}'.format(str(round(1-comp,4))))
    pruner.mask(1 - comp, args.mask_scope)
#     remaining_params, total_params = pruner.stats()
#     comp = total_params/(total_params-remaining_params)
#     print('New Compression: {}'.format(str(round(comp,4))))
    
    # Weight rewind
    if args.rewind =='weight': # or args.weight == 'NP':
        weights_epoch = args.prune_epochs - args.rewind_epochs
        model.load_state_dict(torch.load("{}/prune_epoch_{}.pt".format(args.result_dir, str(weights_epoch), map_location=device)))
        print('Weights rewound')
        # Apply the mask 
        pruner.mask(1 - comp, args.mask_scope)
    
    remaining_params, total_params = pruner.stats()
    comp = total_params/(total_params-remaining_params)
    print('New Compression: {}'.format(str(round(comp,4))))
    
    print('Training for {} epochs'.format(str(args.prune_epochs)))
    for l in range(args.prune_epochs):
        # Train for specified number of epochs
        model.train()
        total = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            train_loss = loss(output, target)
            total += train_loss.item() * data.size(0)
            train_loss.backward()
            optimizer.step()
        # Save weights
        torch.save(model.state_dict(),"{}/prune_epoch_{}.pt".format(args.result_dir, str(epoch)))
        epoch += 1        
    
    # Eval
    print('Evaluating Pruned model')
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:,:1].sum().item()
            correct5 += correct[:,:5].sum().item()
    average_loss = total / len(test_loader.dataset)
    accuracy1 = 100. * correct1 / len(test_loader.dataset)
    accuracy5 = 100. * correct5 / len(test_loader.dataset)
    print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)\n'.format(
            average_loss, correct1, len(test_loader.dataset), accuracy1))

    # Save weights
    torch.save(model.state_dict(),"{}/pruned_{}.pt".format(args.result_dir, str(round(comp, 4))))

    # Prune Result
    prune_result = metrics_Jay.summary(model, 
                                   pruner.scores,
                                   metrics_Jay.flop(model, input_shape, device),
                                   lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))

    # Save Data
    prune_result.to_pickle("{}/compression-{}-{}.pkl".format(args.result_dir, args.pruner, str(round(comp, 4))))