# SynFlow Notebook

In [42]:
import argparse
import json
import os
from Experiments import singleshot
from Experiments import multishot
from Experiments.theory import unit_conservation
from Experiments.theory import layer_conservation
from Experiments.theory import imp_conservation
from Experiments.theory import schedule_conservation

In [75]:
class Args:
    dataset = 'cifar10'
    model = 'vgg16'
    model_class = 'lottery'
    dense_classifier = False
    pretrained = False
    optimizer = 'adam'
    train_batch_size = 64
    test_batch_size = 256
    pre_epochs = 0
    post_epochs = 1#10
    train_epochs = 250
    # what about stopping the IMP when the sparsity is reached?
    lr = 0.001
    lr_drops = tuple([])
    lr_drop_rate = 0.1
    weight_decay = 0.0
    pruner= 'synflow'
    compression= 1.0
    prune_epochs= 1
    compression_schedule= 'exponential'
    mask_scope= 'global'
    prune_dataset_ratio= 10
    prune_batch_size= 256
    prune_bias= False
    prune_batchnorm= False
    prune_residual= False
    prune_train_mode= False
    reinitialize= False
    shuffle= False
    invert= False
    pruner_list= tuple([])
    prune_epoch_list= tuple([])
    compression_list= tuple([])
    level_list= tuple([])
    experiment = 'BK'
    expid = True
    result_dir = 'Results/data'
    gpu = 1
    workers =4
    seed = 1
    no_cuda= True#'store_true'
    verbose= True#'store_true'
    trial = 0
    # Extra arguments for rewinding, Renda et al 2020(1)
    rewind = 'LR' # Will there be rewinding, if so, what type? (None, 'LR', 'weight', 'NP')
    ## Only pertinent for traditional rewinding as seen in Renda et al.
    rewind_train = 20 # how many epochs of training before rewinding?  
    rewind_epochs = 2 # how far back to rewind?

In [76]:
args = Args()

In [77]:
## In case of argument mistakes
if args.rewind == None:
    args.rewind = 'None'
    args.rewind_train = None
    args.rewind_epochs = None
elif args.rewind == 'NP':
    args.rewind_train = None
    args.rewind_epochs = None

In [61]:
## Construct Result Directory ##
if args.expid == False:
    print("WARNING: this experiment is not being saved.")
    setattr(args, 'save', False)
else:
    expid = args.rewind+'_'+args.pruner+'_'+args.dataset+'_'+args.model+'_trial_'+str(args.trial)
    result_dir = '{}/{}/{}'.format(args.result_dir, args.experiment, expid)
    setattr(args, 'save', True)
    setattr(args, 'result_dir', result_dir)
#     try:
    os.makedirs(result_dir)
#     except FileExistsError:
#         val = ""
#         while val not in ['yes', 'no']:
#             val = input("Experiment '{}' with expid '{}' exists.  Overwrite (yes/no)? ".format(args.experiment, args.expid))
#         if val == 'no':
#             quit()
            
print('Expt ID: ' + expid) 
print('Pruner: ' + args.pruner)
print('Rewind Method: ' + args.rewind)
print('Train epochs before rewind: ' + str(args.rewind_train))
print('Epochs to rewind: ' + str(args.rewind_epochs))

Expt ID: None_synflow_cifar10_vgg16_trial_0
Pruner: synflow
Rewind Method: None
Train epochs before rewind: None
Epochs to rewind: None


In [None]:
## Save Args ##
if args.save:
    with open(args.result_dir + '/args.json', 'w') as f:
        json.dump(args.__dict__, f, sort_keys=True, indent=4)

