In [1]:
from composer.algorithms import ChannelsLastHparams
from composer.callbacks import LRMonitorHparams
from composer.core.types import DataLoader
from composer.datasets import DataLoaderHparams
from composer.loggers import WandBLoggerHparams
from composer.models import ComposerClassifier
from composer.optim import (SGDHparams, ConstantSchedulerHparams, CosineAnnealingSchedulerHparams, 
                            CosineAnnealingWithWarmupSchedulerHparams, MultiStepSchedulerHparams, 
                            MultiStepWithWarmupSchedulerHparams)
from composer.trainer import Trainer, TrainerHparams
from composer.utils.object_store import ObjectStoreProviderHparams, ObjectStoreProvider
from copy import deepcopy
from lth_diet.data import CIFAR10DataHparams, DataHparams
from lth_diet.exps import LotteryExperiment
from lth_diet.models import ResNetCIFARClassifierHparams, ClassifierHparams
from lth_diet.pruning import Mask, PrunedClassifier, PruningHparams
from lth_diet.pruning.pruned_classifier import prunable_layer_names
from lth_diet.utils import utils
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import torch
from tqdm import tqdm
from typing import Callable, Tuple
plt.style.use("default")
rc = {"figure.figsize": (4, 3), "figure.dpi": 150, "figure.constrained_layout.use": True, "axes.grid": True, 
      "axes.spines.right": False, "axes.spines.top": False, "axes.linewidth": 0.6, "grid.linewidth": 0.6,
      "xtick.major.width": 0.6, "ytick.major.width": 0.6, "xtick.major.size": 4, "ytick.major.size": 4, 
      "axes.labelsize": 11, "axes.titlesize": 11, "xtick.labelsize": 10, "ytick.labelsize": 10,
      "axes.titlepad": 4, "axes.labelpad": 2, "xtick.major.pad": 2, "ytick.major.pad": 2,
      "lines.linewidth": 1.2, 'lines.markeredgecolor': 'w', "patch.linewidth": 0}
sns.set_theme(style='ticks', palette=sns.color_palette("tab10"), rc=rc)
object_store = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY').initialize_object()
bucket_dir = os.environ['OBJECT_STORE_DIR']

Helper Functions

In [2]:
# Load a checkpoint from the cloud and make a PrunedClassifier
def load_pruned_model_from_cloud(
    exp: LotteryExperiment,
    replicate: int,
    level: int,
    ckpt: str,
    model_hparams: ClassifierHparams,
    object_store: ObjectStoreProvider,
) -> PrunedClassifier:
    location = f"{utils.get_hash(exp.name)}/replicate_{replicate}/level_{level}/main"
    name = f"model_{ckpt}.pt"
    state_dict = utils.load_object(location, name, object_store, torch.load)
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    mask = Mask.load(location, object_store)
    model = PrunedClassifier(model, mask)
    model.cpu()
    model.eval()
    return model

# Load a mask from the cloud
def load_mask_from_cloud(exp: LotteryExperiment, replicate: int, level: int, object_store: ObjectStoreProvider) -> Mask:
    location = f"{utils.get_hash(exp.name)}/replicate_{replicate}/level_{level}/main"
    mask = Mask.load(location, object_store)
    return mask

# Calculate density of the pruned model
def get_pruned_model_density(model: PrunedClassifier) -> float:
    names = [PrunedClassifier.to_mask_name(name) for name in prunable_layer_names(model)]
    params = torch.cat([getattr(model, name).flatten() for name in names])
    return (params != 0).float().mean().item()

# Calculate accuracy of model on data
def get_accuracy(model: ComposerClassifier, data: DataLoader) -> float:
    model.cuda()
    model.eval()
    correct = 0
    for batch in tqdm(data):
        batch = batch[0].cuda(), batch[1].cuda()
        logits = model(batch)
        correct += (logits.argmax(dim=-1) == batch[1]).sum()
    model.cpu()
    return (correct / len(data.dataset)).item()

# Save a model ckpt
def save_model(model: ComposerClassifier, name: str) -> None:
    path_ckpt = f"crazy/{name}.pt"
    torch.save(model.module.state_dict(), path_ckpt)

