In [1]:
from generalization.utils.train import DEFAULT_PARAMS as hparams
from generalization.randomization import available_corruptions

print("Available corruptions:\n", available_corruptions())

hparams['dataset_name'] = 'cifar10'
hparams['n_classes'] = 10
hparams['corrupt_name'] = 'normal_labels'
hparams['corrupt_prob'] = 0
hparams['gradient_clipping'] = True
hparams['lr'] = 0.04
hparams['momentum'] = 0.9
hparams['weight_decay'] = 0.0
hparams["lr_scheduler"] = False

hparams

Available corruptions:
 ['gaussian_pixels', 'random_labels', 'random_pixels', 'partial_labels', 'shuffled_pixels']


{'seed': 88,
 'batch_size': 256,
 'learning_rate': 0.1,
 'epochs': 30,
 'val_every': 1,
 'log_dir': 'logs',
 'dataset_name': 'cifar10',
 'n_classes': 10,
 'corrupt_name': 'normal_labels',
 'corrupt_prob': 0,
 'gradient_clipping': True,
 'lr': 0.04,
 'momentum': 0.9,
 'weight_decay': 0.0,
 'lr_scheduler': False}

In [2]:
from lightning.pytorch import Trainer

from generalization.utils import Classifier, LitDataModule
from generalization.models import create_model

dm = LitDataModule(hparams=hparams)
dm.setup()

Files already downloaded and verified
Files already downloaded and verified


In [3]:
import time
import torch
import torch_pruning as tp

torch.set_float32_matmul_precision("medium")


class PruningClassifier(Classifier):
    def __init__(self, net, hparams, pruner_params=None, pruner_entry=None):
        super().__init__(net=net, hparams=hparams)
        self.pruner_params = pruner_params
        self.pruner_entry = pruner_entry

        if self.hparams["prune"]:
            self.build_pruner()

        self.pruner_stats = {
            "pruner/macs_ratio": [],
            "pruner/macs": [],
            "pruner/nparams": [],
            "pruner/nparams_ratio": [],
            "epoch": [],
            "epoch_time": [],
        }

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        if self.hparams["prune"] and self.hparams["regularize"]:
            self.pruner.regularize(self.net)  # <== for sparse training
        optimizer.step(closure=optimizer_closure)

    def training_step(self, batch, batch_idx):
        x, y, indices = batch
        loss, logits, y = self.step(batch, batch_idx, reduction="mean")

        self.train_acc.update(logits, y)
        acc = self.train_acc.compute().mean()

        self.log("train/loss", loss, prog_bar=True)
        self.log("train/acc", acc, prog_bar=True)
        return loss

    def net_device(self):
        return next(self.net.parameters()).device

    def build_pruner(self):
        ignored_layers = []
        n_out = self.hparams["n_classes"]
        for m in self.net.modules():
            if isinstance(m, torch.nn.Linear) and m.out_features == n_out:
                ignored_layers.append(m)  # DO NOT prune the final classifier!

            # DO NOT prune first convolutional layer
            if isinstance(m, torch.nn.Conv2d) and m.in_channels == 3:
                ignored_layers.append(m)

        if self.pruner_entry is None:
            self.pruner = (
                tp.pruner.MetaPruner(  #  build using self.pruner_params and self.net
                    model=self.net,
                    ignored_layers=ignored_layers,
                    **self.pruner_params,
                )
            )
        else:
            self.pruner = self.pruner_entry(
                model=self.net,
                ignored_layers=ignored_layers,
                **self.pruner_params,
            )
            
        example_inputs = self.pruner_params["example_inputs"]
        self.inputs_shape = [1] + list(example_inputs.shape[1:])
        self.base_macs, self.base_nparams = tp.utils.count_ops_and_params(
            self.net, example_inputs.to(self.net_device())
        )

        return self.pruner

    def on_train_epoch_start(self):
        self.previous_index = -1
        if (
            self.hparams["prune"]
            and not self.hparams["regularize"]  # not sparse training
            and self.current_epoch % self.hparams["prune_every_n_epoch"] == 0
        ):
            # access trainloader
            trainloader = self.trainer.datamodule.train_dataloader()

            if isinstance(self.pruner.importance, tp.importance.TaylorImportance):
                batch = next(iter(trainloader))
                inputs, targets, indices = batch
                assert (
                    self.previous_index != indices
                ).all(), "This is a hack, we need to use the same batch for pruning!"
                self.previous_index = indices
                inputs, targets = inputs.to(self.net_device()), targets.to(
                    self.net_device()
                )
                # print("inputs.shape", inputs.shape)
                # print("targets.shape", targets.shape)
                loss = self.loss(self(inputs), targets, reduction="mean")
                loss.backward()  # before pruner.step()

            self.pruner.step()

            x = torch.randn(self.inputs_shape)

            macs, nparams = tp.utils.count_ops_and_params(
                self.net, x.to(self.net_device())
            )

            self.pruner_stats["pruner/macs"].append(macs)
            self.pruner_stats["pruner/nparams"].append(nparams)
            self.pruner_stats["pruner/macs_ratio"].append(macs / self.base_macs)
            self.pruner_stats["pruner/nparams_ratio"].append(
                nparams / self.base_nparams
            )
            self.pruner_stats["epoch"].append(self.current_epoch)

            self.log("pruner/macs_ratio", self.pruner_stats["pruner/macs_ratio"][-1])
            self.log(
                "pruner/nparams_ratio", self.pruner_stats["pruner/nparams_ratio"][-1]
            )

            # self.hparams["prune"] = False
            # decrease learning rate
            self.trainer.optimizers[0].param_groups[0]["lr"] = (
                self.trainer.optimizers[0].param_groups[0]["lr"] * 0.4
            )

        self.epoch_start_time = time.time()

    def on_train_epoch_end(self, *args, **kwargs):
        self.pruner_stats["epoch_time"].append(time.time() - self.epoch_start_time)
        self.log("epoch_time", self.pruner_stats["epoch_time"][-1], prog_bar=True)

