In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
import pytorch_lightning as pl
import random
import dotenv
import omegaconf
import hydra
import logging
from typing import List
from datetime import datetime
import wandb
from datetime import date
import dotenv
import os
import pathlib
from typing import Dict, Any
from copy import deepcopy

from rigl_torch.models import ModelFactory
from rigl_torch.rigl_scheduler import RigLScheduler
from rigl_torch.rigl_constant_fan import RigLConstFanScheduler
from rigl_torch.datasets import get_dataloaders
from rigl_torch.optim import (
    get_optimizer,
    get_lr_scheduler,
)
from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.utils.rigl_utils import get_T_end, get_fan_in_after_ablation, get_conv_idx_from_flat_idx
from rigl_torch.meters import SegmentationMeter
from hydra import initialize, compose
from rigl_torch.utils.dist_utils import get_steps_to_accumulate_grad
from rigl_torch.utils.wandb_utils import init_wandb


with initialize("../configs", version_base="1.2.0"):
    cfg = compose(
        "config.yaml",
        overrides=[
            "compute.distributed=False",
            "dataset=coco",
            "model=maskrcnn",
            "training.test_batch_size=10",
            "training.save_model=False",
            "wandb.log_to_wandb=False"
            ])
dotenv.load_dotenv("../.env", override=True)
os.environ["IMAGE_NET_PATH"]
print(cfg.model.name)
print(cfg.paths.base)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)



maskrcnn
/home/mike/condensed-sparsity


In [2]:
rank=0
checkpoint=None
if checkpoint is not None:
    run_id = checkpoint.run_id
    optimizer_state = checkpoint.optimizer
    scheduler_state = checkpoint.scheduler
    pruner_state = checkpoint.pruner
    model_state = checkpoint.model
    cfg = checkpoint.cfg
    wandb_init_resume = "must"
else:
    run_id, optimizer_state, scheduler_state, pruner_state, model_state, wandb_init_resume = (
        None,
        None,
        None,
        None,
        None,
        None
    )

if "diet" not in cfg.rigl:
    with omegaconf.open_dict(cfg):
        cfg.rigl.diet = None
if "keep_first_layer_dense" not in cfg.rigl:
    with omegaconf.open_dict(cfg):
        cfg.rigl.keep_first_layer_dense = False
print(cfg.compute)
cfg.compute.distributed=False
wandb_init_kwargs = dict(resume=wandb_init_resume, id=run_id)
run = init_wandb(cfg, wandb_init_kwargs)
pl.seed_everything(cfg.training.seed)
use_cuda = not cfg.compute.no_cuda and torch.cuda.is_available()
if not use_cuda:
    raise SystemError("GPU has stopped responding...waiting to die!")
    logger.warning(
        "Using CPU! Verify cfg.compute.no_cuda and "
        "torch.cuda.is_available() are properly set if this is unexpected"
    )

if cfg.compute.distributed and use_cuda:
    device = torch.device(f"cuda:{rank}")
else:
    print(f"loading to device rank: {rank}")
    device = torch.device(f"cuda:{rank}")
if not use_cuda:
    device = torch.device("cuda" if use_cuda else "cpu")
print(cfg.paths.base)
train_loader, test_loader = get_dataloaders(cfg)

model = ModelFactory.load_model(
    model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet
)
model.to(device)
if cfg.compute.distributed:
    model = DistributedDataParallel(model, device_ids=[rank])
if model_state is not None:
    try:
        model.load_state_dict(model_state)
    except RuntimeError:
        model_state = checkpoint.get_single_process_model_state_from_distributed_state()
        model.load_state_dict(model_state)
        
optimizer = get_optimizer(cfg, model, state_dict=optimizer_state)
scheduler = get_lr_scheduler(cfg, optimizer, state_dict=scheduler_state)
pruner = None
if cfg.rigl.dense_allocation is not None:
    if cfg.rigl.dense_allocation is not None:
        if cfg.model.name == "skinny_resnet18":
            dense_allocation = (
                cfg.rigl.dense_allocation * cfg.model.sparsity_scale_factor
            )
            print(
                f"Scaling {cfg.rigl.dense_allocation} by "
                f"{cfg.model.sparsity_scale_factor:.2f} for SkinnyResNet18 "
                f"New Dense Alloc == {dense_allocation:.6f}"
            )
        else:
            dense_allocation = cfg.rigl.dense_allocation
        T_end = get_T_end(cfg, [0 for _ in range(0,1251)])
        if cfg.rigl.const_fan_in:
            rigl_scheduler = RigLConstFanScheduler
        else:
            rigl_scheduler = RigLScheduler
        pruner = rigl_scheduler(
            model,
            optimizer,
            dense_allocation=cfg.rigl.dense_allocation,
            alpha=cfg.rigl.alpha,
            delta=cfg.rigl.delta,
            static_topo=cfg.rigl.static_topo,
            T_end=T_end,
            ignore_linear_layers=cfg.rigl.ignore_linear_layers,
            grad_accumulation_n=cfg.rigl.grad_accumulation_n,
            sparsity_distribution=cfg.rigl.sparsity_distribution,
            erk_power_scale=cfg.rigl.erk_power_scale,
            state_dict=pruner_state,
            filter_ablation_threshold=cfg.rigl.filter_ablation_threshold,
            static_ablation=cfg.rigl.static_ablation,
            dynamic_ablation=cfg.rigl.dynamic_ablation,
            min_salient_weights_per_neuron=cfg.rigl.min_salient_weights_per_neuron,  # noqa
            use_sparse_init=cfg.rigl.use_sparse_initialization,
            init_method_str=cfg.rigl.init_method_str,
            use_sparse_const_fan_in_for_ablation=cfg.rigl.use_sparse_const_fan_in_for_ablation,  # noqa
            initialize_grown_weights=cfg.rigl.initialize_grown_weights,
            no_ablation_module_names=cfg.model.no_ablation_module_names
        )

