In [10]:
%load_ext autoreload
%autoreload 2
import wandb
from datetime import datetime
from training_pipelines import imp
from training_pipelines import regular

from common.torch_utils import get_pytorch_device
from common.architectures import SimpleMLP
from common.datasets.independence import INPUT_DIM, OUTPUT_DIM, DATASET_NAME, build_loaders
from common.tracking import Config, SGD, ADAM, PROJECT, save_hparams
from common.training import build_optimizer, build_model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
def main():
    device = get_pytorch_device()
    device = 'cpu'
    config = Config(
        experiment=f'IMP-reinit-with nograd',
        dataset=DATASET_NAME,
        model_shape=[INPUT_DIM, 20, 20, OUTPUT_DIM],
        model_class = SimpleMLP,

        # pruning
        pruning_levels=30,
        pruning_rate=0.1,
        pruning_strategy='global',
        prune_weights=True,
        prune_biases=False,

        # training
        training_epochs=1500,
        lr=0.001,
        momentum=0,
        optimizer=ADAM,
        batch_size = None,

        # seeds
        model_seed=3,
        data_seed=2,

        # lottery
        reinit=True,

        # storage
        persist=True,
        timestamp=datetime.now().strftime("%Y_%m_%d_%H%M%S"),
        device=str(device),
        wandb=True, # does this even make any sense
    )

    # create the model, optimizer and dataloaders
    model, loss_fn = build_model(config)
    optim = build_optimizer(model, config)
    train_loader, test_loader = build_loaders(config.batch_size)

    save_hparams(config)

    # run the experiment
    with wandb.init(project=PROJECT, name=config.experiment, config=config):

        model = imp.run(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            optim=optim,
            loss_fn=loss_fn,
            config=config,
        )

main()

0,1
epoch,▃▁▆▄▄▂▇▄▅▂▇▇▅▃██▆▃▃▁▆▄▄▂▇▇▅▂▇▇▅▃▃█▆▃▃▁▆▆
loss/eval/0.00,█▇▆▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.10,█▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.20,█▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.28,█▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.36,█▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.42,█▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.49,█▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.54,█▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.59,█▅▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,1500.0
loss/eval/0.00,0.12753
loss/eval/0.10,0.11442
loss/eval/0.20,0.1421
loss/eval/0.28,0.1246
loss/eval/0.36,0.0933
loss/eval/0.42,0.1017
loss/eval/0.49,0.09867
loss/eval/0.54,0.07098
loss/eval/0.59,0.07358
