In [None]:
from omegaconf import DictConfig
import hydra

with hydra.initialize(config_path="../configs"):
    cfg = hydra.compose(config_name="config.yaml", overrides=["dataset=cifar10", "model=wide_resnet22", "compute.distributed=False"])
cfg

net = ModelFactory.load_model("wide_resnet22", "cifar10")

In [None]:
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_dataloaders(cfg)
model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name
    )
model.to(device)
optimizer = get_optimizer(cfg, model, state_dict=None)
scheduler = get_lr_scheduler(cfg, optimizer, state_dict=None)

def get_T_end(cfg) -> int:
    """Get step number to terminate pruning / regrowth based on cfg settings.

    Args:
        cfg : _description_

    Returns:
        int: _description_
    """
    if cfg.training.max_steps is None:
        if cfg.compute.distributed:
            # In distributed mode, len(train_loader) will be reduced by
            # 1/world_size compared to single device
            T_end = int(
                0.75
                * cfg.training.epochs
                * len(train_loader)  # Dataset length // batch_size
                * cfg.compute.world_size
            )
        else:
            T_end = int(0.75 * cfg.training.epochs * len(train_loader))
    else:
        T_end = int(0.75 * cfg.training.max_steps)
    if not cfg.rigl.use_t_end:
        T_end = int(1 / 0.75 * T_end)  # We use the full number of steps
    return T_end