# Save a mask
def save_mask(mask: Mask, name: str) -> None:
    path_mask = f"crazy/{name}.pt"
    torch.save(mask, path_mask)
    
# Load a mask
def load_mask(name: str) -> Mask:
    return torch.load(f"crazy/{name}.pt")

# Load a pruned classifier from a model ckpt and a saved mask
def load_model(model_name: str, mask_name: str, model_hparams: ClassifierHparams) -> PrunedClassifier:
    state_dict = torch.load(f"crazy/{model_name}.pt")
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    mask = load_mask(mask_name)
    model = PrunedClassifier(model, mask)
    model.cpu()
    model.eval()
    return model

# get the midpoint of 2 state_dicts
def midpoint(state_dict, state_dict_):
    state_dict__ = {}
    for k, v in state_dict.items():
        state_dict__[k] = (v + state_dict_[k]) / 2
    return state_dict__

# compare 2 models to their midpoint on data using eval_fn
def compare_to_midpoint(
    model_a_name: str,
    mask_a_name: str,
    model_b_name: str,
    mask_b_name: str,
    data: DataLoader,
    eval_fn: Callable,
    model_hparams: ClassifierHparams,
) -> Tuple[float, float, float]:
    model_a = load_model(model_a_name, mask_a_name, model_hparams)
    output_a = eval_fn(model_a, data)
    model_b = load_model(model_b_name, mask_b_name, model_hparams)
    output_b = eval_fn(model_b, data)
    state_dict = midpoint(model_a.module.state_dict(), model_b.module.state_dict())
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    assert(get_pruned_model_density(model_a) >= get_pruned_model_density(model_b))
    mask = load_mask(mask_a_name)
    model = PrunedClassifier(model, mask)
    model.cpu()
    model.eval()
    output = eval_fn(model, data)
    return (output_a, output, output_b)

Use LotteryExperiment with rewinding step = 2000ba as it did the well and I want to start by using a late rewinding step.  
Use replicate 3, no good reason.  
Start at level 1 because there is a small error bump between level 0 and level 1.

In [3]:
config = "../configs/lottery_test.yaml"
exp = LotteryExperiment.create(config, cli_args=False)
exp.rewinding_steps = "2000ba"
replicate = 3
level = 1
print(exp.name, replicate, level)

Lottery(model=ResNetCIFAR(num_classes=10,num_layers=20),train_data=CIFAR10(train=True),train_batch_size=128,optimizer=SGDHparams(lr=0.1,momentum=0.9,weight_decay=0.0001,dampening=0.0,nesterov=False),schedulers=[MultiStepSchedulerHparams(milestones=[31200ba,46800ba],gamma=0.1)],max_duration=62400ba,seed=6174,rewinding_steps=2000ba,pruning=PruningHparams(pruning_fraction=0.2),algorithms=[ChannelsLastHparams()],callbacks=[LRMonitorHparams()]) 3 1


Prepare train and test dataloaders as well as model hparams.  
Train batch size = 128.

In [4]:
train_data = CIFAR10DataHparams(True).initialize_object(128, DataLoaderHparams(persistent_workers=False))
test_data = CIFAR10DataHparams(False).initialize_object(1000, DataLoaderHparams(persistent_workers=False))
model_hparams = ResNetCIFARClassifierHparams(10, 20)

Load the final model from replicate 3 level 1 as a PrunedClassifier. Also load the mask for level 1.

In [5]:
# model = load_pruned_model_from_cloud(exp, replicate, level, "final", model_hparams, object_store)
# mask = load_mask_from_cloud(exp, replicate, level, object_store)
model = load_model("model_1__1", "mask_1", model_hparams)
mask = load_mask("mask_1")

Check the density of the model and the mask and the accuracy of the model.  
Both the mask and the model have density = 80%.  
The model achieves test accuracy = 91.76%.

In [6]:
print("Mask density:", mask.density.item())
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Mask density: 0.7999933362007141
Model density: 0.7999933362007141


100%|██████████| 10/10 [00:01<00:00,  9.96it/s]

Model test accuracy: 0.9175999760627747





Save the level 1 final model and mask as `model_1__1` and `mask_1`.