## Run Experiment ##
if args.experiment == 'BK':
    """
    Custom experiment for the work done by Balwani & Krzyston 2021
    Based off of singleshot.py & multishot.py seen in the Ganguli Lab SynFlow repo
    """
    import numpy as np
    import pandas as pd
    import torch
    import torch.nn as nn
    from Utils import load
    from Utils import generator
    from Utils import metrics
    from train import *
    from prune import *
    from Experiments import singleshot
    
    
    if args.rewind == None:
        singleshot.run(args)


    else:

        ## Random Seed and Device ##
        torch.manual_seed(args.seed)
        device = load.device(args.gpu)

        ## Data ##
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset) 
        prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
        train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
        test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

        ## Model ##
        print('Creating {} model.'.format(args.model))
        model = load.model(args.model, args.model_class)(input_shape, 
                                                         num_classes, 
                                                         args.dense_classifier,
                                                         args.pretrained).to(device)
        loss = nn.CrossEntropyLoss()
        opt_class, opt_kwargs = load.optimizer(args.optimizer)
        optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

        ## Save Original ##
        torch.save(model.state_dict(),"{}/model_init.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),"{}/optimizer_init.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),"{}/scheduler_init.pt".format(args.result_dir))
        
        ## Counter for LR and weight rewinding
        if args.rewind == 'LR' or args.rewind == 'weight':
            count = 0
        ## For NP rewinding, track all neural persistences 
        elif args.rewind == 'NP':
            NPs = []
        
        # Every instance of pruning, prune a certain amount, stop when pruned at the specified level
        for epoch in range(epochs):
            count += 1
            
            # Train
            model.train()
            total = 0
            for batch_idx, (data, target) in enumerate(dataloader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                train_loss = loss(output, target)
                total += train_loss.item() * data.size(0)
                train_loss.backward()
                optimizer.step()
                if verbose & (batch_idx % log_interval == 0):
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(dataloader.dataset),
                        100. * batch_idx / len(dataloader), train_loss.item()))
            
            # Eval
            model.eval()
            total = 0
            correct1 = 0
            correct5 = 0
            with torch.no_grad():
                for data, target in dataloader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    total += loss(output, target).item() * data.size(0)
                    _, pred = output.topk(5, dim=1)
                    correct = pred.eq(target.view(-1, 1).expand_as(pred))
                    correct1 += correct[:,:1].sum().item()
                    correct5 += correct[:,:5].sum().item()
            average_loss = total / len(dataloader.dataset)
            accuracy1 = 100. * correct1 / len(dataloader.dataset)
            accuracy5 = 100. * correct5 / len(dataloader.dataset)
            if verbose:
                print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
                    average_loss, correct1, len(dataloader.dataset), accuracy1))
            
            # Save weights
            torch.save(model.state_dict(),"{}/model_"+epoch+".pt".format(args.result_dir))
            
            
            # Rewind
#             if args.rewind == 'NP':
#                 # compute the neural persistence of the model
#                 NP.append(metrics.neural_persistence(model))
#                 # NP should always be increasing
                
                
            elif count == args.rewind_train: #weight or LR rewind
                if args.rewind == 'LR':
                    # reset the learning rate
                    scheduler.load_state_dict(torch.load("{}/scheduler_init.pt".format(args.result_dir), map_location=device))
                if args.rewind == 'weight':
                    # determine which epoch to return to
                    proper_epoch = args.rewind_train - args.rewind_epochs
                    if proper_epoch == 0:
                        proper_epoch = 'init'
                    # reset weights to proper epoch
                    model.load_state_dict(torch.load("{}/model_"+proper_epoch+".pt".format(args.result_dir), map_location=device))
                    # reset the learning rate to initial conditions
                    scheduler.load_state_dict(torch.load("{}/scheduler_init.pt".format(args.result_dir), map_location=device))
                
                # reset counter
                count = 0
            
            # Prune Result
            prune_result = metrics.summary(model, 
                                           pruner.scores,
                                           metrics.flop(model, input_shape, device),
                                           lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
            # Train Model
            post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                          test_loader, device, args.post_epochs, args.verbose)

            # Save Data
            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(args.result_dir, args.pruner, str(compression),  str(level)))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(args.result_dir, args.pruner, str(compression), str(level)))



    

In [None]:
# Get the weights of a model
params = {}
for name, param in model.named_parameters():
    if 'bias' not in name:
        params[(name)] = param

In [None]:
keys = list(params.keys())

In [None]:
model.modules

In [None]:
"""
TODO
- train all the way, THEN IMP
- implement NP rewinding
- compute total sparsity
- test neural persistence code
- 
- write paper
"""

In [6]:
if isinstance(module, layers.Linear) or isinstance(module, nn.Linear):
    print(module)


NameError: name 'module' is not defined

In [28]:
ind = 0
for name, module in model.named_modules():
    print(name)
    if ind > 1:
        print(module.in_features)
    ind +=1


layers
layers.0


AttributeError: 'ConvModule' object has no attribute 'in_features'

In [33]:
module.conv.kernel_size[0]

3

In [23]:
module.out_features

10

In [66]:
weights_all = []
for name, param in model.named_parameters():
    if 'weight' in name:
        weights_all.append(np.double(torch.max(torch.abs(param)).detach().numpy()))

In [68]:
max(weights_all)

1.1864285469055176

In [65]:
args.epochs

100

In [66]:
i = 9

In [70]:
i%args.epochs

9

In [79]:
(args.rewind_train - args.rewind_epochs) +2

20