In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../../project")

In [None]:
import torch
import torch.nn.utils.prune as prune
import networkx as nx
from common.torch_utils import module_is_trainable
from common.constants import *
import common.constants as vocab
from common import testutils
from common import nxutils
import wandb

from common.models import BinaryClassifierMLP

def __draw_nx(G, layout_G=None):
    if layout_G is None:
        layout_G = G

    pos = nx.multipartite_layout(layout_G, 'layer')
    edge_colors = [nxutils.__colormap(data.get('state')) for *_, data in G.edges(data=True)]
    node_colors = [nxutils.__colormap(data.get('state')) for _, data in G.nodes(data=True)]

    nx.draw_networkx_nodes(G, pos, nodelist=G.nodes(), node_color=node_colors)

    nx.draw_networkx_edges(
        G,
        pos,
        edgelist=G.edges(),
        width=6,
        alpha=0.7,
        edge_color=edge_colors,
    )

In [None]:
model = BinaryClassifierMLP(((4,6,2)))
G = nxutils.build_nx(model)
nxutils.__draw_nx(G)

In [None]:
model = BinaryClassifierMLP(((4,6,8,6,4)), seed=2)
model = BinaryClassifierMLP(((4,6,8,6,4)), seed=1)
G = nxutils.build_nx(model)

params = (
    [(module, 'weight') for module in model.layers]
    + [(module, 'bias') for module in model.layers]
)
prune.global_unstructured(params, prune.Identity)
G = nxutils.build_nx(model)

out_features = nxutils.out_features(G, len(model.layers))
in_features  = nxutils.in_features(G)

# prune
prune.global_unstructured(
    params, prune.L1Unstructured, amount=0.5
)
nxutils.tag_params(G, model, in_features, out_features)
nxutils.__draw_nx(nxutils.subgraph_by_state(G, exclude=[vocab.ParamState.pruned]), G)

# prune
prune.global_unstructured(
    params, prune.L1Unstructured, amount=0.5
)
nxutils.tag_params(G, model, in_features, out_features)
nxutils.__draw_nx(nxutils.subgraph_by_state(G, exclude=[vocab.ParamState.pruned]), G)

# prune
prune.global_unstructured(
    params, prune.L1Unstructured, amount=0.5
)
nxutils.tag_params(G, model, in_features, out_features)
nxutils.__draw_nx(nxutils.subgraph_by_state(G, exclude=[vocab.ParamState.pruned]), G)

In [None]:
task_in_size=2
task_out_size=1
task_description = (
    ('moons' , (task_in_size, task_out_size)), 
    ('circles', (task_in_size, task_out_size)), 
    ('spirals', (task_in_size, task_out_size))
)
num_tasks = len(task_description)
shape = [
    task_in_size*num_tasks,
    5*num_tasks,
    5*num_tasks,
    task_out_size*num_tasks
]

model = BinaryClassifierMLP(shape)
model = testutils.make_splitable_model(model, num_tasks, tile_offset= 0.1)
gm = nxutils.GraphManager(model, shape, task_description, start_iteration=0)

gm.fig()

In [None]:
from common.log import Logger

mode = 'online'
mode = 'disabled'
mode=None

it = 10
torch.manual_seed(1)

task_in_size=2
task_out_size=1
task_description = (
    ('moons' , (task_in_size, task_out_size)), 
    ('circles', (task_in_size, task_out_size)), 
    ('spirals', (task_in_size, task_out_size))
)
num_tasks = len(task_description)
shape = [
    task_in_size*num_tasks,
    5*num_tasks,
    5*num_tasks,
    task_out_size*num_tasks
]

log = Logger(task_description, True)
model = BinaryClassifierMLP(shape)
model = testutils.make_splitable_model(model, num_tasks, tile_offset= 0.1)
params = (
    [(module, 'weight') for module in model.layers]
    + [(module, 'bias') for module in model.layers]
)
prune.global_unstructured(params, prune.Identity)
gm = nxutils.GraphManager(model, shape, task_description, 0)

with wandb.init(project="test-plots", mode=mode) as run:

    for i in range(it):
        print(f'{"-"*30} it {i} begin {"-"*30} ')

        gm.update(model)
        log.splitting(gm)
        log.graphs(gm)
        log.feature_categorization(gm)
        log.commit()
        gm.fig(include=[vocab.ParamState.active]).show()

        print('prune the network')
        prune.global_unstructured(
            parameters=params, 
            pruning_method=prune.L1Unstructured, 
            amount=0.2
        )
        print(f'{"-"*30} it {i} end {"-"*30} ')

    gm.update(model)
    log.splitting(gm)
    log.graphs(gm)
    log.feature_categorization(gm)
    log.commit()
    gm.fig().show()
