In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import wandb
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from wandb.integration.sb3 import WandbCallback

import main as M
import three_goals as tg

In [None]:
env = lambda: (
    tg.AddTrueGoalWrapper(
        M.FlatOneHotWrapper(tg.ThreeGoalsEnv(0, 4))
    )
)
print(env())
print(env().observation_space)
print(env().action_space)



In [None]:
net = tg.SwitchMLP(4, (64, 64), 2, 3, 0)
print(net)

In [None]:
x = torch.tensor([
    [1.0, 2, 3, 4],
    [4.0, 3, 2, 1],
])
switch_idx = [0, 2]
switch = torch.zeros((2, 3))
switch[range(2), switch_idx] = 1

print(x)
print(switch)
print(net(x, switch))


In [None]:
def unique(x, *, __previous=set()):
    """Return the argument, if it was never seen before, otherwise raise ValueError"""
    if x in __previous:
        raise ValueError(f"Duplicate value {x}")
    __previous.add(x)
    return x

In [None]:
env = lambda: (
    tg.AddTrueGoalWrapper(
        M.FlatOneHotWrapper(tg.ThreeGoalsEnv(None, 4))
    )
)
n_env = 100
policy = PPO(
    # "MlpPolicy",
    tg.SwitchActorCriticPolicy,
    make_vec_env(env, n_envs=n_env, seed=42),
    policy_kwargs=dict(
        arch_kwargs=dict(
            switched_layer=0,
            hidden=[64, 64],
            out_dim=4,
            n_switches=3,
            l1_reg=0.0001,
        ),
    ),
    verbose=2,
    n_epochs=40,
    n_steps=8_000 // n_env,
    batch_size=400,
    learning_rate=lambda f: 5e-4 * f ** 2 + 2e-5 * (1 - f ** 2),
    # policy_kwargs=policy_kwargs,  # optimizer_kwargs=dict(weight_decay=weight_decay)),
    # arch_kwargs=dict(net_arch=net_arch, features_extractor_class=BaseFeaturesExtractor),
    tensorboard_log="run_logs",
    # device='cuda:1'
)
print(policy.policy)
wandb.init(
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    save_code=True,
    notes=unique("""
    
""")
)
policy.learn(total_timesteps=1_200_000, callback=WandbCallback(2))

In [None]:
# Lower the learning rate
policy.lr_schedule = lambda _: 0.00002

In [None]:
policy.learn(total_timesteps=300_000, reset_num_timesteps=False)

In [None]:
M.show_behavior(policy, env(), 10, width=1600, height=1400)

In [None]:
policy.save("models/switched-3goals-4x4-switch-last-2")

In [None]:
net = policy.policy.mlp_extractor.policy_net
net: tg.SwitchMLP
print(policy.policy)

In [None]:
env().ALL_CELLS

In [None]:
PPO.load("models/")

In [None]:
def imshow(x, **kwargs):
    w, h = x.shape[-2:]
    width = 50 + h * 25
    height = w * 25 + 50 * ('title' in kwargs)
    while width < 500 and height < 500:
        print(width, height)
        width *= 2
        height *= 2
    new = dict(
        width=max(300, width),
        height=max(300, height),
        facet_row_spacing=0.01,
        facet_col_spacing=0.01,
    )
    kwargs = {**new, **kwargs}
    print(kwargs)
    px.imshow(x, **kwargs).show()


imshow(policy.policy.mlp_extractor.switch_biases)

In [None]:
import einops
# Plot the three switch layers
import plotly.express as px

switch_layers_weights = policy.policy.mlp_extractor.switch_weights

# shape of a switch (out_dim, row, col, obj_type)
w1 = einops.rearrange(switch_layers_weights, 'agent out_dim (row col obj_type) -> agent obj_type out_dim row col', row=4, col=4)
b1 = einops.repeat(policy.policy.mlp_extractor.switch_biases, 'agent out_dim -> agent 1 out_dim 4 1')

avg = w1.mean(dim=0)
TYPES = ['empty', 'agent', 'goal_red', 'goal_green', 'goal_blue']
EMPTY, AGENT, RED, GREEN, BLUE = range(5)

print(w1.shape, avg.shape)
# d = w1[:, EMPTY] - avg[EMPTY] 
d = avg

d = w1
print(d.shape)
# Add one black col
d = torch.cat([d, torch.zeros(*d.shape[:-1], 1) + float('nan')], dim=-1)
d = torch.cat([d, torch.zeros(*d.shape[:-2], 1, d.shape[-1]) + float('nan')], dim=-2)
d = einops.rearrange(d, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
b1 = torch.cat([b1, torch.zeros(*b1.shape[:-2], 1, b1.shape[-1]) + float('nan')], dim=-2)
b1 = einops.rearrange(b1, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
print(d.shape, b1.shape)
d = torch.cat([d, b1], dim=-1)
# Remove weights close to zero
# d[abs(d) < 0.1] = float('nan')
imshow(d, title='First layer weights', facet_col=0, facet_col_wrap=4,
          height=4000,
          width=None,
          color_continuous_scale='RdBu',
          color_continuous_midpoint=0,
          zmax=0.6,
          zmin=-0.6,
          )


In [None]:
imshow(w1.flatten(2).mean(dim=2),
          title="Mean of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
          )
imshow(w1.flatten(2).abs().mean(dim=2), 
          title="Mean absolute value of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
          width=800,
          )
imshow(w1.flatten(2).std(dim=2),
            title="Std of the weights of the switch layers",
            labels=dict(x="Object type", y="Agent"),
            width=800,
          )


In [None]:
last_layer: torch.nn.Linear = policy.policy.action_net
last_weights = last_layer.weight.detach().cpu().clone()
last_bias = last_layer.bias.detach().cpu().clone()

weights = torch.cat([
    last_weights.T @ net.switches[i].weight.detach().cpu().clone()
    for i in range(3)
], dim=0)
imshow(weights)

biases = torch.stack([net.switches[i].bias.detach().cpu() for i in range(3)], dim=1)
imshow(biases)

In [None]:
# compute correlations between rows of w2
w2 = net.post_switch[1].weight.detach().cpu().clone()  # (64, 64)
w2 = w2 / w2.norm(dim=1, keepdim=True)
corr = w2 @ w2.T

# Cluster the correlations matrix
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as ssd
import plotly.figure_factory as ff

# Compute and plot first dendrogram.
fig = ff.create_dendrogram(
    corr.numpy(),
    orientation='left',
    labels=list(range(64)),
    linkagefun=lambda x: sch.linkage(x, 'single'),
    distfun=lambda x: ssd.pdist(x, 'euclidean'),
)
fig.update_layout(width=1000, height=1000)
fig.show()

# Remove the diagonal
corr[range(64), range(64)] = float('nan')
px.imshow(corr, width=1000, height=1000)