In [2]:
%load_ext autoreload
%autoreload 2
# import os
# os.environ["WANDB_MODE"]="offline"

import torch
import wandb
from datetime import datetime
from training_pipelines import regular, single_shot_pruning

from common.architectures import SimpleMLP
from common.datasets import symbolic_1
from common.tracking import Config, SGD, ADAM, PROJECT, save_hparams
from common.training import build_optimizer
i=1

In [5]:
ModelClass = SimpleMLP
config = Config(
    experiment=f'IMP-reinit {i}',
    dataset=symbolic_1.DATASET_NAME,
    model_shape=[symbolic_1.INPUT_DIM, 20, 20, symbolic_1.OUTPUT_DIM],
    model_class = ModelClass.__name__,

    # pruning
    pruning_levels=30,
    pruning_rate=0.1,
    pruning_strategy='global',
    prune_weights=True,
    prune_biases=False,

    # training
    training_epochs=1500,
    lr=0.001,
    momentum=0,
    optimizer=ADAM,
    batch_size = None,

    # seeds
    model_seed=2,
    data_seed=2,

    # lottery
    reinit=True,

    # storage
    persist=True,
    timestamp=datetime.now().strftime("%Y_%m_%d_%H%M%S"),
)

# create the model
model = ModelClass(config.model_shape, seed=config.model_seed)
optim = build_optimizer(model, config)
loss_fn = torch.nn.MSELoss(reduction="mean")

# dataloaders
train_loader, test_loader = symbolic_1.get_dataloaders(config.batch_size)

save_hparams(config)

In [6]:

# run the experiment
with wandb.init(project=PROJECT, name=config.experiment, config=config):
    model = single_shot_pruning.run(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        optim=optim,
        loss_fn=loss_fn,
        config=config,
    )

0,1
epoch,▃▁▆▄▄▂▇▄▅▂▇▇▅▃██▆▃▃▁▆▄▄▂▇▇▅▂▇▇▅▃▃█▆▃▃▁▆▆
loss/eval/0.00,█▇▆▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.10,█▆▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.20,█▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.28,█▅▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.36,█▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.43,█▅▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.49,█▅▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.54,█▅▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss/eval/0.59,█▆▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,1500.0
loss/eval/0.00,0.09218
loss/eval/0.10,0.08529
loss/eval/0.20,0.09015
loss/eval/0.28,0.10328
loss/eval/0.36,0.09804
loss/eval/0.43,0.1085
loss/eval/0.49,0.10798
loss/eval/0.54,0.08402
loss/eval/0.59,0.07163
