In [None]:
%load_ext autoreload
%autoreload 2

import torch

from datasets import symbolic_1
from architectures.bio_mlp import BioMLP

from common import STATE_DICT

import pathlib
from datetime import datetime

In [None]:
# ORIGINAL BIMT HYPERPARAMETERS
shp = [symbolic_1.INPUT_DIM, 20, 20, symbolic_1.OUTPUT_DIM]
model = BioMLP(shp=shp)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
log = 200
lamb = 0.001
dump_every = 200
swap_log = 200
weight_factor = 1.
plot_log = 50
epochs = 20000

In [None]:
# ORIGINAL BIMT HYPERPARAMETERS
shp = [symbolic_1.INPUT_DIM, 20, 20, symbolic_1.OUTPUT_DIM]
model = BioMLP(shp=shp)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.0)
log = 200
lamb = 0.001
dump_every = 50
swap_log = 200
weight_factor = 1.
plot_log = 50
epochs = 20000

In [None]:
# TODO: this should be done in the save method

name = BioMLP.__name__ + str(shp).replace(', ','_').replace('[','_').replace(']','')
base = pathlib.Path("models") / name / datetime.now().strftime("%Y-%m-%d_%H%M%S")
base.mkdir(parents=True, exist_ok=True)
base

In [None]:
train_loader, test_loader = symbolic_1.get_dataloaders()

for step in range(epochs):
    
    if step == int(epochs/4):
        lamb *= 10
    
    if step == int(3*epochs/4):
        lamb *= 0.1

    for i, (x, y) in enumerate(test_loader):
        pred_test  = model(x)
        loss_test = torch.mean((pred_test-y)**2)
        
    for i, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        pred  = model(x)
        loss = torch.mean((pred-y)**2)

        # do not penalize bias at first (this makes the weight graph look better)
        training_in_last_quarter = step > int(3*epochs/4)
        penalize = True if training_in_last_quarter else False
        reg = model.get_cc(bias_penalize=penalize, weight_factor=weight_factor)

        #reg = model.get_cc(bias_penalize=True)
        total_loss = loss + lamb*reg
        total_loss.backward()
        optimizer.step()
    
    if step % log == 0:
        print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | reg: %.2e "%(step, total_loss.detach().numpy(), loss.detach().numpy(), loss_test.detach().numpy(), reg.detach().numpy()))
    
    if step % swap_log == 0:
    #if (step+1) % swap_log == 0:
        # TODO: this results in large weights for one epoch. WHY?
        model.relocate()

    if step % dump_every == 0:
        torch.save({
                'epoch': step,
                STATE_DICT: model.state_dict(),
                # 'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': loss.item(),
                'test_loss': loss_test.item(),
                }, base / f"{step}.pt")

In [None]:
import networkx as nx
from common.nx_utils import( 
    get_weights_from_state_dict, 
    load_state_dict, 
    get_shape_from_state_dict,
    add_neuron_nodes,
    black_and_white,
    add_weight_edges_arrays,
    get_layers_of_nodes,
    layerwise_normalized_abs_value
)
from common.bokeh_utils import (
    draw_interactive_mlp_graph,
    LINE_COLOR_LIST,
    LINE_WIDTH_LIST
    )

In [None]:
import re
base = pathlib.Path("models/BioMLP_4_20_20_2/2023-09-20_112428")
chkpts = list(base.glob("*.pt"))
sort_by_integer_in_filename_key = lambda x : int(*re.findall("(\d+)",x.name))
sorted_chkpts = sorted(chkpts, key=sort_by_integer_in_filename_key)

In [None]:
state_dict = load_state_dict(chkpts[0])
weight_shapes = get_shape_from_state_dict(state_dict)

# retrieve weights from state_dicts
state_dicts = [load_state_dict(f) for f in sorted_chkpts]
weights = [get_weights_from_state_dict(sd) for sd in state_dicts]

# reshape weights to include "time"-dimension as the first dim. 
layers_of_weights = [torch.stack(x) for x in zip(*weights)]

In [None]:

# add functions with names, which add attributes to the graph
attribute_functions = {
       LINE_COLOR_LIST : black_and_white(layers_of_weights),
       LINE_WIDTH_LIST : layerwise_normalized_abs_value(layers_of_weights),
    }

# create a graph and populate with nodes
G = nx.DiGraph()
add_neuron_nodes(G, weight_shapes)
add_weight_edges_arrays(G, get_layers_of_nodes(G), attribute_functions)

In [None]:
# draw the network with bokeh
draw_interactive_mlp_graph(G)