In [1]:
import os
import copy
import random
from matplotlib import pyplot as plt
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn

import torch_pruning as tp

In [2]:
from archs import resnet9
from data_generic import load_dataset

In [3]:
model = resnet9()

In [4]:
main_dir = "/home/mateuszpyla/stan/sharpness"
if not "RESULTS" in os.environ:
    os.environ["RESULTS"] = os.path.join(main_dir, "results")
if not "DATASETS" in os.environ:
    os.environ["DATASETS"] = os.path.join(main_dir, "data")

In [5]:
# Load the CIFAR10 dataset
train_dataset = CIFAR10(root=os.environ["DATASETS"], train=True, download=True, transform=transforms.ToTensor())
test_dataset = CIFAR10(root=os.environ["DATASETS"], train=False, download=True, transform=transforms.ToTensor())

batch_size = 256
train_dl = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dl = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def get_accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs_losses, outputs_accs = [], []
    for batch in val_loader:
        images, labels = batch 
        out = model(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = get_accuracy(out, labels)           # Calculate accuracy
        outputs_losses.append(loss.detach())
        outputs_accs.append(acc)
    epoch_loss = torch.stack(outputs_losses).mean()   # Combine losses
    epoch_acc = torch.stack(outputs_accs).mean()      # Combine accuracies
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [13]:
epochs = 25
max_lr = 0.01
grad_clip = 0.12
weight_decay = 0.0001
opt_func = torch.optim.Adam(model.parameters(),max_lr,amsgrad=True
                            , weight_decay=weight_decay)
loss_fn = "mse"

In [34]:
history = []
optimizer = opt_func
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                            steps_per_epoch=len(train_dl))
for epoch in (range(epochs)):
    # Training Phase 
    model.train()
    train_losses = []
    train_accuracy= []
    lrs=[]
    for (batch_idx, batch) in enumerate(train_dl):
        X, y = batch
        out = model(X)
        loss = F.cross_entropy(out, y)
        accuracy = get_accuracy(out, y)
        train_losses.append(loss)
        train_accuracy.append(accuracy)
        loss.backward()
        # Gradient clipping
        if grad_clip: 
            nn.utils.clip_grad_value_(model.parameters(), grad_clip)
        optimizer.step()
        optimizer.zero_grad()
        # Record & update learning rate
        lrs.append(get_lr(optimizer))
        sched.step()
        if batch_idx % 60 == 0:
            print(f"Train Epoch: {epoch+1}")
            print(f"[{batch_idx}/{len(train_dl)}]")
            print(f"{100. * batch_idx / len(train_dl)}%")
            print(f"Loss: {loss.item()}")
            print(f"Accuracy: {accuracy.item()}")
            break
            # print(f"Train Epoch: {epoch+1} [{batch_idx}/{len(train_dl)} ({100. * batch_idx / len(train_dl)}%)]\tLoss: {loss.item():.6f}, Accuracy: {accuracy.item():.4f}")
    
    # Validation phase
    result = evaluate(model, valid_dl)
    result['train_loss'] = torch.stack(train_losses).mean().item()
    result['train_accuracy'] = torch.stack(train_accuracy).mean().item()
    result['lrs'] = lrs
    print("Epoch [{}]")
    print(f"train_loss: {result['train_loss']}")
    print(f"train_acc: {result['train_accuracy']}")
    print(f"val_loss: {result['val_loss']}")
    print(f"val_acc: {result['val_acc']}")
    print(f"last_lr: {result['lrs']}")
    history.append(result)

Train Epoch: 1
[0/196]
0.0%
Loss: 0.9260492324829102
Accuracy: 0.70703125
Epoch [{}]
train_loss: 0.9260492324829102
train_acc: 0.70703125
val_loss: 0.892481803894043
val_acc: 0.6846449971199036
last_lr: [0.0003999999999999993]
Train Epoch: 2
[0/196]
0.0%
Loss: 0.8973379135131836
Accuracy: 0.69140625


KeyboardInterrupt: 

In [20]:
model.load_state_dict(torch.load("/home/mateuszpyla/stan/sharpness/results/trained_resnet9.pth"))

RuntimeError: Error(s) in loading state_dict for ResNet9:
	size mismatch for conv1.0.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 3, 3, 3]).
	size mismatch for conv1.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv1.1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv1.1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv1.1.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv1.1.running_var: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv2.0.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for conv2.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv2.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv2.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv2.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv2.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.0.0.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for res1.0.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.0.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.0.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.0.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.0.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.1.0.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for res1.1.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.1.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.1.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.1.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res1.1.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv3.0.weight: copying a param with shape torch.Size([256, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for conv3.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv3.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv3.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv3.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv3.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv4.0.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for conv4.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv4.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv4.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv4.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for conv4.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.0.0.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for res2.0.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.0.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.0.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.0.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.0.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.1.0.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for res2.1.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.1.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.1.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.1.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for res2.1.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for classifier.3.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([0, 0]).
	size mismatch for classifier.3.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for last.3.weight: copying a param with shape torch.Size([10, 512]) from checkpoint, the shape in current model is torch.Size([0, 0]).
	size mismatch for last.3.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for gradcam.0.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([0, 0, 3, 3]).
	size mismatch for gradcam.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for gradcam.1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for gradcam.1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for gradcam.1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).
	size mismatch for gradcam.1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([0]).

In [10]:
evaluate(model, valid_dl)

{'val_loss': 0.893187403678894, 'val_acc': 0.6840475797653198}

In [11]:
model_main = copy.deepcopy(model)

In [36]:
torch.save(model.state_dict(), "/home/mateuszpyla/stan/sharpness/results/trained_resnet9.pth")

In [15]:
model.eval()
ex = torch.randn(1,3,32,32)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=ex)