In [7]:
# save_model(model, "model_1__1")
# save_mask(mask, "mask_1")

Magnitude prune 20% of the remaining weights of `model_1__1` (`mask_1`) to generate `mask_2`.

In [8]:
model = load_model("model_1__1", "mask_1", model_hparams)
mask = load_mask("mask_1")
pruning_fraction = 0.2
mask = PruningHparams(pruning_fraction).prune(model, mask)
# save_mask(mask, "mask_2")

Sanity check `mask_2`. It should have 64% of weights remaining and should be identical to the mask in level 2 in the cloud.

In [9]:
def sanity_check():
    mask = load_mask("mask_2")
    print(mask.density.item())
    test_mask = load_mask_from_cloud(exp, replicate, 2, object_store)
    for k, v in test_mask.items():
        if not (v == mask[k]).all().item():
            print("Sanity check failed")
            return
    print("Sanity check passed")
    return

sanity_check()

0.6399909853935242
Sanity check passed


Next we construct `model_2__0` by applying `mask_2` to `model_1__1`.

In [10]:
model = load_model("model_1__1", "mask_2", model_hparams)
# save_model(model, "model_2__0")

What is the density and accuracy of `model_2__0`?  
Density = 64%  
Test Accuracy = 91.84%  
`model_2__0` is even better than `model_1__1`.