Global seed set to 42


{'no_cuda': False, 'cuda_kwargs': {'num_workers': '${ oc.decode:${oc.env:NUM_WORKERS} }', 'pin_memory': True}, 'distributed': False, 'world_size': 4, 'dist_backend': 'nccl'}
No logging to WANDB! See cfg.wandb.log_to_wandb
loading to device rank: 0
/home/mike/condensed-sparsity
loading annotations into memory...
Done (t=8.54s)
creating index...
index created!
loading annotations into memory...
Done (t=4.80s)
creating index...


INFO:/home/mike/condensed-sparsity/src/rigl_torch/models/model_factory.py:Loading model maskrcnn/coco using <function get_maskrcnn at 0x7fe708b6cee0> with args: () and kwargs: {'diet': None}


index created!


INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 62 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 63 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 1 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 3 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 4 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 5 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 7 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 8 set to 0.0
INFO:/home/mike/condensed-sparsity/src/rigl_torch/rigl_scheduler.py:Sparsity of layer at index 10 set to 0.0
INFO:/home/mike/condensed

In [3]:
def train(
    cfg,
    model,
    device,
    train_loader,
    optimizer,
    epoch,
    pruner,
    step,
    logger,
    rank,
):
    model.train()
    steps_to_accumulate_grad = get_steps_to_accumulate_grad(
        cfg.training.simulated_batch_size, cfg.training.batch_size
    )
    for batch_idx, (images, targets) in enumerate(train_loader):
        apply_grads = (
            True
            if steps_to_accumulate_grad == 1
            or (
                batch_idx != 0
                and (batch_idx + 1) % steps_to_accumulate_grad == 0
            )
            else False
        )
        images = list(image.to(device) for image in images)
        targets = [
            {
                k: v.to(device) if isinstance(v, torch.Tensor) else v
                for k, v in t.items()
            }
            for t in targets
        ]
        loss_dict = model(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        logger.info(loss_dict)
        logger.info(loss)
        return

        # Normalize loss for accumulated grad
        loss = loss / steps_to_accumulate_grad

        # Will call backwards hooks on model and accumulate dense grads if
        # within cfg.rigl.grad_accumulation_n mini-batch steps from update
        loss.backward()

        if apply_grads:  # If we apply grads, check for topology update and log
            if cfg.training.clip_grad_norm is not None:
                nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=cfg.training.clip_grad_norm
                )
            step += 1
            optimizer.step()
            if pruner is not None:
                # pruner.__call__ returns False if rigl step taken
                pruner_called = not pruner()
            # optimizer.zero_grad()

            if step % cfg.training.log_interval == 0 and rank == 0:
                world_size = (
                    1
                    if cfg.compute.distributed is False
                    else cfg.compute.world_size
                )
                logger.info(
                    "Step: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(  # noqa
                        step,
                        epoch,
                        batch_idx * len(images) * world_size,
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )
                wandb_data = {
                    "Training Loss": loss.item(),
                }
                if pruner is not None:
                    wandb_data["ITOP Rate"] = pruner.itop_rs
                    if (
                        cfg.wandb.log_filter_stats
                        and rank == 0
                        and pruner_called
                    ):
                        # If we updated the pruner
                        # log filter-wise statistics to wandb
                        pruner.log_meters(step=step)
                wandb.log(wandb_data, step=step)

            # We zero grads after logging pruner filter meters
            optimizer.zero_grad()
            if cfg.training.dry_run:
                logger.warning("Dry run, exiting after one training step")
                return step
            if (
                cfg.training.max_steps is not None
                and step > cfg.training.max_steps
            ):
                return step
    return step


