In [1]:
from composer.datasets.dataloader import DataLoaderHparams
from composer.utils.object_store import ObjectStoreProviderHparams
from lth_diet.data import CIFAR10DataHparams, DataHparams
from lth_diet.exps import TrainExperiment
from lth_diet.models import ComposerClassifierHparams, ResNetCIFARClassifierHparams
from lth_diet.utils import utils
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
plt.style.use("default")
plt.style.use("ggplot")

In [2]:
# Experiment
exp = TrainExperiment.create(f='/home/mansheej/lth_diet/configs/cifar10_resnet20.yaml', cli_args=False)
exp.max_duration = "6240ba"
exp.seed = "1009"
exp_hash = utils.get_hash(exp.name)
print(f"EXPERIMENT HASH: {exp_hash}")

EXPERIMENT HASH: 7021678d112f70acb93002144a553e8d


In [3]:
def calculate_scores(
    exp_hash: str, num_replicates: int, data_hparams: DataHparams, model_hparams: ComposerClassifierHparams
) -> NDArray[np.float64]:
    torch.set_grad_enabled(False)
    object_store = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY', ).initialize_object()
    dl = data_hparams.initialize_object(batch_size=1000, dataloader_hparams=DataLoaderHparams())
    model = model_hparams.initialize_object()
    scores = []
    for r in tqdm(range(num_replicates)):
        object_store.download_object(f"exps/{exp_hash}/replicate_{r}/main/model_final.pt", "model.pt")
        model.load_state_dict(torch.load("model.pt"))
        model.cuda()
        scores_r = []
        for batch in dl:
            batch = [tensor.cuda() for tensor in batch]
            probs, targs = F.softmax(model(batch), dim=-1), F.one_hot(batch[1], model.num_classes)
            scores_rb = torch.linalg.vector_norm(probs - targs, dim=-1)
            scores_r.append(scores_rb)
        scores.append(torch.cat(scores_r))
        os.remove("model.pt")
    scores = torch.stack(scores).mean(dim=0).cpu().numpy()
    return scores

In [4]:
num_replicates = 16
scores = calculate_scores(
    exp_hash, 
    num_replicates,
    CIFAR10DataHparams(train=True, shuffle=False, drop_last=False, no_augment=True),
    ResNetCIFARClassifierHparams(num_classes=10, num_layers=20)
)

100%|██████████| 16/16 [00:53<00:00,  3.32s/it]


In [19]:
np.save("scores.npy", scores)
object_score = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY', ).initialize_object()
object_score.upload_object("scores.npy", "exps/scores/error_norm_cifar10_resnet20_6240ba_16reps_seed1009.npy")
os.remove("scores.npy")