In [1]:
import torch
import sys
import os
sys.path.append('..')
from src.module2graph import GraphInterperterWithGamma
from src.resnet18 import ResNet18
import numpy as np

import graphviz
import itertools
import copy
from torchvision import transforms
import torchvision.datasets as datasets
import torchvision

import networkx as nx
from tqdm.notebook import tqdm
from typing import Tuple, Dict # actually we don't need it for py>=3.9, but I have 3.8 on my laptop
from src.utils import train_loop, test_loop
from src.cifar_data import get_dataloaders




In [2]:
# ! wget http://cs231n.stanford.edu/tiny-imagenet-200.zip

In [3]:
# ! unzip tiny-imagenet-200.zip > /dev/null

In [4]:
# data_dir = "tiny-imagenet-200/"
# num_workers = {"train": 2, "val": 0, "test": 0}
# data_transforms = {
#     "train": transforms.Compose(
#         [
#             transforms.ToTensor(),
#             transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
#         ]
#     ),
#     "val": transforms.Compose(
#         [
#             transforms.ToTensor(),
#             transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]),
#         ]
#     ),
# }
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
#                                           data_transforms[x]) for x in ["train", "val"]}
# dataloaders = {
#     x: torch.utils.data.DataLoader(image_datasets[x], batch_size=128,
#                                    shuffle=True, num_workers=num_workers[x])
#     for x in ["train", "val"]
# }

In [5]:
cls = [0, 1]
train_dl, test_dl = get_dataloaders(classes=cls, batch_size=64,
                                   img_size=33, cifar100=False)

# Tiny-Imagenet
# train_dl, test_dl = dataloaders['train'], dataloaders['val']

Files already downloaded and verified
Files already downloaded and verified


### Plan

