In [1]:
from composer.callbacks import LRMonitorHparams
from composer.datasets import DataLoaderHparams
from composer.loggers import WandBLoggerHparams
from composer.models import ComposerClassifier
from composer.optim import SGDHparams, CosineAnnealingSchedulerHparams, MultiStepSchedulerHparams, ConstantSchedulerHparams
from composer.trainer import Trainer, TrainerHparams
from composer.utils.object_store import ObjectStoreProviderHparams, ObjectStoreProvider
from copy import deepcopy
from lth_diet.data import CIFAR10DataHparams, DataHparams
from lth_diet.exps import LotteryExperiment
from lth_diet.models import ResNetCIFARClassifierHparams, ClassifierHparams
from lth_diet.pruning import Mask, PrunedClassifier, PruningHparams
from lth_diet.pruning.pruned_classifier import prunable_layer_names
from lth_diet.utils import utils
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import torch
from tqdm import tqdm
plt.style.use("default")
rc = {"figure.figsize": (4, 3), "figure.dpi": 150, "figure.constrained_layout.use": True, "axes.grid": True, 
      "axes.spines.right": False, "axes.spines.top": False, "axes.linewidth": 0.6, "grid.linewidth": 0.6,
      "xtick.major.width": 0.6, "ytick.major.width": 0.6, "xtick.major.size": 4, "ytick.major.size": 4, 
      "axes.labelsize": 11, "axes.titlesize": 11, "xtick.labelsize": 10, "ytick.labelsize": 10,
      "axes.titlepad": 4, "axes.labelpad": 2, "xtick.major.pad": 2, "ytick.major.pad": 2,
      "lines.linewidth": 1.2, 'lines.markeredgecolor': 'w', "patch.linewidth": 0}
sns.set_theme(style='ticks', palette=sns.color_palette("tab10"), rc=rc)
object_store = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY').initialize_object()
bucket_dir = os.environ['OBJECT_STORE_DIR']

Start with the LotteryExperiment in lottery_test but with rewinding step 2000 (this is where IMP works really well).

In [3]:
config = "../configs/lottery_test.yaml"
exp = LotteryExperiment.create(config, cli_args=False)
exp.rewinding_steps = "2000ba"
print(exp.name)

Lottery(model=ResNetCIFAR(num_classes=10,num_layers=20),train_data=CIFAR10(train=True),train_batch_size=128,optimizer=SGDHparams(lr=0.1,momentum=0.9,weight_decay=0.0001,dampening=0.0,nesterov=False),schedulers=[MultiStepSchedulerHparams(milestones=[31200ba,46800ba],gamma=0.1)],max_duration=62400ba,seed=6174,rewinding_steps=2000ba,pruning=PruningHparams(pruning_fraction=0.2),algorithms=[ChannelsLastHparams()],callbacks=[LRMonitorHparams()])


---

Use the replicate 0 to start with. Initial model is a fully trained dense network, i.e. level 0.

In [4]:
replicate = 1
level = 0

Prepare the training and test data loaders. Prepare model_hparams.

In [5]:
train_data = CIFAR10DataHparams(True).initialize_object(128, DataLoaderHparams(persistent_workers=False))
test_data = CIFAR10DataHparams(False).initialize_object(1000, DataLoaderHparams(persistent_workers=False))
# print(train_data.dataset, "\n", train_data.batch_size, "\n", test_data.dataset)
model_hparams = ResNetCIFARClassifierHparams(10, 20)

Location of the model and the mask.

In [6]:
exp_hash = utils.get_hash(exp.name)
location = f"{exp_hash}/replicate_{replicate}/level_{level}/main"
print(location)

6367fcd0a8f09134f8c65d472df6880e/replicate_0/level_0/main


Function for loading state_dict and mask from the cloud and constructing a PrunedClassifier.

In [7]:
def load_pruned_model_from_cloud(
    exp: LotteryExperiment, replicate: int, level: int, model_hparams: ClassifierHparams, 
    object_store: ObjectStoreProvider, ckpt: str
) -> PrunedClassifier:
    exp_hash = utils.get_hash(exp.name)
    location = f"{exp_hash}/replicate_{replicate}/level_{level}/main"
    name = f"model_{ckpt}.pt"
    state_dict = utils.load_object(location, name, object_store, torch.load)
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    mask = Mask.load(location, object_store)
    pruned_model = PrunedClassifier(model, mask)
    pruned_model.cpu()
    pruned_model.eval()
    return pruned_model

Load level 0 final model.

In [8]:
pruned_model = load_pruned_model_from_cloud(exp, replicate, level, model_hparams, object_store, "final")

Function to calculate the density of a PrunedClassifier.