In [12]:
def random_prunings(model_main, num_prunings):
    amount = random.uniform(0.1, 0.8)
    for i in range(num_prunings):
        model = copy.deepcopy(model_main)
        print("prunning ", i)
        for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
            group.prune()
            print(group)

            print(evaluate(model, valid_dl))

In [16]:
random_prunings(model_main, 10)

prunning  0

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on classifier.3 (Linear(in_features=512, out_features=0, bias=True)) => prune_out_channels on classifier.3 (Linear(in_features=512, out_features=0, bias=True)), #idxs=10
--------------------------------



{'val_loss': 0.893187403678894, 'val_acc': 0.6840475797653198}

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv4.0 (Conv2d(256, 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on conv4.0 (Conv2d(256, 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), #idxs=512
[1] prune_out_channels on conv4.0 (Conv2d(256, 0, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on conv4.1 (BatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=512
[2] prune_out_channels on conv4.1 (BatchNorm2d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_6(ReluBackward0), #idxs=512
[3] prune_out_channels on _ElementWiseOp_6(ReluBackward0) => prune_out_channels on _ElementWiseOp_5(MaxPool2DWithIndicesBackward0), #idxs=512
[4] prune_out_channels on _ElementWiseOp_5(MaxPool2DWithIndicesBackward0) => pru

KeyboardInterrupt: 

In [17]:
class MySlimmingPruner(tp.pruner.MetaPruner):
    def regularize(self, model, reg):
        for m in model.modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and m.affine==True:
                m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) # Lasso for sparsity

class MySlimmingImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        #note that we have multiple BNs in a group, 
        # we store layer-wise scores in a list and then reduce them to get the final results
        group_imp = [] # (num_bns, num_channels) 
        # 1. iterate the group to estimate importance
        for dep, idxs in group:
            layer = dep.target.module # get the target model
            prune_fn = dep.handler    # get the pruning function of target model, unused in this example
            if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and layer.affine:
                local_imp = torch.abs(layer.weight.data)
                group_imp.append(local_imp)
        if len(group_imp)==0: return None # return None if the group contains no BN layer
        # 2. reduce your group importance to a 1-D scroe vector. Here we use the average score across layers.
        group_imp = torch.stack(group_imp, dim=0).mean(dim=0) 
        return group_imp # (num_channels, )

# You can implement any importance functions, as long as it transforms a group to a 1-D score vector.
class RandomImportance(tp.importance.Importance):
    @torch.no_grad()
    def __call__(self, group, **kwargs):
        _, idxs = group[0]
        return torch.rand(len(idxs))

In [21]:
imp = MySlimmingImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 10:
        ignored_layers.append(m)

iterative_steps = 5
pruner = MySlimmingPruner(
    model, 
    ex, 
    global_pruning=False,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

RuntimeError: Given groups=1, expected weight to be at least 1 at dimension 0, but got weight of size [0, 3, 3, 3] instead

In [19]:
base_macs, base_nparams = tp.utils.count_ops_and_params(model, ex)
for i in range(iterative_steps):
    pruner.step()

    macs, nparams = tp.utils.count_ops_and_params(model, ex)
    print(model)
    print(model(ex).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    print("="*16)



RuntimeError: Given groups=1, expected weight to be at least 1 at dimension 0, but got weight of size [0, 3, 3, 3] instead

In [None]:
"""
for x in model._modules.keys():
            layer = getattr(model, x)
            group = DG.get_pruning_group(layer, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
"""

In [None]:
"""def random_prunings(model_main, num_prunings):
    amount = random.uniform(0.1, 0.8)
    # method = random.choice(['magnitude', 'random_unstructured'])
    imp = tp.importance.TaylorImportance()

    selected_layers = list(model._modules.keys()) # random.sample

    for _ in range(num_prunings):
        pruner = tp.pruner.MagnitudePruner(
            model,
            ex,
            importance=imp,
            iterative_steps=5,
            ch_sparsity=amount
        )

        base_macs, base_nparams = tp.utils.count_ops_and_params(model, ex)
        for layer_name in selected_layers:
        for i in range(iterative_steps):
            if isinstance(imp, tp.importance.TaylorImportance):
            pruner.prune(layer, method=pruning_method, amount=amount)
        macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) """

In [41]:
for group in DG.get_all_groups(ignored_layers=[], root_module_types=[nn.Conv2d, nn.Linear]):
    print(group)


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on classifier.3 (Linear(in_features=512, out_features=10, bias=True)) => prune_out_channels on classifier.3 (Linear(in_features=512, out_features=10, bias=True)), #idxs=10
--------------------------------


--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv4.0 (Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on conv4.0 (Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), #idxs=512
[1] prune_out_channels on conv4.0 (Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on conv4.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), #idxs=512
[2] prune_out_channels on conv4.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels o

In [39]:
tp.prune_conv_out_channels

<bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f86ac55ba90>>

In [None]:
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

In [None]:
train, test = load_dataset("coloured_mnist_split1.0")

In [None]:
max_steps = 100


In [None]:
for step in range(max_steps):
