In [1]:
from composer.datasets.dataloader import DataLoaderHparams
from composer.utils.object_store import ObjectStoreProviderHparams
from lth_diet.data import CIFAR10DataHparams, CIFAR100DataHparams, DataHparams
from lth_diet.exps import TrainExperiment
from lth_diet.models import ClassifierHparams, ResNetClassifierHparams
from lth_diet.utils import utils
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
plt.style.use("default")
plt.style.use("ggplot")

In [2]:
# ensemble
network = "resnet50"
data = "cifar100"
max_duration = "7800ba"
seed = 789
num_replicates = 16
# scoring
data_hparams = CIFAR100DataHparams(train=True, shuffle=False, drop_last=False, no_augment=True)
model_hparams = ResNetClassifierHparams(num_classes=100, num_layers=50, low_res=True)

In [5]:
config = f"/home/mansheej/lth_diet/configs/{data}_{network}.yaml"
exp = TrainExperiment.create(f=config, cli_args=False)
exp.max_duration = max_duration
exp.seed = seed
exp_hash = utils.get_hash(exp.name)
print(exp_hash)

d4821907c8b6831a5aeaeeb0e15f3b8f


In [6]:
def calculate_scores(
    exp_hash: str, num_replicates: int, data_hparams: DataHparams, model_hparams: ClassifierHparams
) -> NDArray[np.float32]:
    torch.set_grad_enabled(False)
    object_store = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY', ).initialize_object()
    dl = data_hparams.initialize_object(batch_size=1000, dataloader_hparams=DataLoaderHparams())
    model = model_hparams.initialize_object()
    scores = []
    for r in tqdm(range(num_replicates)):
        object_store.download_object(
            f"{os.environ['OBJECT_STORE_DIR']}/{exp_hash}/replicate_{r}/main/model_final.pt", "model.pt"
        )
        model.load_state_dict(torch.load("model.pt"))
        model.cuda()
        scores_r = []
        for batch in dl:
            batch = [tensor.cuda() for tensor in batch]
            probs, targs = F.softmax(model(batch), dim=-1), F.one_hot(batch[1], model_hparams.num_classes)
            scores_rb = torch.linalg.vector_norm(probs - targs, dim=-1)
            scores_r.append(scores_rb)
        scores.append(torch.cat(scores_r))
        os.remove("model.pt")
    scores = torch.stack(scores).mean(dim=0).cpu().numpy()
    return scores

In [7]:
scores = calculate_scores(exp_hash, num_replicates, data_hparams, model_hparams)

100%|██████████| 16/16 [04:24<00:00, 16.52s/it]


In [11]:
np.save("scores.npy", scores)
target = f"{os.environ['OBJECT_STORE_DIR']}/scores/error_norm_{data}_{network}_{max_duration}_{num_replicates}reps_seed{seed}.npy"
object_score = ObjectStoreProviderHparams('google_storage', 'prunes', 'GCS_KEY', ).initialize_object()
object_score.upload_object("scores.npy", target)
os.remove("scores.npy")