In [1]:
%load_ext autoreload
%autoreload 2

## Experiment Config

In [2]:
from generalization.utils.train import DEFAULT_PARAMS as hparams
from generalization.randomization import available_corruptions

print("Available corruptions:\n", available_corruptions())

hparams['dataset_name'] = 'cifar10'
hparams['n_classes'] = 10
hparams['corrupt_name'] = 'normal_labels'
hparams['corrupt_prob'] = 0
hparams['gradient_clipping'] = True
hparams['lr'] = 0.04
hparams['momentum'] = 0.9
hparams['weight_decay'] = 0.0
hparams["lr_scheduler"] = False

hparams # same as generalization/configs/cifar-normal_labels.yaml

Available corruptions:
 ['gaussian_pixels', 'random_labels', 'random_pixels', 'partial_labels', 'shuffled_pixels']


{'seed': 88,
 'batch_size': 256,
 'learning_rate': 0.1,
 'epochs': 30,
 'val_every': 1,
 'log_dir': 'logs',
 'dataset_name': 'cifar10',
 'n_classes': 10,
 'corrupt_name': 'normal_labels',
 'corrupt_prob': 0,
 'gradient_clipping': True,
 'lr': 0.04,
 'momentum': 0.9,
 'weight_decay': 0.0,
 'lr_scheduler': False}

## Build Data for Experiment: `normal_labels`

In [3]:
from generalization.utils.data import build_experiment

experiment_data = build_experiment(
    0.0,
    corrupt_name="normal_labels",
    batch_size=hparams["batch_size"],
)
experiment_data["normal_labels"]

Files already downloaded and verified
Files already downloaded and verified


{'train_set': Dataset CIFAR10
     Number of datapoints: 50000
     Root location: /data/cifar10
     Split: Train, Corruption: normal_labels,
 'val_set': <torch.utils.data.dataset.Subset at 0x7efcd2697ac0>,
 'test_set': <torch.utils.data.dataset.Subset at 0x7efcd2697b20>,
 'train_loader': <torch.utils.data.dataloader.DataLoader at 0x7efdecce3a00>,
 'val_loader': <torch.utils.data.dataloader.DataLoader at 0x7efcd2697a30>,
 'test_loader': <torch.utils.data.dataloader.DataLoader at 0x7efcd2697bb0>}

## Build Modules: `models` & `datamodule`

In [4]:
from generalization.models import get_cifar_models
from generalization.utils.model import LitDataModule, LitModel

models = get_cifar_models(lib="torch")
print(models.keys())


dm = LitDataModule(hparams=hparams)
dm.setup()
dm

dict_keys(['alexnet', 'inception', 'mlp_1x512', 'mlp_3x512'])
Files already downloaded and verified
Files already downloaded and verified


DataModule:
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: /data/cifar10
    Split: Train, Corruption: normal_labels
Val: <torch.utils.data.dataset.Subset object at 0x7efcc90be230>

## Build `Trainer` & `fit()`

In [5]:
import time

import torch
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

import wandb


def fit(trainer, model, datamodule):
    torch.set_float32_matmul_precision(precision="medium")

    start_time = time.time()
    trainer.fit(model, datamodule)
    print(f"Training took {time.time() - start_time:.2f} seconds")

    return trainer, model, datamodule


for model_name, model in models.items():
    # assure that logger process has exited
    wandb.finish()
    hparams["model_name"] = model_name
    log_dir = f"logs/dense"
    project_name = f"generalization-dense-{hparams['corrupt_name']}"
    experiment_name = f"{hparams['model_name']}-{hparams['corrupt_prob']}"

    logger = WandbLogger(
        name=experiment_name,
        project=project_name,
        log_model="all",
        save_dir=log_dir,
        id=f"{hparams['model_name']}-{hparams['corrupt_prob']}",
        group=f"{hparams['corrupt_name']}",
        tags=[hparams["model_name"], hparams["corrupt_name"]],
    )

    ckpt = ModelCheckpoint(
        # dirpath=ckpt_dir, # overridden by default_root_dir
        filename=f"{hparams['model_name']}-{hparams['corrupt_prob']}"
        + "-{epoch:02d}-{valid/loss:.2f}",
        save_top_k=-1,
        save_last=True,
    )

    trainer = Trainer(
        max_epochs=hparams["epochs"],
        logger=logger,
        callbacks=[ckpt],
        default_root_dir=log_dir,
        check_val_every_n_epoch=hparams["val_every"],
    )
    pl_model = LitModel(
        net=model,
        hparams=hparams,
    )

    trainer, pl_model, dm = fit(trainer, pl_model, dm)

    trainer.test(pl_model, dm.test_dataloader())

    # trainer._save_to_state_dict(f"logs/{project_name}/{experiment_name}/last.ckpt")

    # assure that logger process has exited
    trainer.logger.experiment.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mstepp1[0m. Use [1m`wandb login --relogin`[0m to force relogin


ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None) will duplicate the last checkpoint saved.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | net            | SmallAlexNet       | 56.8 M
1 | train_acc      | MulticlassAccuracy | 0     
2 | valid_acc      | MulticlassAccuracy | 0     
3 | valid_top5_acc | MulticlassAccuracy | 0     
4 | test_acc       | MulticlassAccuracy | 0     
5 | test_top5_acc  | MulticlassAccuracy | 0     
------------------------------------------------------
56.8 M    Trainable params
0         Non-trainable params
56.8 M    Total params
227.307   Total estimated model params size (MB)


Epoch 4:  20%|██        | 40/196 [00:02<00:07, 19.99it/s, v_num=et-0, valid/loss=1.060, valid/acc=0.482, valid/top5_acc=0.903, train/loss=1.020, train/acc=0.638] 

Epoch 4:  27%|██▋       | 52/196 [00:02<00:06, 20.83it/s, v_num=et-0, valid/loss=1.060, valid/acc=0.482, valid/top5_acc=0.903, train/loss=1.020, train/acc=0.638]

In [None]:
wandb.finish()

In [None]:
# from torchvision import transforms
# from generalization.randomization.utils import CIFAR10_NORMALIZE_MEAN, CIFAR10_NORMALIZE_STD


# idx = np.random.randint(len(experiments[CORRUPT_NAME]["train_set"]))

# unnormalize = transforms.functional.normalize(
#     experiments[CORRUPT_NAME]["train_set"][idx][0],
#     mean=[-m / s for m, s in zip(CIFAR10_NORMALIZE_MEAN, CIFAR10_NORMALIZE_STD)],
#     std=[1 / s for s in CIFAR10_NORMALIZE_STD],
# )

# label = experiments[CORRUPT_NAME]["train_set"][idx][1]
# class_name = experiments[CORRUPT_NAME]["train_set"].classes[label]


# f, ax = plt.subplots(1, 1, figsize=(2, 2))
# ax.imshow(unnormalize.permute(1, 2, 0))
# ax.set_title(f"Corrupted label: {class_name}")
# ax.axis("off")
# plt.show()