In [1]:
import torch
from src.module2graph import GraphInterperterWithGamma
from src.resnet18 import ResNet18
import numpy as np

import graphviz
import itertools
import copy
from torchvision import transforms
import torchvision

import networkx as nx
from tqdm.auto import tqdm
from typing import Tuple, Dict # actually we don't need it for py>=3.9, but I have 3.8 on my laptop
from src.utils import train_loop, test_loop
#from numba import njit
from src.cifar_data import get_dataloaders




In [3]:
# forward for target model with gamma values for each edge.
# means - mean values for arguments
def forward_with_gammas(model, gammas: Dict[Tuple[str, str], torch.Tensor], 
                        means: Dict[str, torch.Tensor] = None, *torch_model_args):
    args_iter = iter(torch_model_args)
    env : Dict[str, Node] = {}
    used_edges = set()
    def load_arg(a):    
        return torch.fx.graph.map_arg(a, lambda n: env[n.name])

    def fetch_attr(target : str):
        target_atoms = target.split('.')
        attr_itr = model.graph
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr
    named_modules = dict(model.named_modules())
    for node in model.graph.nodes:
        edges = []

        if node.op in ['call_module', 'call_function', 'output']:    
            if node.op == 'output':
                edges = [(node.args[0][0].name, node.name)]
            else:
    
                for arg in node.args:
                    if type(arg) == torch.fx.Node:  # ignore constants
                        edges.append((arg.name, node.name))
                    else:
                        edges.append(None)
            gammas_node = [int(gammas[e])  if (e is not None) else 1 for e in edges ]
                
            #print (edges, gammas_node)
        if node.op == 'placeholder':
            result = next(args_iter) 
        elif node.op == 'get_attr':
            result = fetch_attr(node.target)
        elif node.op == 'call_function':
            args = [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g  for a0,a,g in zip(node.args,
                                                                           load_arg(node.args), gammas_node)]
            #print (len(args), len(node.args))
            #print (node, [a for a in node.args])
            #print (node, [a.shape for a in args])
            result = node.target(*args, **load_arg(node.kwargs)) 
        elif node.op == 'call_method':
            self_obj, *args = load_arg(node.args) 
            kwargs = load_arg(node.kwargs)
            args =  [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g   for a0, a,g in zip(node.args[1:], 
                                                                        args, gammas_node)]
            result = getattr(self_obj, node.target)(*args, **kwargs)
        elif node.op == 'call_module':
            args = [a*g + (1.0 - g) * means[str(a0)] if str(a0) in means else a*g   for a0, a,g in zip(node.args, 
                                                                           load_arg(node.args), gammas_node)]
            
            result = named_modules[node.target](*args, **load_arg(node.kwargs)) 
        
        result = result
        for e in edges:
            used_edges.add(e)
            
        if node.op == 'output':
            
            return result, env # currently ignorign means for output
        #print (node.args, node.name, node.op, abs(result).sum().item())
        env[node.name] = result
        
    return result

# a wrapper that takes model and uses forward_with_Gammas
class PrunedModel(torch.nn.Module):
    def __init__(self, base, prune_dict, means = None):
        super().__init__()
        self.base = base
        self.prune_dict = prune_dict
        self.means = means 
    def forward(self, x):
        return forward_with_gammas(self.base, self.prune_dict,  self.means, x)



In [4]:
# gets intermediate representations of nodes
def get_inter(model, *torch_model_args) -> dict:
    args_iter = iter(torch_model_args)
    env : Dict[str, Node] = {}
    used_edges = set()
    inter = {}
    def load_arg(a):    
        return torch.fx.graph.map_arg(a, lambda n: env[n.name])

    def fetch_attr(target : str):
        target_atoms = target.split('.')
        attr_itr = model.graph
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr
    named_modules = dict(model.named_modules())
    for node in model.graph.nodes:
        edges = []

        if node.op in ['call_module', 'call_function', 'output']:    
            if node.op == 'output':
                edges = [(node.args[0][0].name, node.name)]
            else:
    
                for arg in node.args:
                    if type(arg) == torch.fx.Node:  # ignore constants
                        edges.append((arg.name, node.name))
                    else:
                        edges.append(None)
                
            #print (edges, gammas_node)
        if node.op == 'placeholder':
            result = next(args_iter) 
        elif node.op == 'get_attr':
            result = fetch_attr(node.target)
        elif node.op == 'call_function':
            
            args = load_arg(node.args)
            for a_, a in zip(node.args, args):
                inter[a_] = a
            #print (len(args), len(node.args))
            #print ([a.shape for a in load_arg(node.args)], [a.shape for a in args])
            result = node.target(*args, **load_arg(node.kwargs)) 
        elif node.op == 'call_method':
            self_obj, *args = load_arg(node.args) 
            
            for a_, a in zip(node.args[1:], args):
                inter[a_] = a
            kwargs = load_arg(node.kwargs)
            result = getattr(self_obj, node.target)(*args, **kwargs)
        elif node.op == 'call_module':
            args = load_arg(node.args)
            for a_, a in zip(node.args, args):
                inter[a_] = a
            result = named_modules[node.target](*args, **load_arg(node.kwargs)) 
        
        
        result = result
        for e in edges:
            used_edges.add(e)
    
        if node.op == 'output':
            
            return inter 
        #print (node.args, node.name, node.op, abs(result).sum().item())
        env[node.name] = result
        
    return inter


In [5]:


train_dl, test_dl = get_dataloaders([0,1,2,3,4,5,6,7], )


Files already downloaded and verified
Files already downloaded and verified


In [6]:
from statistics import mean 

def module_to_graph(m: torch.nn.Module):
    graph = torch.fx.symbolic_trace(m).graph
    named_dict = dict(m.named_modules())
    edges = [] # (from, to)
    grad = {'x': 0}
    params = {'x': 0}
    weights = {'x': 0} # node: params
    means = {}
    for node in graph.nodes:
    
        # no placeholder and call_mathod
        if node.op == 'call_module':
            n_params = 0
            grad_1 = []
            params_1 = []
            for p in named_dict[node.target].parameters():
                n_params += p.numel()
                grad_1.append(p.grad)
                params_1.append(p)
            try:
                grad[node.name] = torch.mean(torch.stack(grad_1))
                params[node.name] = torch.mean(torch.stack(params_1))
            except:
                grad[node.name] = 0
                params[node.name] = 0
            weights[node.name] = n_params
            assert len(node.args) == 1
            for arg in node.args:
                if type(arg) == torch.fx.Node:  # ignore constants
                    edges.append((arg.name, node.name))
        elif node.op == 'call_function':
            for arg in node.args:
                if type(arg) == torch.fx.Node:  # ignore constants
                    edges.append((arg.name, node.name))
            weights[node.name] = 0
        elif node.op == 'output':
            try:
                edges.append((node.args[0][0].name, node.name))
            except:
                edges.append((node.args[0].name, node.name))
            weights['output'] = 0
        
        #if  len(edges)>0 and edges[-1] == ('model_maxpool', 'add'):
        #    print (node.args, node.name, node.op)
    return edges, {'_'.join(k.split('.')): v for k, v in weights.items()}, {'_'.join(k.split('.')): v for k, v in params.items()}, {'_'.join(k.split('.')): v for k, v in grad.items()}

module_to_graph(ResNet18())

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


([('x', 'model_conv1'),
  ('model_conv1', 'model_bn1'),
  ('model_bn1', 'model_relu'),
  ('model_relu', 'model_maxpool'),
  ('model_maxpool', 'model_layer1_0_conv1'),
  ('model_layer1_0_conv1', 'model_layer1_0_bn1'),
  ('model_layer1_0_bn1', 'model_layer1_0_relu'),
  ('model_layer1_0_relu', 'model_layer1_0_conv2'),
  ('model_layer1_0_conv2', 'model_layer1_0_bn2'),
  ('model_layer1_0_bn2', 'add'),
  ('model_maxpool', 'add'),
  ('add', 'model_layer1_0_relu_1'),
  ('model_layer1_0_relu_1', 'model_layer1_1_conv1'),
  ('model_layer1_1_conv1', 'model_layer1_1_bn1'),
  ('model_layer1_1_bn1', 'model_layer1_1_relu'),
  ('model_layer1_1_relu', 'model_layer1_1_conv2'),
  ('model_layer1_1_conv2', 'model_layer1_1_bn2'),
  ('model_layer1_1_bn2', 'add_1'),
  ('model_layer1_0_relu_1', 'add_1'),
  ('add_1', 'model_layer1_1_relu_1'),
  ('model_layer1_1_relu_1', 'model_layer2_0_conv1'),
  ('model_layer2_0_conv1', 'model_layer2_0_bn1'),
  ('model_layer2_0_bn1', 'model_layer2_0_relu'),
  ('model_layer2_0_r

In [7]:
edges, weights, a, b = module_to_graph(ResNet18())
# edges, weights
edges[:10], list(weights.items())[:10]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


([('x', 'model_conv1'),
  ('model_conv1', 'model_bn1'),
  ('model_bn1', 'model_relu'),
  ('model_relu', 'model_maxpool'),
  ('model_maxpool', 'model_layer1_0_conv1'),
  ('model_layer1_0_conv1', 'model_layer1_0_bn1'),
  ('model_layer1_0_bn1', 'model_layer1_0_relu'),
  ('model_layer1_0_relu', 'model_layer1_0_conv2'),
  ('model_layer1_0_conv2', 'model_layer1_0_bn2'),
  ('model_layer1_0_bn2', 'add')],
 [('x', 0),
  ('model_conv1', 9408),
  ('model_bn1', 128),
  ('model_relu', 0),
  ('model_maxpool', 0),
  ('model_layer1_0_conv1', 36864),
  ('model_layer1_0_bn1', 128),
  ('model_layer1_0_relu', 0),
  ('model_layer1_0_conv2', 36864),
  ('model_layer1_0_bn2', 128)])

In [8]:
# getting subset to evaulate mean
edges, weights, a, b = module_to_graph(ResNet18())
# edges, weights
edges[:10], list(weights.items())[:10]



train_dl_limit, _ = get_dataloaders([0,1,2,3,4,5,6,7], train_limit=256)# 256
len(train_dl_limit)

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


Files already downloaded and verified
Files already downloaded and verified


16

In [9]:
inter = {}
model = ResNet18(8)
model.load_state_dict(torch.load('./model_last.ckpt', map_location='cpu'))
tr = torch.fx.symbolic_trace(model)
elem_count = 0
for x,_ in train_dl_limit:
    elem_count += x.shape[0]
    i_ = get_inter(tr, x)
    for k in i_:
        try:
            if k not in inter:
                inter[str(k)] = i_[k].sum(0).detach()
            else:
                inter[str(k)] += i_[k].sum(0).detach()
        except:
            print ('bad inter', k)
            #inter[str(k)] = 0.0
    for k in inter:
        inter[k] /= elem_count

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1
bad inter 1


In [10]:
# likelihood of the model without pruning
model = ResNet18(8)
model.load_state_dict(torch.load('./model_last.ckpt', map_location='cpu'))

    
full_ll = test_loop(model, train_dl_limit, nc=8, return_ll=True, device='cpu')
full_ll

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/16 [00:00<?, ?it/s]

12.21679581142962

In [11]:
# ll for pruned models
edge_ll = {}
model = ResNet18(8)
model.load_state_dict(torch.load('./model_last.ckpt', map_location='cpu'))

wrapped = torch.fx.symbolic_trace(model)
    
    
for e in tqdm(edges):
    pruned = PrunedModel(wrapped, {k:1.0 if k != e else 0.0 for k in edges }, inter)
    edge_ll[e] = test_loop(pruned, train_dl_limit, nc=8, return_ll=True, device='cpu')
    

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [13]:
# don't process logit edge, so setting its ll to min
edge_ll[('model_fc', 'output')] = max(edge_ll.values())
edges_importance = [edge_ll[e]/max(edge_ll.values()) for e in edges] #[1.0 - edge_ll[e]/max(edge_ll.values()) + EPS for e in edges]
edges_importance

[0.6867269778513567,
 0.686726980815562,
 0.7028506840060982,
 0.7028506751134821,
 0.5510818931627941,
 0.5510818857522807,
 0.22123894801958208,
 0.22123894301748553,
 0.2212389333838181,
 0.2625129312717981,
 0.637146508061444,
 0.8549246308990748,
 0.2700808469524391,
 0.27008084398823373,
 0.17254874674256804,
 0.1725487485951964,
 0.17254874896572206,
 0.1660348703001716,
 0.7709851022536136,
 1.0,
 0.5571379158972058,
 0.5571379070045898,
 0.334882749167205,
 0.3348827462029997,
 0.33488275657771843,
 0.41678110673124735,
 0.41678111562386344,
 0.3447857561585027,
 0.22906685804327434,
 0.8587983671093765,
 0.18137245321329376,
 0.18137245469539645,
 0.15517665322485003,
 0.1551766526227458,
 0.15517665512379408,
 0.16542635710295725,
 0.7857241984404829,
 0.9755678692730418,
 0.31775229364134494,
 0.31775228771293423,
 0.3061377500262021,
 0.3061377485440994,
 0.3061377478030481,
 0.22559176262090913,
 0.22559176484406315,
 0.32042563339036484,
 0.20896800280219377,
 0.76597884

<!-- ### Conclusion

It seems that the problem is NP-hard. We need to come up with a new approach.
 -->

<!-- # The second attempt

Consider the following heuristic

1. Top sort (v_i, v_j) => i < j

2. for k in {n, ..., 1}

Consider 2 cases:

a) put v_k into the layer of its nearest child + prune some edges

b) put vk into a new layer => 

Consider all subsets of outcoming edges. We instantly identify a layer given a subset. So, we aggregate this layer with the answer of the nearest child to v_k
 -->

## The second attempt + deleting of edges

1. Find all (sample) topological sorts

https://www.geeksforgeeks.org/all-topological-sorts-of-a-directed-acyclic-graph/

2. Apply greedy dynamic programming to find a monotonous solution

3. Postprocess a graph: remove all nodes that are unreacheble from "x".


In [14]:
DG = nx.DiGraph(edges)
all_sorts = list(nx.all_topological_sorts(DG))
len(all_sorts)



9261

In [15]:
len(edges), len(edges_importance)

(78, 78)

In [16]:
# @njit
def dp_for_top_sort(edges, weights, e_importance, top_sort_str, memory=1e10):
    node_ids = {k: i for i, k in enumerate(weights)}
    id_to_node = [node for _, node in enumerate(weights)]
    top_sort = np.array([node_ids[n] for n in top_sort_str])
    assert top_sort[-1] == node_ids['output']
    assert top_sort[0] == node_ids['x']
    assert top_sort.shape[0] == len(node_ids)
    m = np.zeros((len(node_ids), len(node_ids))).astype(np.int32)
    id_to_weight = np.array([weights[n] for n in id_to_node])
    assert len(edges) == len(e_importance)
    for (src, dst), w in zip(edges, e_importance):
        src_id, dst_id = node_ids[src], node_ids[dst]
        m[src_id, dst_id] = w
        
    node_to_layers = np.ones((len(node_ids), len(node_ids))).astype(np.int32) * (-100)  # ans for each v ->
    node_to_layers[top_sort[-1], top_sort[-1]] = 0
    dp = [1e9] * len(node_ids)
    dp[top_sort[-1]] = 0
    for i in range(len(node_ids) - 2, -1, -1):
        v = top_sort[i]
        for j in range(i, len(node_ids)):  # the last node of the first layer (starting from v)
            if id_to_weight[top_sort[i: j + 1]].sum() > memory:
                continue
            if j == len(node_ids) - 1:
                dp[v] = 0
                node_to_layers[v, top_sort[i:]] = 0
                continue
            v_j = top_sort[j + 1]
            next_layer_ids = [] if j + 1 >= len(node_ids) else \
            [k for k in range(m.shape[0]) if node_to_layers[v_j, k] == node_to_layers[v_j, v_j]]
            pruned_value = sum([m[top_sort[k], l] for k in range(i, j + 1) for l in next_layer_ids if m[top_sort[k], l] != 0])
            if dp[v_j] + pruned_value <= dp[v]:
                dp[v] = dp[v_j] + pruned_value
                node_to_layers[v] = node_to_layers[v_j]
                node_to_layers[v, top_sort[i:j + 1]] = node_to_layers[v].max() + 1
                
    # prune restricted edges (TODO: also prune unreacheble nodes)
    ans = node_to_layers[node_ids['x']]
    pruned_edges_ids = [(i, j) for i in range(m.shape[0]) for j in range(m.shape[0]) \
                    if m[i, j] != 0 and abs(ans[i] - ans[j]) > 1]
    pruned_edges = [(id_to_node[i], id_to_node[j]) for i, j in pruned_edges_ids]
    pruned_value = sum([m[i, j] for i, j in pruned_edges_ids])
    
    reach_ids = [set() for _ in range(len(node_ids))]
    for i in range(len(node_ids) - 1, -1, -1):
        v = top_sort[i]
        reach_ids[v].add(v)
        for k in range(i + 1, m.shape[0]):
            v_c = top_sort[k]
            if m[v, v_c] != 0 and (v, v_c) not in pruned_edges_ids:
                reach_ids[v] |= reach_ids[v_c]
    conn_g = top_sort[-1] in reach_ids[top_sort[0]]
                
    return {'node_to_layer': {id_to_node[i]: ans.max() - l for i, l in enumerate(ans)},
            'pruned_value': pruned_value, 'pruned_edges': pruned_edges,
            'connected_graph': conn_g}



In [17]:
### find the best solution
import pickle
with open('./naive_mean.pckl', 'rb') as inp:
    eval_dict, fine_dict = pickle.loads(inp.read())
#eval_dict = {}
#fine_dict = {}

for edge_mem in range(2,6):
    if edge_mem in eval_dict and edge_mem in fine_dict:
        print ('skip', edge_mem)
        continue
    #if edge_mem == 3 and attemp > 0:
    #    continue
    #if edge_mem == 4 and attemp > 0:
    #    continue

    model = ResNet18(8)
    model.load_state_dict(torch.load('./model_last.ckpt', map_location='cpu'))
    warp = GraphInterperterWithGamma(model)

    named_dict = dict(model.named_modules())

    for node in warp.graph.nodes:
        if node.op == 'call_module':
            pass
            # print('Norm', np.sqrt(sum([(p ** 2).sum().item() \
            #                            for p in named_dict[node.target].parameters()])))
        # print(node.op, node.name, node.args)
        # break


    edges, weights, a, b = module_to_graph(ResNet18())
    # edges, weights
    edges[:10], list(weights.items())[:10]
    edges_importance = [1]*len(edges)

    best_pruned = 1e10
    best_val = None
    for s in tqdm(all_sorts):
        # for mem=8 the computation takes time
        res = dp_for_top_sort(edges, {k: 1 for k in weights}, edges_importance, s, edge_mem)  
        if res['connected_graph'] == True and best_pruned > res['pruned_value']:
            best_pruned = res['pruned_value']
            best_val = res

        if best_pruned == 0:
            print(res)
            break


    model = ResNet18(8)
    model.load_state_dict(torch.load('./model_last.ckpt', map_location='cpu'))

    wrapped = torch.fx.symbolic_trace(model)
    pruned = PrunedModel(wrapped, {k:1.0 if k not in best_val['pruned_edges'] else 0.0 for k in edges }, inter)

    res = test_loop(pruned, test_dl,  "cpu", nc=8)
    if edge_mem not in eval_dict:
        eval_dict[edge_mem] = []
    eval_dict[edge_mem].append(res)

    train_loop(pruned, train_dl, test_dl, 9999999999, 1, 1e-3,  "cpu")
    res = test_loop(pruned, test_dl,  "cpu", nc=8)
    if edge_mem not in fine_dict:
        fine_dict[edge_mem] = []
    fine_dict[edge_mem].append(res)

    import pickle
    with open('naive_mean.pckl', 'wb') as out:
        out.write(pickle.dumps([eval_dict, fine_dict]))


skip 2


Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/9261 [00:00<?, ?it/s]

Using cache found in /home/legin/.cache/torch/hub/pytorch_vision_v0.10.0


  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [106]:
fine_dict

{}

In [67]:
if edge_mem not in eval_dict:
    eval_dict[edge_mem] = []
eval_dict[edge_mem].append(res)

train_loop(pruned, train_dl, test_dl, 9999999999, 1, 1e-3,  "cpu")
res = test_loop(pruned, test_dl,  "cpu", nc=8)
if edge_mem not in fine_dict:
    fine_dict[edge_mem] = []
fine_dict[edge_mem].append(res)

import pickle
with open('naive.pckl', 'wb') as out:
    out.write(pickle.dumps([eval_dict, fine_dict]))


  0%|          | 0/2500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [2]:
import pickle
with open('naive.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.125], 3: [0.125], 4: [0.12524999678134918], 5: [0.23675000667572021]},
 {2: [0.5216249823570251],
  3: [0.5532500147819519],
  4: [0.656624972820282],
  5: [0.6958749890327454]}]

In [2]:
import pickle
with open('randn.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.125, 0.12600000202655792, 0.125],
  3: [0.125, 0.125, 0.125],
  4: [0.12524999678134918, 0.12524999678134918, 0.125],
  5: [0.1368750035762787, 0.125, 0.21812500059604645]},
 {2: [0.5640000104904175, 0.5180000066757202, 0.5730000138282776],
  3: [0.5625, 0.5435000061988831, 0.5742499828338623],
  4: [0.6353750228881836, 0.6573749780654907, 0.6784999966621399],
  5: [0.6420000195503235, 0.6677500009536743, 0.6775000095367432]}]

In [16]:
import pickle
with open('naive_mean.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.125], 3: [0.125], 4: [0.1264999955892563], 5: [0.2516250014305115]},
 {2: [0.5951250195503235],
  3: [0.6156250238418579],
  4: [0.6588749885559082],
  5: [0.6837499737739563]}]

In [5]:
import pickle
with open('random_mean.pckl', 'rb') as inp:
    data = pickle.loads(inp.read())
data

[{2: [0.12962499260902405, 0.125, 0.125],
  3: [0.125, 0.125, 0.125],
  4: [0.12612499296665192, 0.12612499296665192, 0.125],
  5: [0.12587499618530273, 0.12587499618530273]},
 {2: [0.5892500281333923, 0.5241249799728394, 0.6132500171661377],
  3: [0.5706250071525574, 0.6041250228881836, 0.5644999742507935],
  4: [0.6047499775886536, 0.6507499814033508, 0.6359999775886536],
  5: [0.6508749723434448, 0.6942499876022339]}]