# Advanced-эксперименты с библиотекой DepGraph

[Ссылка](https://github.com/VainF/Torch-Pruning/) на репозиторий библиотеки.

### Импорт модулей

In [1]:
!pip install torch-pruning torcheval --upgrade -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.2/70.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m75.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7

In [2]:
import torch
from torch import nn
from torch.fx import symbolic_trace
import torch.utils.data
import torchvision
from torchvision import transforms
from torcheval.metrics import BinaryAUROC

from torchvision.models import resnet50
import torch_pruning as tp

import numpy as np
from scipy.stats import spearmanr, kendalltau

import abc
from typing import Callable, List, Tuple, Dict
from functools import reduce, partial
import re
import copy
from collections import defaultdict

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
import networkx

sns.set_style('darkgrid')



### Реализации модулей и функций

Данные.

In [3]:
def get_dataloaders(classes: List[int], batch_size: int = 16, img_size: int = 33, need_val: bool = False, cifar100: bool = False, train_limit = None):
    classes_to_ids = {cls : i for i, cls in enumerate(classes)}
    transform_train = transforms.Compose([
        transforms.Resize(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if cifar100:
        trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                            download=True, transform=transform_train)
    else:
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)
        
    trainset = [(x, classes_to_ids[y]) for x, y in trainset if y in classes]
    if need_val:
        _trainset = trainset[:len(trainset)//2]
        valset  = trainset[len(trainset)//2:]
        trainset = _trainset
    if train_limit:
        trainset = trainset[:train_limit]
        
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True)
    if need_val:
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True)
    if cifar100:
        testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                           download=True, transform=transform_test)
    else:
        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    
    testset = [(x, classes_to_ids[y]) for x, y in testset if y in classes]
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False)
    if need_val:
        return trainloader, valloader, testloader
    return trainloader, testloader

Циклы обучения и теста.

In [4]:
def train_loop(model, traindata, testdata, epoch_num=1, lr=1e-3, device='cuda'):
    history = []
    model.to(device)
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    batch_seen = 0
    
    for epoch_num in range(epoch_num):
        losses = []
        tq = tqdm(traindata, leave=False)
        
        for x, y in tq:
            optim.zero_grad()
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            
            if not isinstance(out, torch.Tensor):
                out = out[0] #  when features are also returned in forward
            
            loss = criterion(out, y)
            loss.backward()
            optim.step()
            losses.append(loss.cpu().detach().numpy())
            
            batch_seen += 1
            metric_result = test_loop(model, testdata, device)
            tq.set_description(f'Epoch: {epoch_num}, Loss: {str(np.mean(losses))}, ROC-AUC: {metric_result}')
            history.append(metric_result)
                
    return history

def test_loop(model, testdata, device='cuda', return_loss=False):
    criterion = torch.nn.CrossEntropyLoss()
    metric = BinaryAUROC(device=device)
    model.to(device)
    model.eval()
    loss = 0.0

    for x, y in testdata:
        x = x.to(device)
        y = y.to(device)
        out = model(x)
        if not isinstance(out, torch.Tensor):
            out = out[0] #  when features are also returned in forward
        pred = out.argmax(-1)
        metric.update(pred, y)
        if return_loss:
            loss += criterion(out, y).detach().cpu().item()
    
    metric_result = metric.compute().item()
    
    model.train()
    if return_loss:
        return loss
    
    return metric_result

Напишем функцию для создания полной модели.

In [5]:
def get_model(device='cuda'):
    trainloader, testloader = get_dataloaders([8,9], batch_size=64)
    full_model = resnet50(pretrained=True)
    full_model.fc = torch.nn.Linear(full_model.fc.in_features, 2)
    train_loop(full_model, trainloader, testloader)
    full_model = full_model.to(device)

    return full_model

### Создание графа

Инициализируем модель.

In [6]:
model = get_model(device='cpu')

100%|██████████| 170M/170M [00:01<00:00, 91.3MB/s] 
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 169MB/s]


  0%|          | 0/157 [00:00<?, ?it/s]

**Из документации**

Dependency Graph (DepGraph) is the core feature of Torch-Pruning, which provides an automatic mechanism to group dependent layers. There are two key concepts for DepGraph:

- `tp.dependency.Dependency`: the dependency between layers.
- `tp.dependency.DependencyGraph`: A relational graph to model the dependency.
- `tp.dependency.Group`: A list of dependencies that represents the minimally-removable units.

Построим DependencyGraph.

In [7]:
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

У него есть ноды.

In [29]:
nodes = list(DG.module2node.values())

In [33]:
nodes[3].inputs

[<Node: (_ElementWiseOp_3(ReluBackward0))>]

In [37]:
nodes[10].outputs

[<Node: (_ElementWiseOp_6(AddBackward0))>,
 <Node: (layer4.1.conv1 (Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)))>]

У его есть Dependency (Edge of DepGraph).

For the dependency A -> B, the pruning operation ``trigger(A)`` will trigger 
the pruning operation ``handler(B)``.

In [42]:
nodes[2].dependencies

[prune_out_channels on _ElementWiseOp_1(TBackward0) => prune_in_channels on fc (Linear(in_features=2048, out_features=2, bias=True))]

У этого объекта есть такие атрибуты.

In [43]:
nodes[2].dependencies[0].source

<Node: (_ElementWiseOp_1(TBackward0))>

In [44]:
nodes[2].dependencies[0].target

<Node: (fc (Linear(in_features=2048, out_features=2, bias=True)))>

In [51]:
nodes[2].dependencies[0].trigger

<bound method DummyPruner.prune_out_channels of <torch_pruning.ops.ElementWisePruner object at 0x7f24ad334b10>>

Dependency складываются в Group для прунинга.

Group is the basic unit for pruning. It contains a list of dependencies and their corresponding indices.

    group := [ (Dep1, Indices1), (Dep2, Indices2), ..., (DepK, IndicesK) ]

Example: 

For a simple network Conv2d(2, 4) -> BN(4) -> Relu, we have:

    group1 := [ (Conv2d -> BN, [0, 1, 2, 3]), (BN -> Relu, [0, 1, 2, 3]) ]

There are 4 prunable elements, i.e., 4 channels in Conv2d.

The indices do not need to be full and can be a subset of the prunable elements.
For instance, if we want to prune the first 2 channels, we have:

    group2 := [ (Conv2d -> BN, [0, 1]), (BN -> Relu, [0, 1]) ]

When combined with tp.importance, we can compute the importance of corresponding channels.

    imp_1 = importance(group1) # len(imp_1)=4
    imp_2 = importance(group2) # len(imp_2)=2

For importance estimation, we should craft a group with full indices just like group1.
For pruning, we need to craft a new group with the to-be-pruned indices like group2.

Можно доставать Dependency из модуля, например.

In [54]:
group = DG.get_pruning_group(model.layer1[0].conv1, pruning_fn=tp.prune_conv_out_channels, idxs=[2, 6, 9])

In [55]:
for i, (dep, idxs) in enumerate(group):
    trigger = dep.trigger
    handler = dep.handler
    source_layer = dep.source.module
    target_layer = dep.target.module

    print("For Dep: ", dep)
    print(" > Trigger: ", trigger)
    print(" > Handler: ", handler)
    print(" > Source Layer: ", source_layer)
    print(" > Target Layer: ", target_layer)
    print("")

For Dep:  prune_out_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)) => prune_out_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False))
 > Trigger:  <bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f24c5c1d550>>
 > Handler:  <bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f24c5c1d550>>
 > Source Layer:  Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
 > Target Layer:  Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)

For Dep:  prune_out_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)) => prune_out_channels on layer1.0.bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
 > Trigger:  <bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f24c5c1d550>>
 > Han