def test(  # TODO
    cfg, model, device, test_loader, epoch, step, rank, logger, training_meter
):
    model.eval()
    test_loss = 0
    correct = 0
    top_k_correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            logits = model(data)
            test_loss += F.cross_entropy(
                logits,
                target,
                label_smoothing=cfg.training.label_smoothing,
                reduction="mean",
            )
            pred = logits.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum()
            if cfg.dataset.name == "imagenet":
                _, top_5_indices = torch.topk(logits, k=5, dim=1, largest=True)
                top_5_pred = (
                    target.reshape(-1, 1).expand_as(top_5_indices)
                    == top_5_indices
                ).any(dim=1)
                top_k_correct += top_5_pred.sum()
            else:
                top_k_correct = None
    if cfg.compute.distributed:
        dist.all_reduce(test_loss, dist.ReduceOp.AVG, async_op=False)
        dist.all_reduce(correct, dist.ReduceOp.SUM, async_op=False)
        if cfg.dataset.name == "imagenet":
            dist.all_reduce(top_k_correct, dist.ReduceOp.SUM, async_op=False)
            top_k_correct = top_k_correct / len(test_loader.dataset)
    training_meter.accuracy = (correct / len(test_loader.dataset)).item()
    if rank == 0:
        logger.info(
            (
                "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n"
            ).format(
                test_loss,
                correct,
                len(test_loader.dataset),
                100.0 * correct / len(test_loader.dataset),
            )
        )
    return test_loss, correct / len(test_loader.dataset)

In [4]:
run = None  # No WANDB for us here
segmentation_meter = SegmentationMeter()
if not cfg.experiment.resume_from_checkpoint:
    step = 0
    if rank == 0:
        if run is None:
            run_id = datetime.now().strftime("%h-%m-%d-%H-%M")
        else:
            run_id = run.id
        checkpoint = Checkpoint(
            run_id=run_id,
            cfg=cfg,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            pruner=pruner,
            epoch=0,
            step=step,
            parent_dir=cfg.paths.checkpoints,
        )
        if (pruner is not None) and (cfg.wandb.log_filter_stats):
            # Log inital filter stats before pruning
            pruner.log_meters(step=step)

    epoch_start = 1
else:  # Resuming from checkpoint
    checkpoint.model = model
    checkpoint.optimizer = optimizer
    checkpoint.scheduler = scheduler
    checkpoint.pruner = pruner
    # Start at the next epoch after the last that successfully was saved
    epoch_start = checkpoint.epoch + 1
    step = checkpoint.step
    # NOTE: we will use acc for checkpointing but this will hold mask_mAP
    segmentation_meter._max_mask_mAP = checkpoint.best_acc

for epoch in range(epoch_start, cfg.training.epochs + 1):
    if pruner is not None and rank == 0:
        logger.info(pruner)
    if cfg.compute.distributed:
        train_loader.sampler.set_epoch(epoch)
    step = train(
        cfg,
        model,
        device,
        train_loader,
        optimizer,
        epoch,
        pruner=pruner,
        step=step,
        logger=logger,
        rank=rank,
    )
    break
    loss, box_mAP, mask_mAP = test(
        cfg,
        model,
        device,
        test_loader,
        epoch,
        step,
        rank,
        logger,
        segmentation_meter,
    )
    if rank == 0:
        wandb.log({"Learning Rate": scheduler.get_last_lr()[0]}, step=step)
        logger.info(f"Learning Rate: {scheduler.get_last_lr()[0]}")
        checkpoint.current_acc = mask_mAP
        checkpoint.step = step
        checkpoint.epoch = epoch
        checkpoint.save_checkpoint()
    if cfg.training.dry_run:
        break
    if cfg.training.max_steps is not None and step > cfg.training.max_steps:
        break
    scheduler.step()

if cfg.training.save_model and rank == 0:
    save_path = pathlib.Path(cfg.paths.artifacts)
    if not save_path.is_dir():
        save_path.mkdir()
    f_path = save_path / f"{cfg.experiment.name}.pt"
    torch.save(model.state_dict(), f_path)
    art = wandb.Artifact(name=cfg.experiment.name, type="model")
    art.add_file(f_path)
    logging.info(f"artifact path: {f_path}")
    wandb.log_artifact(art)
if rank == 0 and cfg.wandb.log_to_wandb:
    run.finish()

INFO:__main__:RigLScheduler(
layers=74,
nonzero_params=[9408/9408, 4096/4096, 15744/36864, 16384/16384, 16384/16384, 16384/16384, 15744/36864, 16384/16384, 16384/16384, 15744/36864, 16384/16384, 32768/32768, 30720/147456, 65536/65536, 90112/131072, 65536/65536, 30720/147456, 65536/65536, 65536/65536, 30720/147456, 65536/65536, 65536/65536, 30720/147456, 65536/65536, 90368/131072, 60928/589824, 150528/262144, 180224/524288, 150784/262144, 60928/589824, 150528/262144, 150784/262144, 60928/589824, 150528/262144, 150784/262144, 60928/589824, 150528/262144, 150784/262144, 60928/589824, 150528/262144, 150784/262144, 60928/589824, 150528/262144, 180736/524288, 120832/2359296, 301056/1048576, 360448/2097152, 301056/1048576, 120832/2359296, 301056/1048576, 301056/1048576, 120832/2359296, 301056/1048576, 60416/65536, 90368/131072, 150784/262144, 271360/524288, 60928/589824, 60928/589824, 60928/589824, 60928/589824, 60928/589824, 768/768, 3072/3072, 1596416/12845056, 240640/1048576, 93184/93184, 

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.56 GiB (GPU 0; 11.77 GiB total capacity; 9.83 GiB already allocated; 1.43 GiB free; 9.97 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF