In [3]:
from composer.algorithms import ChannelsLastHparams
from composer.callbacks import LRMonitorHparams
from composer.core.time import Time
from composer.core.types import DataLoader
from composer.datasets import DataLoaderHparams
from composer.loggers import WandBLoggerHparams
from composer.models import ComposerClassifier
from composer.optim import (SGDHparams, ConstantSchedulerHparams, CosineAnnealingSchedulerHparams, 
                            CosineAnnealingWithWarmupSchedulerHparams, MultiStepSchedulerHparams, 
                            MultiStepWithWarmupSchedulerHparams)
from composer.trainer import Trainer, TrainerHparams
from composer.utils.object_store import ObjectStoreProviderHparams, ObjectStoreProvider
from copy import deepcopy
from lth_diet.data import CIFAR10DataHparams, DataHparams, CINIC10DataHparams
from lth_diet.exps import LotteryExperiment, LotteryRetrainExperiment
from lth_diet.models import ResNetCIFARClassifierHparams, ClassifierHparams
from lth_diet.pruning import Mask, PrunedClassifier, PruningHparams
from lth_diet.pruning.pruned_classifier import prunable_layer_names
from lth_diet.utils import utils
from numpy.typing import NDArray
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import torch
from tqdm import tqdm
from typing import Callable, Dict, Tuple
plt.style.use("default")
rc = {"figure.figsize": (4, 3), "figure.dpi": 150, "figure.constrained_layout.use": True, "axes.grid": True, 
      "axes.spines.right": False, "axes.spines.top": False, "axes.linewidth": 0.6, "grid.linewidth": 0.6,
      "xtick.major.width": 0.6, "ytick.major.width": 0.6, "xtick.major.size": 4, "ytick.major.size": 4, 
      "axes.labelsize": 11, "axes.titlesize": 11, "xtick.labelsize": 10, "ytick.labelsize": 10,
      "axes.titlepad": 4, "axes.labelpad": 2, "xtick.major.pad": 2, "ytick.major.pad": 2,
      "lines.linewidth": 1.2, 'lines.markeredgecolor': 'w', "patch.linewidth": 0}
sns.set_theme(style='ticks', palette=sns.color_palette("colorblind"), rc=rc)
object_store = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY').initialize_object()
bucket_dir = os.environ['OBJECT_STORE_DIR']

In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f14d7f3fac0>

In [4]:
config = f"../configs/lottery_cinic10_retrain.yaml"
exp = LotteryRetrainExperiment.create(f=config, cli_args=False)
load_replicates = [0, 1, 2, 3]
replicates = [0, 1]
rewinding_steps = [400, 800, 1600]
model_hparams = ResNetCIFARClassifierHparams(10, 56)
train_data = CINIC10DataHparams(True,False,False,True,).initialize_object(1000, DataLoaderHparams(persistent_workers=False))
test_data = CINIC10DataHparams(False).initialize_object(1000, DataLoaderHparams(persistent_workers=False))

FileNotFoundError: [Errno 2] No such file or directory: '/home/mansheej/lth_diet/data/cinic10/train'

In [5]:
def losses_and_state_dict(location, name, train_data):
    state_dict = utils.load_object(location, name, object_store, torch.load)
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    print(f"      Loaded {utils.get_object_name(location, name)}")
    losses = []
    model.cuda()
    model.eval()
    for batch in train_data:
        batch = batch[0].cuda(), batch[1].cuda()
        logits = model(batch)
        losses.append(torch.nn.CrossEntropyLoss(reduction="none")(logits, batch[1])) 
    losses = torch.cat(losses).cpu().numpy()
    print(f"      Evaluated {utils.get_object_name(location, name)}")
    return losses, state_dict

def losses_only(state_dict, train_data):
    model = model_hparams.initialize_object()
    model.module.load_state_dict(state_dict)
    losses = []
    model.cuda()
    model.eval()
    for batch in train_data:
        batch = batch[0].cuda(), batch[1].cuda()
        logits = model(batch)
        losses.append(torch.nn.CrossEntropyLoss(reduction="none")(logits, batch[1])) 
    losses = torch.cat(losses).cpu().numpy()
    print(f"      Evaluated midpoint")
    return losses

def midpoint(state_dict: Dict, state_dict_: Dict) -> Dict:
    """Return the midpoint between two state dicts"""
    state_dict__ = {}
    for k, v in state_dict.items():
        state_dict__[k] = (v + state_dict_[k]) / 2
    return state_dict__

In [None]:
train_loss_barriers = np.zeros((len(rewinding_steps), len(load_replicates), 3, 50000))
for i, rstep in enumerate(rewinding_steps):
    print("rewinding step:", rstep)
    for j, lrep in enumerate(load_replicates):
        print("  replicate:", lrep)
        # setup experiment
        exp.load_exp.rewinding_steps = f"{rstep}ba"
        exp.load_replicate = lrep
        # parent
        location = f"{utils.get_hash(exp.load_exp.name)}/replicate_{lrep}/level_0/main"
        losses_, state_dict_ = losses_and_state_dict(location, "model_final.pt", train_data)
        # child 0 
        location = f"{utils.get_hash(exp.load_exp.name)}/replicate_{lrep}/level_0/{utils.get_hash(exp.name)}/replicate_{0}/main"
        losses_0, state_dict_0 = losses_and_state_dict(location, "model_final.pt", train_data)
        # child 0 
        location = f"{utils.get_hash(exp.load_exp.name)}/replicate_{lrep}/level_0/{utils.get_hash(exp.name)}/replicate_{1}/main"
        losses_1, state_dict_1 = losses_and_state_dict(location, "model_final.pt", train_data)
        # midpoint 0
        state_dict = midpoint(state_dict_, state_dict_0)
        losses = losses_only(state_dict, train_data)
        train_loss_barriers[i, j, 0] = losses - (losses_ + losses_0) / 2
        # midpoint 1
        state_dict = midpoint(state_dict_, state_dict_1)
        losses = losses_only(state_dict, train_data)
        train_loss_barriers[i, j, 1] = losses - (losses_ + losses_1) / 2
        # midpoint 2
        state_dict = midpoint(state_dict_0, state_dict_1)
        losses = losses_only(state_dict, train_data)
        train_loss_barriers[i, j, 2] = losses - (losses_0 + losses_1) / 2

In [None]:
train_loss_barriers.mean(1).mean(1)

In [None]:
error_norm = utils.load_object("scores", "error_norm_cinic10_resnet56_8000ba_10reps_seed1234.npy", object_store, np.load)

In [None]:
plt.plot(error_norm, train_loss_barriers[0], '.', alpha=0.2)