In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("../../project")

In [None]:
import torch
import torch.nn.utils.prune as prune

from common.torch_utils import module_is_trainable
from common.constants import *
from common import testutils
from common import nxutils
import wandb

In [None]:
task_in_size=2
task_out_size=1
task_description = (
    ('moons' , (task_in_size, task_out_size)), 
    ('circles', (task_in_size, task_out_size)), 
    ('spirals', (task_in_size, task_out_size))
)
num_tasks = len(task_description)
shape = [
    task_in_size*num_tasks,
    5*num_tasks,
    5*num_tasks,
    task_out_size*num_tasks
]

model = testutils.BaseModel(shape, init_bias_zero=True)
model = testutils.make_splitable_model(model, num_tasks, tile_offset= 0.1)
nxutils.GraphManager(model, shape, task_description)

In [None]:
torch.manual_seed(1)

shape = (6,6,6,2)
task_description = ('moons' , (3, 1)), ('circles', (3, 1))
num_tasks = len(task_description)

model = testutils.BaseModel(shape, init_bias_zero=False)
model = testutils.make_splitable_model(model, num_tasks, tile_offset= 0.1)
modules = [m for m in model.modules if module_is_trainable(m)]

params = []
params.extend([(module, 'weight') for module in modules])
params.extend([(module, 'bias') for module in modules])

gm = nxutils.GraphManager(model, shape, task_description)

print(f'INITIAL STATE')
#gm.plot()
gm.print_info()

print(f'START PRUNING')
with wandb.init(project="test-plots") as run:
    for i in range(9):
        print(f'{"-"*30} it {i} begin {"-"*30} ')

        print(f'update the graph manager')
        gm.update(model)

        print(f'print info')
        gm.print_info()

        print('plot the model')
        #gm.plot()

        for name, g in gm.catalogue.items():
            run.log({name: gm.make_plotly(g)})

        print('prune the network')
        prune.global_unstructured(
            parameters=params, 
            pruning_method=prune.L1Unstructured, 
            amount=0.2
        )
        print(f'{"-"*30} it {i} end {"-"*30} ')