In [9]:
def get_pruned_model_density(model: PrunedClassifier) -> float:
    names = [PrunedClassifier.to_mask_name(name) for name in prunable_layer_names(model)]
    params = torch.cat([getattr(model, name).flatten() for name in names])
    return (params != 0).float().mean().item()

Check that the density of the model is 1.0.

In [10]:
print(get_pruned_model_density(pruned_model))

1.0


Function to calculate accuracy of a model.

In [11]:
def get_accuracy(model: ComposerClassifier, data: DataHparams) -> float:
    model.cuda()
    model.eval()
    correct = 0
    for batch in tqdm(data):
        batch = batch[0].cuda(), batch[1].cuda()
        logits = model(batch)
        correct += (logits.argmax(dim=-1) == batch[1]).sum()
    model.cpu()
    return (correct / len(data.dataset)).item()

Calculate test accuracy of our level 0 final model: 91.34%.

In [12]:
print(get_accuracy(pruned_model, test_data))

100%|██████████| 10/10 [00:01<00:00,  9.44it/s]

0.9133999943733215





Save the level 0 final PrunedClassifier as `model_0.pt`.

In [13]:
torch.save(pruned_model.module.state_dict(), "crazy/model_0.pt")

We also want to save the mask for this classifier. Note, the mask is a ones mask. Save it as `mask_0.pt`.

In [14]:
mask = Mask.ones_like(pruned_model)
torch.save(mask, "crazy/mask_0.pt")

New mask: magnitude prune 20% of the current mask.

In [15]:
mask = PruningHparams(pruning_fraction=0.2).prune(pruned_model, mask)

Check that the mask density is 80%.

In [16]:
print(f"{mask.density.item():.2f}")

0.80


Sanity check: mask should be the same as level 1 of LotteryExperiment.

In [17]:
test_mask = Mask.load(f"{exp_hash}/replicate_{replicate}/level_{1}/main", object_store)
for k, v in test_mask.items():
    if not (v == mask[k]).all().item():
        print("test failed")
        break
del test_mask, k, v

Save mask as `mask_1.pt`.

In [18]:
torch.save(mask, "crazy/mask_1.pt")

Clear up the model and mask for safety.

In [19]:
del pruned_model, mask

Load a pruned model from model name and mask name.

In [20]:
def load_pruned_model(model_name: str, mask_name: str, model_hparams: ClassifierHparams) -> PrunedClassifier:
    state_dict = torch.load(f"crazy/{model_name}.pt")
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    mask = torch.load(f"crazy/{mask_name}.pt")
    model.cpu()
    model.eval()
    return PrunedClassifier(model, mask)

Take `model_0` and project it with `mask_1` to get `model_1__0`.


In [21]:
model_1__0 = load_pruned_model("model_0", "mask_1", model_hparams)

How well does `model_1__0` perform (initial model after projection)? 91.27% ... not bad.

In [22]:
print(get_accuracy(model_1__0, test_data))

100%|██████████| 10/10 [00:01<00:00,  9.96it/s]

0.9126999974250793





Save `model_1__0`.

In [23]:
torch.save(model_1__0.module.state_dict(), "crazy/model_1__0.pt")

Train `model1_1__1`. 5 epochs seems enough. Same hyperparameters as where we left off.

