In [None]:
%load_ext autoreload
%autoreload 2

import torch

from datasets import symbolic_1

from architectures.simlpe_mlp import SimpleMLP

import iterative_magnitude_pruning_with_reinit

from common.torch_utils import remaining_weights_by_pruning_steps
from common.training import build_optimizer
from common.tracking import Config, SGD, ADAM, PROJECT

In [None]:
config = Config(
    experiment = 'basic-test',
    lr = 0.001,
    dataset = symbolic_1.DATASET_NAME,
    training_epochs = 10,
    pruning_levels = 10,
    pruning_rate   = 1,
    num_layers   = 'num_layers',
    prune_weights = True,
    prune_biases = False,
    pruning_strategy= 'global',
    model_shape = [symbolic_1.INPUT_DIM, 20, 20, symbolic_1.OUTPUT_DIM],
    optimizer=SGD,
    momentum=0,
    model_seed=1,
    data_seed=1,
    batch_size = None,
)

In [None]:
# create the model
model = SimpleMLP(config.model_shape, seed = config.model_seed)
config.architecture = model.name

remaining_weights_by_pruning_steps(
    model, 
    config.pruning_rate, 
    config.pruning_levels
)

In [None]:
# optimizer
optim = build_optimizer(model, config)

# loss function and dataloaders
loss_fn = torch.nn.MSELoss(reduction="mean")
train_loader, test_loader = symbolic_1.get_dataloaders(config.batch_size)

In [None]:
import wandb

with wandb.init(project=PROJECT, name=config.experiment, config=config):
    model = iterative_magnitude_pruning_with_reinit.run(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        optim=optim,
        loss_fn=loss_fn,
        config=config,
    )