In [47]:
%load_ext autoreload
%autoreload 2
import torch

from datasets.symbolic_1 import get_dataloaders, INPUT_DIM, OUTPUT_DIM
from architectures import BioMLP
from nx_utils import STATE_DICT
import pathlib
from datetime import datetime

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
# ORIGINAL BIMT HYPERPARAMETERS
shp = [INPUT_DIM, 20, 20, OUTPUT_DIM]
model = BioMLP(shp=shp)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
log = 200
lamb = 0.001
dump_every = 200
swap_log = 200
weight_factor = 1.
plot_log = 50
epochs = 20000

In [96]:
# ORIGINAL BIMT HYPERPARAMETERS
shp = [INPUT_DIM, 20, 20, OUTPUT_DIM]
model = BioMLP(shp=shp)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
log = 200
lamb = 0.001
dump_every = 50
swap_log = 200
weight_factor = 1.
plot_log = 50
epochs = 20000

In [97]:
name = BioMLP.__name__ + str(shp).replace(', ','_').replace('[','_').replace(']','')
base = pathlib.Path("models") / name / datetime.now().strftime("%Y-%m-%d_%H%M%S")
base.mkdir(parents=True, exist_ok=True)
base

PosixPath('models/BioMLP_4_20_20_2/2023-09-17_164345')

In [98]:
train_loader, test_loader = get_dataloaders()
for step in range(epochs):
    
    if step == int(epochs/4):
        lamb *= 10
    
    if step == int(3*epochs/4):
        lamb *= 0.1

    for i, (x, y) in enumerate(test_loader):
        pred_test  = model(x)
        loss_test = torch.mean((pred_test-y)**2)
        
    for i, (x, y) in enumerate(train_loader):

        optimizer.zero_grad()
        pred  = model(x)
        loss = torch.mean((pred-y)**2)

        # do not penalize bias at first (this makes the weight graph look better)
        training_in_last_quarter = step > int(3*epochs/4)
        penalize = True if training_in_last_quarter else False
        reg = model.get_cc(bias_penalize=penalize, weight_factor=weight_factor)

        #reg = model.get_cc(bias_penalize=True)
        total_loss = loss + lamb*reg
        total_loss.backward()
        optimizer.step()
    
    if step % log == 0:
        print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | reg: %.2e "%(step, total_loss.detach().numpy(), loss.detach().numpy(), loss_test.detach().numpy(), reg.detach().numpy()))
    
    if step % swap_log == 0:
    #if (step+1) % swap_log == 0:
        # TODO: this results in large weights for one epoch. WHY?
        model.relocate()

    if step % dump_every == 0:
        torch.save({
                'epoch': step,
                STATE_DICT: model.state_dict(),
                # 'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': loss.item(),
                'test_loss': loss_test.item(),
                }, base / f"{step}.pt")

step = 0 | total loss: 1.22e+00 | train loss: 1.19e+00 | test loss 1.75e+00 | reg: 2.92e+01 
step = 200 | total loss: 2.85e-01 | train loss: 2.50e-01 | test loss 6.43e-01 | reg: 3.50e+01 
step = 400 | total loss: 2.18e-01 | train loss: 1.85e-01 | test loss 5.32e-01 | reg: 3.28e+01 
step = 600 | total loss: 8.72e-02 | train loss: 5.01e-02 | test loss 1.50e-01 | reg: 3.72e+01 
step = 800 | total loss: 4.09e-02 | train loss: 6.10e-03 | test loss 3.76e-02 | reg: 3.48e+01 
step = 1000 | total loss: 3.10e-02 | train loss: 2.64e-03 | test loss 2.29e-02 | reg: 2.84e+01 
step = 1200 | total loss: 2.66e-02 | train loss: 1.91e-03 | test loss 1.81e-02 | reg: 2.47e+01 
step = 1400 | total loss: 2.31e-02 | train loss: 1.54e-03 | test loss 1.56e-02 | reg: 2.15e+01 
step = 1600 | total loss: 1.99e-02 | train loss: 1.21e-03 | test loss 1.47e-02 | reg: 1.87e+01 
step = 1800 | total loss: 1.75e-02 | train loss: 8.88e-04 | test loss 1.39e-02 | reg: 1.66e+01 
step = 2000 | total loss: 1.58e-02 | train loss

In [104]:
import networkx as nx
from nx_utils import( 
    get_weights_and_biases_from_state_dict, 
    load_state_dict_from_file, 
    get_shape_from_state_dict,
    add_neuron_nodes,
    black_and_white,
    add_weight_edges_arrays,
    get_layers_of_nodes,
    layerwise_normalized_abs_value
)
from bokeh_utils import (
    draw_interactive_mlp_graph,
    LINE_COLOR_LIST,
    LINE_WIDTH_LIST
    )

In [108]:
list_of_files = sorted(base.glob("*.pt"))
state_dict = load_state_dict_from_file(list_of_files[0])
weight_shapes = get_shape_from_state_dict(state_dict)

# retrieve weights from state_dicts
weights = [
    get_weights_and_biases_from_state_dict(load_state_dict_from_file(f))[0]
    for f in list_of_files
]

# reshape weights to include "time"-dimension as the first dim. 
layers_of_weights = [torch.stack(x) for x in zip(*weights)]

In [123]:

# add functions with names, which add attributes to the graph
attribute_functions = {
       LINE_COLOR_LIST : black_and_white(layers_of_weights),
       LINE_WIDTH_LIST : layerwise_normalized_abs_value(layers_of_weights),
    }

# create a graph and populate with nodes
G = nx.DiGraph()
add_neuron_nodes(G, weight_shapes)
add_weight_edges_arrays(G, get_layers_of_nodes(G), attribute_functions)

In [124]:

# draw the network with bokeh
draw_interactive_mlp_graph(G)