In [1]:
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import pytorch_lightning as pl
import random
import dotenv
import omegaconf
from datetime import datetime
import hydra
import logging
import wandb
from datetime import date
import pathlib
from typing import Dict, Any, Optional
from copy import deepcopy

from rigl_torch.models.model_factory import ModelFactory
from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.utils.rigl_utils import get_T_end
from rigl_torch.meters import TrainingMeter
from rigl_torch.utils.wandb_utils import WandbRunName, wandb_log_check


def _get_checkpoint(cfg: omegaconf.DictConfig, rank: int, logger) -> Checkpoint:
    run_id = cfg.experiment.run_id
    if run_id is None:
        raise ValueError(
            "Must provide wandb run_id when "
            "cfg.training.resume_from_checkpoint is True"
        )
    checkpoint = Checkpoint.load_last_checkpoint(
        run_id=run_id,
        parent_dir=cfg.paths.checkpoints,
        rank=rank,
    )
    if checkpoint.run_id != run_id:
        logger.warning(f"Mismatched run_id found! {checkpoint.run_id} != {run_id}")
        checkpoint.run_id = run_id
    logger.info(f"Resuming training with run_id: {cfg.experiment.run_id}")
    return checkpoint


def init_wandb(cfg: omegaconf.DictConfig, wandb_init_kwargs: Dict[str, Any]):
    # We override logging functions now to avoid any calls
    if not cfg.wandb.log_to_wandb:
        print("No logging to WANDB! See cfg.wandb.log_to_wandb")
        wandb.log = wandb_log_check(wandb.log, cfg.wandb.log_to_wandb)
        wandb.log_artifact = wandb_log_check(
            wandb.log_artifact, cfg.wandb.log_to_wandb
        )
        wandb.watch = wandb_log_check(wandb.watch, cfg.wandb.log_to_wandb)
        return None
    _ = WandbRunName(name=cfg.experiment.name)  # Verify name is OK
    run = wandb.init(
        name=cfg.experiment.name,
        entity=cfg.wandb.entity,
        project=cfg.wandb.project,
        config=omegaconf.OmegaConf.to_container(
            cfg=cfg, resolve=True, throw_on_missing=True
        ),
        settings=wandb.Settings(start_method=cfg.wandb.start_method),
        **wandb_init_kwargs,
    )
    return run


@hydra.main(config_path="configs/", config_name="config", version_base="1.2")
def initalize_main(cfg: omegaconf.DictConfig) -> None:
    use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()
    if not use_cuda:
        raise SystemError("GPU has stopped responding...waiting to die!")
    if cfg.compute.distributed:
        # We initalize train and val loaders here to ensure .tar balls have
        # been decompressed before parallel workers try and write the same
        # directories!
        single_proc_cfg = deepcopy(cfg)
        single_proc_cfg.compute.distributed = False
        train_loader, test_loader = get_dataloaders(single_proc_cfg)
        del train_loader
        del test_loader
        del single_proc_cfg
        wandb.setup()
        mp.spawn(
            main,
            args=(cfg,),
            nprocs=cfg.compute.world_size,
        )
    else:
        main(0, cfg)  # Single GPU


def _get_logger(rank, cfg: omegaconf.DictConfig) -> logging.Logger:
    log_path = pathlib.Path(cfg.paths.logs)
    if not log_path.is_dir():
        log_path.mkdir()
    logger = logging.getLogger(__file__)
    logger.setLevel(level=logging.INFO)
    current_date = date.today().strftime("%Y-%m-%d")
    # logformat = "[%(levelname)s] %(asctime)s G- %(name)s -%(rank)s -
    # %(funcName)s (%(lineno)d) : %(message)s"
    logformat = (
        "[%(levelname)s] %(asctime)s G- %(name)s - %(funcName)s "
        "(%(lineno)d) : %(message)s"
    )
    logging.root.handlers = []
    logging.basicConfig(
        level=logging.INFO,
        format=logformat,
        handlers=[
            logging.FileHandler(log_path / f"processor_{current_date}.log"),
            logging.StreamHandler(),
        ],
    )
    # logger = logging.LoggerAdapter(logger, {"rank": f"rank: {rank}"})
    # logger.info("hell world")
    return logger




In [None]:
from hydra import initialize, compose
import os
with initialize("../configs", version_base="1.2.0"):
    cfg = compose(
        "config.yaml",
        overrides=[
            "dataset=imagenet",
            "compute.distributed=False",
            "model=resnet50"
            ])
dotenv.load_dotenv("../.env")
print(cfg.model.name)
rank=0


logger = _get_logger(rank, cfg)
if cfg.experiment.resume_from_checkpoint:
    checkpoint = _get_checkpoint(cfg, rank, logger)
    wandb_init_resume = "must"
    run_id = checkpoint.run_id
    cfg = checkpoint.cfg
    cfg.experiment.run_id = run_id
    cfg.experiment.resume_from_checkpoint = True
else:
    run_id = None
    wandb_init_resume = None
    checkpoint = None
logger.info(f"Running train_rigl.py with config:\n{cfg}")

if cfg.compute.distributed:
    dist.init_process_group(
        backend=cfg.compute.dist_backend,
        world_size=cfg.compute.world_size,
        rank=rank,
    )
run_id, optimizer_state, scheduler_state, pruner_state, model_state = (
    None,
    None,
    None,
    None,
    None,
)

if checkpoint is not None:
    run_id = checkpoint.run_id
    optimizer_state = checkpoint.optimizer
    scheduler_state = checkpoint.scheduler
    pruner_state = checkpoint.pruner
    model_state = checkpoint.model
    logger.info(f"Resuming training with run_id: {run_id}")
    cfg = checkpoint.cfg

if rank == 0:
    wandb_init_kwargs = dict(resume=wandb_init_resume, id=run_id)
    run = init_wandb(cfg, wandb_init_kwargs)