In [4]:
cnn = create_model("inception", lib="torch", cifar=True)

hparams['prune'] = False
hparams['learning_rate'] = 0.04
model = PruningClassifier(net=cnn, hparams=hparams, pruner_params=None)
trainer = Trainer(max_epochs=7, devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [103]:
# trainer.fit(model, datamodule=dm)

In [104]:
# trainer.test(datamodule=dm)

In [9]:
from functools import partial


hparams["prune"] = True
hparams["learning_rate"] = 0.003
hparams["max_epochs"] = 10
hparams["prune_every_n_epoch"] = 2
hparams["sparsity_learning"] = True
hparams["regularize"] = True

# Importance criteria
example_inputs = torch.randn(1, 3, 28, 28)
imp = (
    tp.importance.TaylorImportance()
)  # or MagnitudeImportance, GroupNormPruner, BNScalePruner, etc.

cnn = create_model("inception", lib="torch", cifar=True)

imp = tp.importance.GroupNormImportance(
    p=2, normalizer="max"
)  # normalized by the maximum score for CIFAR

pruner_params = dict(
    example_inputs=example_inputs,
    global_pruning=True,
    importance=imp,
    ch_sparsity=1.0,
    reg=5e-4,
    iterative_steps=hparams["max_epochs"] // hparams["prune_every_n_epoch"],
)

pruner_entry = partial(tp.pruner.GroupNormPruner, **pruner_params)

# load from checkpoint
model = PruningClassifier(
    # "lightning_logs/version_2/checkpoints/epoch=6-step=1372.ckpt",
    net=cnn,
    hparams=hparams,
    pruner_params=pruner_params,
    pruner_entry=pruner_entry,
)
trainer = Trainer(max_epochs=hparams["max_epochs"], devices=1)
trainer.test(model, datamodule=dm)
model.build_pruner();

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 55.12it/s]


In [10]:
print(f"Before pruning: MACs = {model.base_macs}, n_params = {model.base_nparams}")
trainer.fit(model, datamodule=dm)

Before pruning: MACs = 152342618.0, n_params = 8047866
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | net       | InceptionSmall     | 8.0 M 
1 | train_acc | MulticlassAccuracy | 0     
2 | valid_acc | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
8.0 M     Trainable params
0         Non-trainable params
8.0 M     Total params
32.191    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/196 [00:00<?, ?it/s]                           

AttributeError: 'NoneType' object has no attribute 'data'

In [66]:
trainer.test(datamodule=dm)

  rank_zero_warn(


Files already downloaded and verified
Files already downloaded and verified


Restoring states from the checkpoint path at /home/step/Code/projects/ids-generalization/notebooks/lightning_logs/version_18/checkpoints/epoch=9-step=1960.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/step/Code/projects/ids-generalization/notebooks/lightning_logs/version_18/checkpoints/epoch=9-step=1960.ckpt


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 95.51it/s]


[{'test/loss': 1.411431074142456, 'test/acc': 0.7136176824569702}]

In [67]:
model.pruner_stats['epoch_time'] = model.pruner_stats['epoch_time'][::2]

In [68]:
import pandas as pd

pd.DataFrame(model.pruner_stats)

Unnamed: 0,pruner/macs_ratio,pruner/macs,pruner/nparams,pruner/nparams_ratio,epoch,epoch_time
0,0.813656,123954495.0,6497719,0.807384,0,8.41043
1,0.65722,100122591.0,5131958,0.637679,2,8.026077
2,0.51877,79030777.0,3934190,0.488849,4,7.866646
3,0.396168,60353310.0,2889554,0.359046,6,7.70078
4,0.295554,45025522.0,2023650,0.251452,8,7.361772


In [69]:
pruner_params['iterative_steps']

5