1. Fine-tune the last layer of the pretrained ResNet18 on Imagenet
2. Calculate the importance of each edge in a neive way (directly estimate the loss increment)
instead of fine-tuning the whole model first, since it contradicts to the protocol of our experimental section (we can't fine-tune the whole model due to the lack of GPU memory).


In [19]:
model = ResNet18(num_classes=len(cls)) # attention
for n, p in model.named_parameters():
    if 'fc' not in n:
        p.requires_grad = False

optimizer = torch.optim.Adam([p for n, p in model.named_parameters() if 'fc' in n],
                           lr=1e-3)
crit = torch.nn.CrossEntropyLoss()

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


In [20]:
device = 'mps'
model = model.to(device)
model.eval()

for epoch in range(10):
    # model.train() # no need to switch to training mode
    for i ,(x, y) in enumerate(tqdm(train_dl)):
        logits = model(x.to(device))[0]
        loss = crit(logits, y.to(device))
        loss.backward()
        optimizer.step()
        if i % 100 == 0 and i > 0:
            n_corr = 0
            n_tot = 0
            model.eval()
            for j, (x, y) in enumerate(test_dl):
                with torch.inference_mode():
                    logits = model(x.to(device))[0]
                    n_corr += (logits.argmax(-1) == y.to(device)).sum().item()
                    n_tot += x.shape[0]
                if j >= 50:
                    break
            print(n_corr / n_tot)


  0%|          | 0/157 [00:00<?, ?it/s]

0.66


  0%|          | 0/157 [00:00<?, ?it/s]

0.727


  0%|          | 0/157 [00:00<?, ?it/s]

0.84


  0%|          | 0/157 [00:00<?, ?it/s]

0.9145


  0%|          | 0/157 [00:00<?, ?it/s]

0.8815


  0%|          | 0/157 [00:00<?, ?it/s]

0.9075


  0%|          | 0/157 [00:00<?, ?it/s]

0.9195


  0%|          | 0/157 [00:00<?, ?it/s]

0.898


  0%|          | 0/157 [00:00<?, ?it/s]

0.916


  0%|          | 0/157 [00:00<?, ?it/s]

0.9235


In [23]:
n_corr = 0
n_tot = 0
model.eval()
for j, (x, y) in enumerate(tqdm(test_dl)):
    with torch.inference_mode():
        logits = model(x.to(device))[0]
        n_corr += (logits.argmax(-1) == y.to(device)).sum().item()
        n_tot += x.shape[0]
print(n_corr / n_tot)
torch.save(model.model.fc.state_dict(), 'fc_best.ckpt')

  0%|          | 0/32 [00:00<?, ?it/s]

0.9065


In [24]:
for n, p in model.named_parameters():
    p.requires_grad = True

In [25]:
# forward for target model with gamma values for each edge.
# means - mean values for arguments
def forward_with_gammas(model, gammas: Dict[Tuple[str, str], torch.Tensor], 
                        means: Dict[str, torch.Tensor] = None, *torch_model_args):
    args_iter = iter(torch_model_args)
    env : Dict[str, Node] = {}
    used_edges = set()
    def load_arg(a):    
        return torch.fx.graph.map_arg(a, lambda n: env[n.name])

    def fetch_attr(target : str):
        target_atoms = target.split('.')
        attr_itr = model.graph
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr
    named_modules = dict(model.named_modules())
    for node in model.graph.nodes:
        edges = []

        if node.op in ['call_module', 'call_function', 'output']:    
            if node.op == 'output':
                edges = [(node.args[0][0].name, node.name)]
            else:
    
                for arg in node.args:
                    if type(arg) == torch.fx.Node:  # ignore constants
                        edges.append((arg.name, node.name))
                    else:
                        edges.append(None)
            gammas_node = [int(gammas[e])  if (e is not None) else 1 for e in edges ]
                
            #print (edges, gammas_node)
        if node.op == 'placeholder':
            result = next(args_iter) 
        elif node.op == 'get_attr':
            result = fetch_attr(node.target)
        elif node.op == 'call_function':
            args = [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g  for a0,a,g in zip(node.args,
                                                                           load_arg(node.args), gammas_node)]
            #print (len(args), len(node.args))
            #print (node, [a for a in node.args])
            #print (node, [a.shape for a in args])
            result = node.target(*args, **load_arg(node.kwargs)) 
        elif node.op == 'call_method':
            self_obj, *args = load_arg(node.args) 
            kwargs = load_arg(node.kwargs)
            args =  [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g   for a0, a,g in zip(node.args[1:], 
                                                                        args, gammas_node)]
            result = getattr(self_obj, node.target)(*args, **kwargs)
        elif node.op == 'call_module':
            args = [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g   for a0, a,g in zip(node.args, 
                                                                           load_arg(node.args), gammas_node)]
            
            result = named_modules[node.target](*args, **load_arg(node.kwargs)) 
        
        result = result
        for e in edges:
            used_edges.add(e)
            
        if node.op == 'output':
            
            return result, env # currently ignorign means for output
        #print (node.args, node.name, node.op, abs(result).sum().item())
        env[node.name] = result
        
    return result

# a wrapper that takes model and uses forward_with_Gammas
class PrunedModel(torch.nn.Module):
    def __init__(self, base, prune_dict, means = None):
        super().__init__()
        self.base = base
        self.prune_dict = prune_dict
        self.means = means 
    def forward(self, x):
        return forward_with_gammas(self.base, self.prune_dict,  self.means, x)



In [26]:
# gets intermediate representations of nodes
def get_inter(model, *torch_model_args) -> dict:
    args_iter = iter(torch_model_args)
    env : Dict[str, Node] = {}
    used_edges = set()
    inter = {}
    def load_arg(a):    
        return torch.fx.graph.map_arg(a, lambda n: env[n.name])

    def fetch_attr(target : str):
        target_atoms = target.split('.')
        attr_itr = model.graph
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr
    named_modules = dict(model.named_modules())
    for node in model.graph.nodes:
        edges = []

        if node.op in ['call_module', 'call_function', 'output']:    
            if node.op == 'output':
                edges = [(node.args[0][0].name, node.name)]
            else:
    
                for arg in node.args:
                    if type(arg) == torch.fx.Node:  # ignore constants
                        edges.append((arg.name, node.name))
                    else:
                        edges.append(None)
                
            #print (edges, gammas_node)
        if node.op == 'placeholder':
            result = next(args_iter) 
        elif node.op == 'get_attr':
            result = fetch_attr(node.target)
        elif node.op == 'call_function':
            
            args = load_arg(node.args)
            for a_, a in zip(node.args, args):
                inter[a_] = a
            #print (len(args), len(node.args))
            #print ([a.shape for a in load_arg(node.args)], [a.shape for a in args])
            result = node.target(*args, **load_arg(node.kwargs)) 
        elif node.op == 'call_method':
            self_obj, *args = load_arg(node.args) 
            
            for a_, a in zip(node.args[1:], args):
                inter[a_] = a
            kwargs = load_arg(node.kwargs)
            result = getattr(self_obj, node.target)(*args, **kwargs)
        elif node.op == 'call_module':
            args = load_arg(node.args)
            for a_, a in zip(node.args, args):
                inter[a_] = a
            result = named_modules[node.target](*args, **load_arg(node.kwargs)) 
        
        
        result = result
        for e in edges:
            used_edges.add(e)
    
        if node.op == 'output':
            
            return inter 
        #print (node.args, node.name, node.op, abs(result).sum().item())
        env[node.name] = result
        
    return inter


In [27]:
from statistics import mean 

def module_to_graph(m: torch.nn.Module):
    graph = torch.fx.symbolic_trace(m).graph
    named_dict = dict(m.named_modules())
    edges = [] # (from, to)
    grad = {'x': 0}
    params = {'x': 0}
    weights = {'x': 0} # node: params
    means = {}
    for node in graph.nodes:
    
        # no placeholder and call_mathod
        if node.op == 'call_module':
            n_params = 0
            grad_1 = []
            params_1 = []
            for p in named_dict[node.target].parameters():
                n_params += p.numel()
                grad_1.append(p.grad)
                params_1.append(p)
            try:
                grad[node.name] = torch.mean(torch.stack(grad_1))
                params[node.name] = torch.mean(torch.stack(params_1))
            except:
                grad[node.name] = 0
                params[node.name] = 0
            weights[node.name] = n_params
            assert len(node.args) == 1
            for arg in node.args:
                if type(arg) == torch.fx.Node:  # ignore constants
                    edges.append((arg.name, node.name))
        elif node.op == 'call_function':
            for arg in node.args:
                if type(arg) == torch.fx.Node:  # ignore constants
                    edges.append((arg.name, node.name))
            weights[node.name] = 0
        elif node.op == 'output':
            try:
                edges.append((node.args[0][0].name, node.name))
            except:
                edges.append((node.args[0].name, node.name))
            weights['output'] = 0
        
        #if  len(edges)>0 and edges[-1] == ('model_maxpool', 'add'):
        #    print (node.args, node.name, node.op)
    return edges, {'_'.join(k.split('.')): v for k, v in weights.items()}, {'_'.join(k.split('.')): v for k, v in params.items()}, {'_'.join(k.split('.')): v for k, v in grad.items()}

module_to_graph(model)
pass

In [28]:
edges, weights, a, b = module_to_graph(ResNet18())
# edges, weights
edges[:10], list(weights.items())[:10]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


([('x', 'model_conv1'),
  ('model_conv1', 'model_bn1'),
  ('model_bn1', 'model_relu'),
  ('model_relu', 'model_maxpool'),
  ('model_maxpool', 'model_layer1_0_conv1'),
  ('model_layer1_0_conv1', 'model_layer1_0_bn1'),
  ('model_layer1_0_bn1', 'model_layer1_0_relu'),
  ('model_layer1_0_relu', 'model_layer1_0_conv2'),
  ('model_layer1_0_conv2', 'model_layer1_0_bn2'),
  ('model_layer1_0_bn2', 'add')],
 [('x', 0),
  ('model_conv1', 9408),
  ('model_bn1', 128),
  ('model_relu', 0),
  ('model_maxpool', 0),
  ('model_layer1_0_conv1', 36864),
  ('model_layer1_0_bn1', 128),
  ('model_layer1_0_relu', 0),
  ('model_layer1_0_conv2', 36864),
  ('model_layer1_0_bn2', 128)])

In [29]:
#likelihood/accuracy of the original model
full_ll = test_loop(model, test_dl, nc=len(cls), return_ll=False, device='cpu')
full_ll

  0%|          | 0/32 [00:00<?, ?it/s]

0.906499981880188

In [30]:
inter = {}
# TODO:fix an issue with all bad iter when dealing with test split
train_dl_limit, _ = get_dataloaders(cls, train_limit=256)
tr = torch.fx.symbolic_trace(ResNet18(2))
elem_count = 0
for x,_ in train_dl_limit:
    elem_count += x.shape[0]
    i_ = get_inter(tr, x)
    for k in i_:
        try:
            if k not in inter:
                inter[str(k)] = i_[k].sum(0).detach()
            else:
                inter[str(k)] += i_[k].sum(0).detach()
        except:
            print ('bad inter', k)
            #inter[str(k)] = 0.0
    for k in inter:
        inter[k] /= elem_count

Files already downloaded and verified
Files already downloaded and verified


Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1


In [32]:
# ll for pruned models (naively)
edge_ll = {}
model = ResNet18(2)
model.model.fc.load_state_dict(torch.load('fc_best.ckpt', map_location='cpu'))
wrapped = torch.fx.symbolic_trace(model)
    
for e in tqdm(edges):
    pruned = PrunedModel(wrapped, {k:1.0 if k != e else 0.0 for k in edges}, inter)
    # edge_ll[e] = test_loop(pruned, train_dl_limit, nc=2, return_ll=True, device='cpu')
    edge_ll[e] = test_loop(pruned, test_dl, nc=2, return_ll=False, device='cpu')
    

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [33]:
# don't process logit edge, so setting its ll to min
edge_ll[('model_fc', 'output')] = max(edge_ll.values())
# edges_importance = [edge_ll[e]/max(edge_ll.values()) for e in edges]
edges_importance = [-edge_ll[e] / full_ll for e in edges]
edges_importance = [imp - min(edges_importance) for imp in edges_importance]

edges_importance[:10]
# {e : -edge_ll[e] / full_ll for e in edges}

[0.4721455949832417,
 0.4721455949832417,
 0.4721455949832417,
 0.4721455949832417,
 0.4721455949832417,
 0.4721455949832417,
 0.16822942837533583,
 0.16822942837533583,
 0.16822942837533583,
 0.1108659565483775]

In [35]:
# edge_ll

In [36]:
DG = nx.DiGraph(edges)
all_sorts = list(nx.all_topological_sorts(DG))
len(all_sorts)

9261

In [37]:
len(edges), len(edges_importance)

(78, 78)

In [38]:
# @njit
def dp_for_top_sort(edges, weights, e_importance, top_sort_str, memory=1e10):
    node_ids = {k: i for i, k in enumerate(weights)}
    id_to_node = [node for _, node in enumerate(weights)]
    top_sort = np.array([node_ids[n] for n in top_sort_str])
    assert top_sort[-1] == node_ids['output']
    assert top_sort[0] == node_ids['x']
    assert top_sort.shape[0] == len(node_ids)
    m = np.zeros((len(node_ids), len(node_ids))).astype(np.int32)
    id_to_weight = np.array([weights[n] for n in id_to_node])
    assert len(edges) == len(e_importance)
    for (src, dst), w in zip(edges, e_importance):
        src_id, dst_id = node_ids[src], node_ids[dst]
        m[src_id, dst_id] = w
        
    node_to_layers = np.ones((len(node_ids), len(node_ids))).astype(np.int32) * (-100)  # ans for each v ->
    node_to_layers[top_sort[-1], top_sort[-1]] = 0
    dp = [1e9] * len(node_ids)
    dp[top_sort[-1]] = 0
    for i in range(len(node_ids) - 2, -1, -1):
        v = top_sort[i]
        for j in range(i, len(node_ids)):  # the last node of the first layer (starting from v)
            if id_to_weight[top_sort[i: j + 1]].sum() > memory:
                continue
            if j == len(node_ids) - 1:
                dp[v] = 0
                node_to_layers[v, top_sort[i:]] = 0
                continue
            v_j = top_sort[j + 1]
            next_layer_ids = [] if j + 1 >= len(node_ids) else \
            [k for k in range(m.shape[0]) if node_to_layers[v_j, k] == node_to_layers[v_j, v_j]]
            pruned_value = sum([m[top_sort[k], l] for k in range(i, j + 1) for l in next_layer_ids if m[top_sort[k], l] != 0])
            if dp[v_j] + pruned_value <= dp[v]:
                dp[v] = dp[v_j] + pruned_value
                node_to_layers[v] = node_to_layers[v_j]
                node_to_layers[v, top_sort[i:j + 1]] = node_to_layers[v].max() + 1
                
    # prune restricted edges (TODO: also prune unreacheble nodes)
    ans = node_to_layers[node_ids['x']]
    pruned_edges_ids = [(i, j) for i in range(m.shape[0]) for j in range(m.shape[0]) \
                    if m[i, j] != 0 and abs(ans[i] - ans[j]) > 1]
    pruned_edges = [(id_to_node[i], id_to_node[j]) for i, j in pruned_edges_ids]
    pruned_value = sum([m[i, j] for i, j in pruned_edges_ids])
    
    reach_ids = [set() for _ in range(len(node_ids))]
    for i in range(len(node_ids) - 1, -1, -1):
        v = top_sort[i]
        reach_ids[v].add(v)
        for k in range(i + 1, m.shape[0]):
            v_c = top_sort[k]
            if m[v, v_c] != 0 and (v, v_c) not in pruned_edges_ids:
                reach_ids[v] |= reach_ids[v_c]
    conn_g = top_sort[-1] in reach_ids[top_sort[0]]
                
    return {'node_to_layer': {id_to_node[i]: ans.max() - l for i, l in enumerate(ans)},
            'pruned_value': pruned_value, 'pruned_edges': pruned_edges,
            'connected_graph': conn_g}



In [60]:
### find the best solution
# import pickle
# with open('./naive_mean.pckl', 'rb') as inp:
#     eval_dict, fine_dict = pickle.loads(inp.read())
eval_dict = {}
fine_dict = {}

# naive (proposed) or random importance
random_importance = True

for edge_mem in range(2,6):
    if edge_mem in eval_dict and edge_mem in fine_dict:
        print ('skip', edge_mem)
        continue
    #if edge_mem == 3 and attemp > 0:
    #    continue
    #if edge_mem == 4 and attemp > 0:
    #    continue

    model = ResNet18(len(cls))
    model.model.fc.load_state_dict(torch.load('fc_best.ckpt', map_location='cpu'))
    warp = GraphInterperterWithGamma(model)

    named_dict = dict(model.named_modules())

    for node in warp.graph.nodes:
        if node.op == 'call_module':
            pass
            # print('Norm', np.sqrt(sum([(p ** 2).sum().item() \
            #                            for p in named_dict[node.target].parameters()])))
        # print(node.op, node.name, node.args)
        # break


    edges, weights, a, b = module_to_graph(ResNet18())
    # edges, weights
    edges[:10], list(weights.items())[:10]
    # edges_importance = [1]*len(edges)  ### WHY???
    print(edges_importance)

    best_pruned = 1e10
    best_val = None
    for s in tqdm(all_sorts, leave=False):
        # for mem=8 the computation takes time
        res = dp_for_top_sort(edges, {k: 1 for k in weights},
                              edges_importance if not random_importance else [1] * len(edges),
                              s, edge_mem)  
        if res['connected_graph'] == True and best_pruned > res['pruned_value'] or best_val is None:
            best_pruned = res['pruned_value']
            best_val = res

        if best_pruned == 0:
            print(best_pruned)
            # TODO: fix a bug with connected graph = False when pruned = 0
            # print(res)
            break


    model = ResNet18(len(cls))
    model.model.fc.load_state_dict(torch.load('fc_best.ckpt', map_location='cpu'))

    wrapped = torch.fx.symbolic_trace(model)
    pruned = PrunedModel(wrapped,
                         {k:1.0 if k not in best_val['pruned_edges'] else 0.0 for k in edges},
                         inter)

    res = test_loop(pruned, test_dl, "cpu", nc=len(cls))
    if edge_mem not in eval_dict:
        eval_dict[edge_mem] = []
    eval_dict[edge_mem].append(res)

    train_loop(pruned, train_dl, test_dl, 9999999999, 1, 1e-3,  "cpu")
    res = test_loop(pruned, test_dl,  "cpu", nc=len(cls))
    if edge_mem not in fine_dict:
        fine_dict[edge_mem] = []
    fine_dict[edge_mem].append(res)

    import pickle
    with open('naive_mean.pckl' if not random_importance else 'random_mean.pckl', 'wb') as out:
        out.write(pickle.dumps([eval_dict, fine_dict]))


Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


[0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.16822942837533583, 0.16822942837533583, 0.16822942837533583, 0.1108659565483775, 0.4478764110651998, 0.4721455949832417, 0.3673468850999958, 0.3673468850999958, 0.12465524317709276, 0.12465524317709276, 0.12465524317709276, 0.05681187406080768, 0.4120242658305403, 0.4721455949832417, 0.4495311386111467, 0.4495311386111467, 0.24875895434053985, 0.24875895434053985, 0.24875895434053985, 0.3314947398653363, 0.3314947398653363, 0.12300051563114589, 0.3083286857270906, 0.4721455949832417, 0.15499167376026612, 0.15499167376026612, 0.2471042267945931, 0.2471042267945931, 0.2471042267945931, 0.10921122900243063, 0.4660782825656049, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.0, 0.0, 0.3182569852502666, 0.001103129779796319, 0.4721455949832417, 0.2504136818864867, 0.2504136818864867, 0.01213455908

  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


[0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.16822942837533583, 0.16822942837533583, 0.16822942837533583, 0.1108659565483775, 0.4478764110651998, 0.4721455949832417, 0.3673468850999958, 0.3673468850999958, 0.12465524317709276, 0.12465524317709276, 0.12465524317709276, 0.05681187406080768, 0.4120242658305403, 0.4721455949832417, 0.4495311386111467, 0.4495311386111467, 0.24875895434053985, 0.24875895434053985, 0.24875895434053985, 0.3314947398653363, 0.3314947398653363, 0.12300051563114589, 0.3083286857270906, 0.4721455949832417, 0.15499167376026612, 0.15499167376026612, 0.2471042267945931, 0.2471042267945931, 0.2471042267945931, 0.10921122900243063, 0.4660782825656049, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.0, 0.0, 0.3182569852502666, 0.001103129779796319, 0.4721455949832417, 0.2504136818864867, 0.2504136818864867, 0.01213455908

  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


[0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.16822942837533583, 0.16822942837533583, 0.16822942837533583, 0.1108659565483775, 0.4478764110651998, 0.4721455949832417, 0.3673468850999958, 0.3673468850999958, 0.12465524317709276, 0.12465524317709276, 0.12465524317709276, 0.05681187406080768, 0.4120242658305403, 0.4721455949832417, 0.4495311386111467, 0.4495311386111467, 0.24875895434053985, 0.24875895434053985, 0.24875895434053985, 0.3314947398653363, 0.3314947398653363, 0.12300051563114589, 0.3083286857270906, 0.4721455949832417, 0.15499167376026612, 0.15499167376026612, 0.2471042267945931, 0.2471042267945931, 0.2471042267945931, 0.10921122900243063, 0.4660782825656049, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.0, 0.0, 0.3182569852502666, 0.001103129779796319, 0.4721455949832417, 0.2504136818864867, 0.2504136818864867, 0.01213455908

  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


[0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.16822942837533583, 0.16822942837533583, 0.16822942837533583, 0.1108659565483775, 0.4478764110651998, 0.4721455949832417, 0.3673468850999958, 0.3673468850999958, 0.12465524317709276, 0.12465524317709276, 0.12465524317709276, 0.05681187406080768, 0.4120242658305403, 0.4721455949832417, 0.4495311386111467, 0.4495311386111467, 0.24875895434053985, 0.24875895434053985, 0.24875895434053985, 0.3314947398653363, 0.3314947398653363, 0.12300051563114589, 0.3083286857270906, 0.4721455949832417, 0.15499167376026612, 0.15499167376026612, 0.2471042267945931, 0.2471042267945931, 0.2471042267945931, 0.10921122900243063, 0.4660782825656049, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.4721455949832417, 0.0, 0.0, 0.3182569852502666, 0.001103129779796319, 0.4721455949832417, 0.2504136818864867, 0.2504136818864867, 0.01213455908

  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /Users/konstantinakovlev/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/157 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [61]:
import pickle
with open('naive_mean.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.906499981880188],
  3: [0.906499981880188],
  4: [0.906499981880188],
  5: [0.906499981880188]},
 {2: [0.953000009059906],
  3: [0.8345000147819519],
  4: [0.925000011920929],
  5: [0.9514999985694885]}]

In [62]:
with open('random_mean.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.5], 3: [0.5], 4: [0.5], 5: [0.5]},
 {2: [0.8544999957084656],
  3: [0.9150000214576721],
  4: [0.890500009059906],
  5: [0.9330000281333923]}]

### Conclusion

The proposed pruning method substantially outperforms a random baseline
in the case of a fair experimental protocol

__TODO__: TODO: fix a bug with connected graph = False when pruned = 0 