In [24]:
model_1__0 = load_pruned_model("model_0", "mask_1", model_hparams)
model = model_1__0
model.cuda()
model.train()
max_duration = "1950ba"
optimizer = SGDHparams(lr=0.001, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
scheduler = ConstantSchedulerHparams().initialize_object()
seed = 789
logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
callback = [LRMonitorHparams().initialize_object()]
trainer = Trainer(model=model, 
                  train_dataloader=train_data, 
                  max_duration=max_duration, 
                  eval_dataloader=test_data,
                  optimizers=optimizer, 
                  schedulers=scheduler,
                  device="gpu", 
                  validate_every_n_batches=195, 
                  validate_every_n_epochs=-1,
                  precision="amp", 
                  seed=seed, 
                  loggers=logger,
                  callbacks=callback,
                  save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
                  save_name_format="{batch}ba",
                  save_latest_format=None,
                  save_interval="195ba",
                  save_weights_only=True)
trainer.fit()
model.cpu()
model.eval();

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmansheej[0m (use `wandb login --relogin` to force relogin)





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy/val,▁▂▁▃▃▃▂▄▆█
crossentropyloss/val,▂▄▂▄▁▆▃▃█▂
epoch,▁▁▃▃▅▅▆▆███
loss/train,▄▂▄▁▂█▁▃▃▂▁▂▄▃▂▃▁█▃▃▂▃▆▃▂▃▃▃▄▁█▁▁▅▂▃▁▃▂▃
lr-SGD/group0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/batch_idx,▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
accuracy/val,0.9137
crossentropyloss/val,0.3595
epoch,4.0
loss/train,0.01203
lr-SGD/group0,0.001
trainer/batch_idx,389.0
trainer/global_step,1950.0


Model test accuracy: 91.37% -> It is in the error sublevel set.

In [46]:
print(get_accuracy(model, test_data))

100%|██████████| 10/10 [00:01<00:00,  9.75it/s]

0.9136999845504761





Save this trained model as `model_1__1.pt`.

In [47]:
torch.save(model.module.state_dict(), "crazy/model_1__1.pt")

We no longer need `model_1__0` or `model`.

In [52]:
del model, model_1__0

Load `model_0` and `model_1__1` so that we can look for LMC.

In [50]:
model_0 = load_pruned_model("model_0", "mask_0", model_hparams)
model_1__1 = load_pruned_model("model_1__1", "mask_1", model_hparams)

Function for calculating the midpoint state dict.

In [30]:
def midpoint(state_dict, state_dict_):
    state_dict__ = {}
    for k, v in state_dict.items():
        state_dict__[k] = (v + state_dict_[k]) / 2
    return state_dict__

Make the midpoint state_dict and load it into a `model_mid` which should have `mask_0`.

In [61]:
state_dict = midpoint(model_0.module.state_dict(), model_1__1.module.state_dict())
model_mid = model_hparams.initialize_object()
model_mid.module.load_state_dict(state_dict)
mask = torch.load(f"crazy/mask_0.pt")
model_mid = PrunedClassifier(model_mid, mask)
model_mid.cpu()
model_mid.eval();

The midpoint is roughly in the error sublevel set with accuracy 91.29%

In [62]:
print(get_accuracy(model_mid, test_data))

100%|██████████| 10/10 [00:01<00:00,  9.66it/s]

0.9128999710083008





We can now drop `model_0`, `model_mid` and `mask`.

In [64]:
del model_0, model_mid, mask, model_1__1

Time to prune `mask_1`

In [65]:
mask = PruningHparams(0.2).prune(model_1__1, torch.load(f"crazy/mask_1.pt"))

Save `mask_2`

In [67]:
torch.save(mask, "crazy/mask_2.pt")

In [68]:
del mask

Take `model_1__1` and project it with `mask_2` to get `model_2__0`.

In [69]:
model_2__0 = load_pruned_model("model_1__1", "mask_2", model_hparams)

This model has the correct sparsity of 64%.

In [71]:
get_pruned_model_density(model_2__0)

0.6399909853935242

The test accuracy however has fallen to 90.89%.

In [72]:
print(get_accuracy(model_2__0, test_data))

100%|██████████| 10/10 [00:01<00:00,  9.92it/s]

0.9088999629020691





In [74]:
torch.save(model_2__0.module.state_dict(), "crazy/model_2__0.pt")

In [75]:
del model_2__0

In [77]:
model = load_pruned_model("model_2__0", "mask_2", model_hparams)
model.cuda()
model.train()
max_duration = "1950ba"
optimizer = SGDHparams(lr=0.001, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
scheduler = ConstantSchedulerHparams().initialize_object()
seed = 789
logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
callback = [LRMonitorHparams().initialize_object()]
trainer = Trainer(model=model, 
                  train_dataloader=train_data, 
                  max_duration=max_duration, 
                  eval_dataloader=test_data,
                  optimizers=optimizer, 
                  schedulers=scheduler,
                  device="gpu", 
                  validate_every_n_batches=195, 
                  validate_every_n_epochs=-1,
                  precision="amp", 
                  seed=seed, 
                  loggers=logger,
                  callbacks=callback,
                  save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
                  save_name_format="{batch}ba",
                  save_latest_format=None,
                  save_interval="195ba",
                  save_weights_only=True)
trainer.fit()
model.cpu()
model.eval();




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy/val,▂▅▅▄▆▆█▅▆▁
crossentropyloss/val,▅▆▄▅▁▇▄▂█▂
epoch,▁▁▃▃▅▅▆▆███
loss/train,▄▂▅▁▃▆▂▃▃▂▂▁▄▃▂▃▃█▃▄▂▃▆▂▂▃▃▄▄▁▆▁▂▆▂▃▁▂▂▃
lr-SGD/group0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/batch_idx,▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
accuracy/val,0.9103
crossentropyloss/val,0.3637
epoch,4.0
loss/train,0.01381
lr-SGD/group0,0.001
trainer/batch_idx,389.0
trainer/global_step,1950.0