In [11]:
model = load_model("model_2__0", "mask_2", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.6399909853935242


100%|██████████| 10/10 [00:00<00:00, 10.04it/s]

Model test accuracy: 0.91839998960495





We calculate the test accuracy of `model_1__1`, `model_2__0`, and the midpoint between them. For the midpoint, we use the mask of the denser model.  
`model_1__1` = 91.76%  
midpoint = 91.87%  
`model_2__0` = 91.84%  
Both models are linearly connected in the same error sublevel set. We call them ***ERROR CONNECTED***.

In [12]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_1__1", "mask_1", "model_2__0", "mask_2", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.90it/s]
100%|██████████| 10/10 [00:00<00:00, 10.71it/s]
100%|██████████| 10/10 [00:00<00:00, 10.58it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.9175999760627747, 0.9187999963760376, 0.91839998960495)





Because of this lovely property, `model_2__1` will just be `model_2__0`.

In [13]:
model = load_model("model_2__0", "mask_2", model_hparams)
# save_model(model, "model_2__1")

Magnitude prune 20% of the remaining weights of `model_2__1` (`mask_2`) to generate `mask_3` which has 51% of the weights remaining.

In [14]:
model = load_model("model_2__1", "mask_2", model_hparams)
mask = load_mask("mask_2")
pruning_fraction = 0.2
mask = PruningHparams(pruning_fraction).prune(model, mask)
print("Mask density =", mask.density.item())
# save_mask(mask, "mask_3")

Mask density = 0.5119861364364624


Next we construct `model_3__0` by applying `mask_3` to `model_2__1`.

In [15]:
model = load_model("model_2__1", "mask_3", model_hparams)
# save_model(model, "model_3__0")

What is the density and accuracy of `model_3__0`?  
Density = 51%  
Test Accuracy = 91.01%  
`model_3__0` is not as good and needs further training.

In [16]:
model = load_model("model_3__0", "mask_3", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.5119861364364624


100%|██████████| 10/10 [00:00<00:00, 10.54it/s]

Model test accuracy: 0.910099983215332





I tried a bunch of different training hyperparameters. This seemed to work. Saving models along the way so we can load a good one. 

In [17]:
# model = load_model("model_3__0", "mask_3", model_hparams)
# model.cuda()
# model.train()
# max_duration = "7800ba"
# algorithms = [ChannelsLastHparams().initialize_object()]
# optimizer = SGDHparams(lr=0.001, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
# scheduler = CosineAnnealingWithWarmupSchedulerHparams(t_warmup="3120ba", alpha_f=0.0).initialize_object()
# seed = 789
# logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
# callback = [LRMonitorHparams().initialize_object()]
# trainer = Trainer(model=model, 
#                   train_dataloader=train_data, 
#                   max_duration=max_duration, 
#                   eval_dataloader=test_data,
#                   algorithms=algorithms,
#                   optimizers=optimizer, 
#                   schedulers=scheduler,
#                   device="gpu", 
#                   validate_every_n_batches=195, 
#                   validate_every_n_epochs=-1,
#                   precision="amp", 
#                   seed=seed, 
#                   loggers=logger,
#                   callbacks=callback,
#                   save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
#                   save_name_format="{batch}ba",
#                   save_latest_format=None,
#                   save_interval="195ba",
#                   save_weights_only=True)
# trainer.fit()
# model.cpu()
# model.eval();

The model at 4680ba worked pretty well. Load in the composer format and save to our format. 

In [18]:
# model = load_model("model_3__0", "mask_3", model_hparams)
# state_dict = torch.load("crazy/ckpts/4680ba")
# model.load_state_dict(state_dict["state"]["model"])
# model._apply_mask()
# save_model(model, "model_3__1")

The trained model at level 3: `model_3__1`  
Density = 51%  
Test Accuracy = 91.77%  

In [19]:
model = load_model("model_3__1", "mask_3", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.5119861364364624


100%|██████████| 10/10 [00:00<00:00, 10.67it/s]

Model test accuracy: 0.9176999926567078





Next we check that model that `model_2__1` and `mmodel_3__1` are indeed ***ERROR CONNECTED***.  
Test accuracy `model_2__1` = 91.84%  
Test accuracy midpoint = 91.86%  
Test accuracy `model_3__1` = 91.77%  

In [20]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_2__1", "mask_2", "model_3__1", "mask_3", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.18it/s]
100%|██████████| 10/10 [00:00<00:00, 10.82it/s]
100%|██████████| 10/10 [00:00<00:00, 10.60it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.91839998960495, 0.9185999631881714, 0.9176999926567078)





We now want to see if `mask_3`, which was obtained through Subelevel Error Search can be applied to the dense model at the rewind step such that the model successfully trains to high test accuracy into the same error sublevel set.  
Load the dense rewind model from the cloud and save it as `model_rewind`. This model is the initial model in level 0. For `model_rewind`:  
Density = 100%  
Test Accuracy = 68.96%

In [21]:
model = load_pruned_model_from_cloud(exp, replicate, 0, "init", model_hparams, object_store)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
# save_model(model, "model_rewind")

Model density: 1.0


100%|██████████| 10/10 [00:01<00:00,  9.95it/s]

Model test accuracy: 0.6895999908447266





We can now apply `mask_3` to `model_rewind` to get `model_3__i`. `model_3__i` should have a density of 51% and low test accuracy. Indeed, for `model_3__i`,  
Density = 51%  
Test Accuracy = 15.01%  

In [22]:
model = load_model("model_rewind", "mask_3", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
# save_model(model, "model_3__i")

Model density: 0.5119861364364624


100%|██████████| 10/10 [00:00<00:00, 10.32it/s]

Model test accuracy: 0.1500999927520752





Now we train `model_3__i` using the same hyperparameters as the dense model accounting for the fact that 2000 batches of pretraining is complete. This is testing if the mask obtained without rewinding can be used to train a model that has been rewinded.

In [23]:
# model = load_model("model_3__i", "mask_3", model_hparams)
# model.cuda()
# model.train()
# max_duration = "60400ba"
# algorithms = [ChannelsLastHparams().initialize_object()]
# optimizer = SGDHparams(lr=0.1, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
# scheduler = MultiStepSchedulerHparams(milestones=["29200ba", "44800ba"], gamma=0.1).initialize_object()
# seed = 6174
# logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
# callback = [LRMonitorHparams().initialize_object()]
# trainer = Trainer(
#     model=model, 
#     train_dataloader=train_data, 
#     max_duration=max_duration, 
#     eval_dataloader=test_data,
#     algorithms=algorithms,
#     optimizers=optimizer, 
#     schedulers=scheduler,
#     device="gpu", 
#     # validate_every_n_batches=195, 
#     # validate_every_n_epochs=-1,
#     precision="amp", 
#     step_schedulers_every_batch=True,
#     seed=seed, 
#     loggers=logger,
#     callbacks=callback,
#     # save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
#     # save_name_format="{batch}ba",
#     # save_latest_format=None,
#     # save_interval="195ba",
#     # save_weights_only=True)
# )
# trainer.fit()
# model.cpu()
# model.eval();

We save the model trained from the rewind point but with `mask_3` as `model_3__f`. It has
Density = 51%
Test Accuracy = 91.64%

In [24]:
# load_model("model_3__f", "mask_3",)
# print("Model density:", get_pruned_model_density(model))
# print("Model test accuracy:", get_accuracy(model, test_data))
# # save_model(model, "model_3__f")

Remarkably `model_3__f` and `mmodel_3__1` are indeed ***ERROR CONNECTED***.
Test accuracy `model_3__f` = 91.84%  
Test accuracy midpoint = 91.86%  
Test accuracy `model_3__1` = 91.77%  

In [25]:
# print(
#     "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
#     compare_to_midpoint("model_3__f", "mask_3", "model_3__1", "mask_3", test_data, get_accuracy, model_hparams)
# )

Magnitude prune 20% of the remaining weights of `model_3__1` (`mask_3`) to generate `mask_4` which has 41% of the weights remaining.

In [27]:
model = load_model("model_3__1", "mask_3", model_hparams)
mask = load_mask("mask_3")
pruning_fraction = 0.2
mask = PruningHparams(pruning_fraction).prune(model, mask)
print("Mask density =", mask.density.item())
# save_mask(mask, "mask_4")

Mask density = 0.40958523750305176


Next we construct `model_4__0` by applying `mask_4` to `model_3__1`.

In [28]:
model = load_model("model_3__1", "mask_4", model_hparams)
# save_model(model, "model_4__0")

What is the density and accuracy of `model_4__0`?  
Density = 41%  
Test Accuracy = 89.05%  
`model_4__0` is not as good and needs further training.

In [29]:
model = load_model("model_4__0", "mask_4", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.40958523750305176


100%|██████████| 10/10 [00:00<00:00, 10.05it/s]

Model test accuracy: 0.8904999494552612





In [30]:
# model = load_model("model_4__0", "mask_4", model_hparams)
# model.cuda()
# model.train()
# max_duration = "7800ba"
# algorithms = [ChannelsLastHparams().initialize_object()]
# optimizer = SGDHparams(lr=0.01, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
# scheduler = MultiStepSchedulerHparams(gamma=0.1, milestones=["3120ba"]).initialize_object()
# seed = 1234
# logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
# callback = [LRMonitorHparams().initialize_object()]
# trainer = Trainer(model=model, 
#                   train_dataloader=train_data, 
#                   max_duration=max_duration, 
#                   eval_dataloader=test_data,
#                   algorithms=algorithms,
#                   optimizers=optimizer, 
#                   schedulers=scheduler,
#                   device="gpu", 
#                   validate_every_n_batches=195, 
#                   validate_every_n_epochs=-1,
#                   precision="amp", 
#                   seed=seed, 
#                   loggers=logger,
#                   callbacks=callback,
#                   save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
#                   save_name_format="{batch}ba",
#                   save_latest_format=None,
#                   save_interval="195ba",
#                   save_weights_only=True)
# trainer.fit()
# model.cpu()
# model.eval();

The model at 5265ba worked pretty well. Load in the composer format and save to our format. 

In [31]:
# model = load_model("model_4__0", "mask_4", model_hparams)
# state_dict = torch.load("crazy/ckpts/5265ba")
# model.load_state_dict(state_dict["state"]["model"])
# model._apply_mask()
# save_model(model, "model_4__1")

The trained model at level 4: `model_4__1`  
Density = 41%  
Test Accuracy = 91.62%  

In [32]:
model = load_model("model_4__1", "mask_4", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.40958523750305176


100%|██████████| 10/10 [00:00<00:00, 10.53it/s]

Model test accuracy: 0.9161999821662903





Next we check that model that `model_3__1` and `model_4__1` are indeed ***ERROR CONNECTED***.  
Test accuracy `model_3__1` = 91.77%  
Test accuracy midpoint = 91.62%  
Test accuracy `model_4__1` = 91.62%  

In [33]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_3__1", "mask_3", "model_4__1", "mask_4", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.08it/s]
100%|██████████| 10/10 [00:01<00:00,  9.70it/s]
100%|██████████| 10/10 [00:00<00:00, 10.39it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.9176999926567078, 0.9161999821662903, 0.9161999821662903)





We can now apply `mask_4` to `model_rewind` to get `model_4__i`. `model_4__i` should have a density of 41% and low test accuracy. Indeed, for `model_4__i`,  
Density = 41%  
Test Accuracy = 10.29%  

In [34]:
model = load_model("model_rewind", "mask_4", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
# save_model(model, "model_4__i")

Model density: 0.40958523750305176


100%|██████████| 10/10 [00:00<00:00, 10.19it/s]

Model test accuracy: 0.10289999842643738





Now we train `model_4__i` using the same hyperparameters as the dense model accounting for the fact that 2000 batches of pretraining is complete. This is testing if the mask obtained without rewinding can be used to train a model that has been rewinded.

In [35]:
# model = load_model("model_4__i", "mask_4", model_hparams)
# model.cuda()
# model.train()
# max_duration = "60400ba"
# algorithms = [ChannelsLastHparams().initialize_object()]
# optimizer = SGDHparams(lr=0.1, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
# scheduler = MultiStepSchedulerHparams(milestones=["29200ba", "44800ba"], gamma=0.1).initialize_object()
# seed = 6174
# logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
# callback = [LRMonitorHparams().initialize_object()]
# trainer = Trainer(
#     model=model, 
#     train_dataloader=train_data, 
#     max_duration=max_duration, 
#     eval_dataloader=test_data,
#     algorithms=algorithms,
#     optimizers=optimizer, 
#     schedulers=scheduler,
#     device="gpu", 
#     # validate_every_n_batches=195, 
#     # validate_every_n_epochs=-1,
#     precision="amp", 
#     step_schedulers_every_batch=True,
#     seed=seed, 
#     loggers=logger,
#     callbacks=callback,
#     # save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
#     # save_name_format="{batch}ba",
#     # save_latest_format=None,
#     # save_interval="195ba",
#     # save_weights_only=True)
# )
# trainer.fit()
# model.cpu()
# model.eval();

We save the model trained from the rewind point but with `mask_4` as `model_4__f`. It has
Density = 41%
Test Accuracy = 91.79%

In [40]:
model = load_model("model_4__f", "mask_4", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
# save_model(model, "model_4__f")

Model density: 0.40958523750305176


100%|██████████| 10/10 [00:00<00:00, 10.49it/s]

Model test accuracy: 0.9178999662399292





Remarkably `model_4__f` and `mmodel_4__1` are indeed ***ERROR CONNECTED***.
Test accuracy `model_4__f` = 91.79%  
Test accuracy midpoint = 91.92%  
Test accuracy `model_4__1` = 91.62%  
This is working phenomenally!

In [41]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_4__f", "mask_4", "model_4__1", "mask_4", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.43it/s]
100%|██████████| 10/10 [00:00<00:00, 10.81it/s]
100%|██████████| 10/10 [00:01<00:00,  9.91it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.9178999662399292, 0.9192000031471252, 0.9161999821662903)





Magnitude prune 20% of the remaining weights of `model_4__1` (`mask_4`) to generate `mask_5` which has 33% of the weights remaining.

In [73]:
model = load_model("model_4__1", "mask_4", model_hparams)
mask = load_mask("mask_4")
pruning_fraction = 0.2
mask = PruningHparams(pruning_fraction).prune(model, mask)
print("Mask density =", mask.density.item())
save_mask(mask, "mask_5")

Mask density = 0.32766449451446533


Next we construct `model_5__0` by applying `mask_5` to `model_4__1`.

In [74]:
model = load_model("model_4__1", "mask_5", model_hparams)
save_model(model, "model_5__0")

What is the density and accuracy of `model_5__0`?  
Density = 33%  
Test Accuracy = 86.74%  
`model_5__0` is not as good and needs further training.

In [75]:
model = load_model("model_5__0", "mask_5", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.32766449451446533


100%|██████████| 10/10 [00:00<00:00, 10.34it/s]

Model test accuracy: 0.8673999905586243





In [77]:
model = load_model("model_5__0", "mask_5", model_hparams)
model.cuda()
model.train()
max_duration = "7800ba"
algorithms = [ChannelsLastHparams().initialize_object()]
optimizer = SGDHparams(lr=0.01, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
# scheduler = ConstantSchedulerHparams().initialize_object()
# scheduler = CosineAnnealingSchedulerHparams().initialize_object()
scheduler = MultiStepSchedulerHparams(gamma=0.1, milestones=["1560ba"]).initialize_object()
seed = 7890
logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
callback = [LRMonitorHparams().initialize_object()]
trainer = Trainer(model=model, 
                  train_dataloader=train_data, 
                  max_duration=max_duration, 
                  eval_dataloader=test_data,
                  algorithms=algorithms,
                  optimizers=optimizer, 
                  schedulers=scheduler,
                  device="gpu", 
                  validate_every_n_batches=195, 
                  validate_every_n_epochs=-1,
                  precision="amp", 
                  seed=seed, 
                  loggers=logger,
                  callbacks=callback,
                  save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
                  save_name_format="{batch}ba",
                  save_latest_format=None,
                  save_interval="195ba",
                  save_weights_only=True)
trainer.fit()
model.cpu()
model.eval();




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
/Event.INIT,▁
accuracy/val,▁▂▄▁▄▃▇▅▅▆▆█▇▇▇▆▆▇▇▆▇█▇▇▇▇▇▆▇▇▇█▇▆▆█▇▇█▇
crossentropyloss/val,█▄▇█▅▆▄▆▄▃▃▂▃▂▃▂▁▃▂▂▃▃▂▂▃▂▂▃▃▂▃▃▃▄▄▄▄▃▄▄
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
loss/train,▆█▆▇▃▄▄▃▇▅█▅▃▂▆▂▄▅▄▂▃▂▁▃▂▅▅▃▂▁▃▂▂▄▄▁▄▄▃▁
lr-SGD/group0,████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/batch_idx,▂▆▂▆▂▆▁▆▁▆▁▆▁▇▂▇▂▇▂▆▂▆▂▆▂▆▃▇▃▇▃▇▃▇▂▇▂▇▂█
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
/Event.INIT,1.0
accuracy/val,0.9147
crossentropyloss/val,0.34235
epoch,19.0
loss/train,0.02422
lr-SGD/group0,0.001
trainer/batch_idx,389.0
trainer/global_step,7800.0


The model at 7605ba worked pretty well. Load in the composer format and save to our format. 

In [78]:
model = load_model("model_5__0", "mask_5", model_hparams)
state_dict = torch.load("crazy/ckpts/7605ba")
model.load_state_dict(state_dict["state"]["model"])
model._apply_mask()
save_model(model, "model_5__1")

The trained model at level 5: `model_5__1`  
Density = 33%  
Test Accuracy = 91.52%  

In [79]:
model = load_model("model_5__1", "mask_5", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))

Model density: 0.32766449451446533


100%|██████████| 10/10 [00:00<00:00, 10.72it/s]

Model test accuracy: 0.9151999950408936





Next we check that model that `model_4__1` and `model_5__1` are indeed ***ERROR CONNECTED***.  
Test accuracy `model_4__1` = 91.62%  
Test accuracy midpoint = 91.78%  
Test accuracy `model_5__1` = 91.52%  

In [81]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_4__1", "mask_4", "model_5__1", "mask_5", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.38it/s]
100%|██████████| 10/10 [00:00<00:00, 10.39it/s]
100%|██████████| 10/10 [00:00<00:00, 10.08it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.9161999821662903, 0.9177999496459961, 0.9151999950408936)





We can now apply `mask_5` to `model_rewind` to get `model_5__i`. `model_5__i` should have a density of 33% and low test accuracy. Indeed, for `model_5__i`,  
Density = 33%  
Test Accuracy = 10%  

In [82]:
model = load_model("model_rewind", "mask_5", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
save_model(model, "model_5__i")

Model density: 0.32766449451446533


100%|██████████| 10/10 [00:00<00:00, 10.69it/s]

Model test accuracy: 0.09999999403953552





Now we train `model_5__i` using the same hyperparameters as the dense model accounting for the fact that 2000 batches of pretraining is complete. This is testing if the mask obtained without rewinding can be used to train a model that has been rewinded.

In [83]:
model = load_model("model_5__i", "mask_5", model_hparams)
model.cuda()
model.train()
max_duration = "60400ba"
algorithms = [ChannelsLastHparams().initialize_object()]
optimizer = SGDHparams(lr=0.1, momentum=0.9, weight_decay=0.0001).initialize_object(model.parameters())
scheduler = MultiStepSchedulerHparams(milestones=["29200ba", "44800ba"], gamma=0.1).initialize_object()
seed = 6174
logger = [WandBLoggerHparams("lth_diet", "crazy", entity="prunes").initialize_object()]
callback = [LRMonitorHparams().initialize_object()]
trainer = Trainer(
    model=model, 
    train_dataloader=train_data, 
    max_duration=max_duration, 
    eval_dataloader=test_data,
    algorithms=algorithms,
    optimizers=optimizer, 
    schedulers=scheduler,
    device="gpu", 
    # validate_every_n_batches=195, 
    # validate_every_n_epochs=-1,
    precision="amp", 
    step_schedulers_every_batch=True,
    seed=seed, 
    loggers=logger,
    callbacks=callback,
    # save_folder="/home/mansheej/lth_diet/ipynbs/crazy/ckpts",
    # save_name_format="{batch}ba",
    # save_latest_format=None,
    # save_interval="195ba",
    # save_weights_only=True)
)
trainer.fit()
model.cpu()
model.eval();




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
/Event.INIT,▁
accuracy/val,▁▄▅▅▅▅▇▆▅▆▇▆▆▆▆▆▆▆▆█████████████████████
crossentropyloss/val,█▅▃▃▃▄▂▃▄▄▂▃▃▃▄▃▂▃▃▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss/train,█▇▆▆▅▄▄▄▄▅▃▃▅▅▃▃▄▃▂▅▂▂▃▂▁▂▁▁▁▁▁▁▁▂▂▁▁▂▁▁
lr-SGD/group0,████████████████████▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
trainer/batch_idx,▃█▇▇▆▆▆▅▅▄▄▃▃▂▂▁▁█▇▇▇▆▆▅▅▄▄▃▃▂▂▂▁██▇▇▆▆▅
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
/Event.INIT,1.0
accuracy/val,0.9156
crossentropyloss/val,0.3508
epoch,155.0
loss/train,0.03535
lr-SGD/group0,0.001
trainer/batch_idx,339.0
trainer/global_step,60400.0


We save the model trained from the rewind point but with `mask_5` as `model_5__f`. It has  
Density = 33%  
Test Accuracy = %

In [84]:
model = load_model("model_5__f", "mask_5", model_hparams)
print("Model density:", get_pruned_model_density(model))
print("Model test accuracy:", get_accuracy(model, test_data))
# save_model(model, "model_5__f")

Model density: 0.32766449451446533


100%|██████████| 10/10 [00:00<00:00, 10.73it/s]

Model test accuracy: 0.9156000018119812





Remarkably `model_5__f` and `mmodel_5__1` are indeed ***ERROR CONNECTED***.  
Test accuracy `model_5__f` = 91.56%  
Test accuracy midpoint = 91.40%  
Test accuracy `model_5__1` = 91.73%  
Sad, this didn't work as well.

In [85]:
print(
    "Test accuracy of (model_denser, model_midpoint, model_sparser) =",
    compare_to_midpoint("model_5__f", "mask_5", "model_5__1", "mask_5", test_data, get_accuracy, model_hparams)
)

100%|██████████| 10/10 [00:00<00:00, 10.83it/s]
100%|██████████| 10/10 [00:00<00:00, 10.56it/s]
100%|██████████| 10/10 [00:00<00:00, 10.43it/s]

Test accuracy of (model_denser, model_midpoint, model_sparser) = (0.9156000018119812, 0.914199948310852, 0.9151999950408936)





---