In [15]:
import time

from model import LitModel
from main import (
    build_experiment,
    DEFAULT_PARAMS,
    get_cifar_models,
    available_corruptions,
)

import torch
import lightning as L

all_corruptions = available_corruptions()
print(all_corruptions)

models = get_cifar_models(lib="torch")

print(list(models.keys()))

['gaussian_pixels', 'random_labels', 'random_pixels', 'partial_labels', 'shuffled_pixels']
['resnet18', 'resnet34', 'alexnet', 'inception', 'mlp_1x512', 'mlp_3x512']


In [16]:
import lightning.pytorch as pl
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.callbacks import ModelPruning
from lightning.pytorch.utilities.rank_zero import rank_zero_debug

class LitModelPruning(ModelPruning):
    def filter_parameters_to_prune(self, parameters_to_prune):
        # filter linear layers
        filter_names = ["Linear"]
        return [
            (param, name)
            for param, name in parameters_to_prune
            if not any([filter_name in param.__class__.__name__ for filter_name in filter_names])
        ]
    
    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
        if self._prune_on_train_epoch_end:
            rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
            self._run_pruning(pl_module.current_epoch)

def apply_prunning_every(k):
    """
    Apply pruning every k epochs, ignores epoch 0.
    
    Parameters:
        k (int): apply pruning every k epochs

    Returns:
        Callable: wrapper with the following signature: `wrapper(current_epoch) -> bool`
    """

    return lambda current_epoch: current_epoch % k == 0 and current_epoch != 0

In [18]:
import os
import wandb
import torch

torch.set_float32_matmul_precision('medium')
hparams = DEFAULT_PARAMS.copy()

MODEL_NAME = "alexnet"
CORRUPTION_NAME = "normal_labels"
CORRUPTION_PROB = 0.0

model = models[MODEL_NAME]

hparams["model_name"] = MODEL_NAME
hparams["n_classes"] = 10
hparams["drop_return_index"] = True
hparams["corrupt_name"] = CORRUPTION_NAME
hparams["corrupt_prob"] = CORRUPTION_PROB
hparams["val_every"] = 1


experiment_name = f"{hparams['model_name']}_{hparams['corrupt_name']}_{hparams['corrupt_prob']}"

os.environ.update({"WANDB_NOTEBOOK_NAME": "pruning.ipynb"})

try:
    wandb.finish()
except Exception:
    pass

pruning_cb = LitModelPruning(
    pruning_fn="l1_unstructured",
    amount=0.1,
    use_global_unstructured=True,
    use_lottery_ticket_hypothesis=True,
    apply_pruning=apply_prunning_every(5),
    verbose=2
)

trainer = L.Trainer(
    max_epochs=30,
    logger=[wb_logger, tb_logger],
    default_root_dir="logs",
    check_val_every_n_epoch=hparams["val_every"],
    callbacks=[L.pytorch.callbacks.EarlyStopping(monitor="valid/loss", patience=3)],
    accelerator="gpu",
)

data = build_experiment(
    corrupt_name=hparams['corrupt_name'],
    corrupt_prob=hparams['corrupt_prob'],
    batch_size=hparams["batch_size"],
)
train_loader = data["normal_labels"]["train_loader"]
val_loader = data["normal_labels"]["val_loader"]

# dict(model.named_parameters()).keys()

hparams["learning_rate"] = 0.01

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


In [11]:
tb_logger = L.pytorch.loggers.TensorBoardLogger(hparams["log_dir"], name="pruning")
wb_logger = L.pytorch.loggers.WandbLogger(project="pruning", name=experiment_name)
# wb_logger.watch(model)
wandb.run.log_code(".")
pl_model = LitModel(model, hparams=hparams)

start_time = time.time()
trainer.fit(
    pl_model,
    train_loader,
    val_loader,
)
print(f"Training took {time.time() - start_time:.2f} seconds")
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | net            | SmallAlexNet       | 56.8 M
1 | train_acc      | MulticlassAccuracy | 0     
2 | valid_acc      | MulticlassAccuracy | 0     
3 | valid_top5_acc | MulticlassAccuracy | 0     
4 | test_acc       | MulticlassAccuracy | 0     
5 | test_top5_acc  | MulticlassAccuracy | 0     
------------------------------------------------------
56.8 M    Trainable params
0         Non-trainable params
56.8 M    Total params
227.307   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Epoch 3: 100%|██████████| 196/196 [00:09<00:00, 19.64it/s, v_num=q4_6, valid/loss=1.600, valid/acc=0.733, valid/top5_acc=0.973, train/loss=0.0331, train/acc=0.989]
Training took 42.17 seconds


0,1
epoch,▁▁▃▃▆▆██
train/acc,▁▅▇█
train/loss,█▄▂▁
trainer/global_step,▃▁▃▄▁▄▆▁▆█▁█
valid/acc,▁▅▇█
valid/loss,▁▃▅█
valid/top5_acc,▁▁▆█

0,1
epoch,3.0
train/acc,0.9895
train/loss,0.03314
trainer/global_step,783.0
valid/acc,0.73264
valid/loss,1.605
valid/top5_acc,0.97309
