In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import einops
import torch
import wandb
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_checker import check_env
from wandb.integration.sb3 import WandbCallback
import plotly.express as px

import main as M
import three_goals as tg

BIG = dict(width=1600, height=1600)

In [None]:
red, green, blue = tg.ThreeGoalsEnv.GOAL_CELLS
env = lambda: (
    tg.ColorBlindWrapper.merged(
        tg.AddTrueGoalWrapper(
            tg.ThreeGoalsEnv(4, true_goal=None)
        ),
        red, green
    )
)
check_env(env())
print(env())
print(env().observation_space)
print(env().action_space)
o, _ = env().reset()
switch = o['switch']
print(f"{switch=}")
obs = o['obs']
print(obs.astype(int).reshape(4, 4, -1))



In [None]:
net = tg.SwitchMLP(4, (64, 64), 2, 3, 0)
print(net)

In [None]:
x = torch.tensor([
    [1.0, 2, 3, 4],
    [4.0, 3, 2, 1],
])
switch_idx = [0, 2]
switch = torch.zeros((2, 3))
switch[range(2), switch_idx] = 1

print(x)
print(switch)
print(net(x, switch))


In [None]:
def unique(x, *, __previous=set()):
    """Return the argument, if it was never seen before, otherwise raise ValueError"""
    if x in __previous:
        raise ValueError(f"Duplicate value {x}")
    __previous.add(x)
    return x

class WandbWithBehaviorCallback(WandbCallback):
    def __init__(self, show_every=10, **kwargs):
        self.show_every = show_every
        self.time = 0
        super().__init__(**kwargs)
    def _on_rollout_start(self) -> None:
        super()._on_rollout_start()
        # Show every  10 episodes
        self.time += 1
        if self.time % self.show_every == 0:
            M.show_behavior(self.model, env(), max_len=20, add_to_wandb=True, plot=False)

In [None]:
env_size = 4
red, green, blue = tg.ThreeGoalsEnv.GOAL_CELLS
env = lambda: (
    tg.ColorBlindWrapper.merged(
        tg.AddTrueGoalWrapper(
            tg.ThreeGoalsEnv(env_size)
        ),
        red, green
    )
)
# env = lambda: (
#     tg.AddTrueGoalWrapper(
#         M.FlatOneHotWrapper(tg.ThreeGoalsEnv(None, env_size))
#     )
# )
n_env = 32
lr_start, lr_end = 5e-3, 2e-5
policy = PPO(
    # "MlpPolicy",
    tg.SwitchActorCriticPolicy,
    make_vec_env(env, n_envs=n_env, seed=42),
    policy_kwargs=dict(
        arch_kwargs=dict(
            switched_layer=0,
            hidden=[32, 32],
            out_dim=4,
            n_switches=2,
            l1_reg=1e-5,
        ),
    ),
    verbose=2,
    # n_epochs=40,
    n_steps=2_048 // n_env,
    # batch_size=400,
    # learning_rate=lambda f: lr_start * f ** 2 + lr_end * (1 - f ** 2),
    # policy_kwargs=policy_kwargs,  # optimizer_kwargs=dict(weight_decay=weight_decay)),
    # arch_kwargs=dict(net_arch=net_arch, features_extractor_class=BaseFeaturesExtractor),
    tensorboard_log="run_logs",
    # device='cuda:1'
)
print(policy.policy)

wandb.init(
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    save_code=True,
    config=dict(
        env_size=env_size,
        lr_start=lr_start,
        lr_end=lr_end,
        arch=policy.policy_kwargs['arch_kwargs']
    ),
    project="switched-3goals-blind",
    notes=unique("""
Colorblind agent Rewards: true goal=1, wrong goal=0, nothing=-1 at the end of the trajectory
Arch 32,32, with regularisation L1=1e-6
Regularisation bumped to 1e-5, it wasn't enough
Lowered batch size to the default (400 -> 64)
Lowered epochs to the default (40 -> 10)
Lowered n_steps (8000 -> 2000)
Blind BOTH the red-green and blue agents
(fixed bug × 2)
Changed rewards: true goal=1, wrong goal=0, step=-1/16 (16=board size)
""")
)

policy.learn(total_timesteps=600_000, callback=WandbWithBehaviorCallback())

In [None]:
policy.lr_schedule = lambda _: 1e-5
policy.learn(total_timesteps=100_000, reset_num_timesteps=False, callback=WandbWithBehaviorCallback())

In [None]:
# M.show_behavior(policy, env(), width=1600, height=1600, max_len=20
                # add_to_wandb=True,
                # )
    
wrappers = [
    tg.AddTrueGoalWrapper,
    lambda e: tg.ColorBlindWrapper.merged(e, [red, green], [blue], disabled=False),
]
M.show_behavior(policy, tg.ThreeGoalsEnv.interesting(4, 10, wrappers), **BIG)

In [None]:
def add_wrappers(e, disabled):
    e = tg.AddTrueGoalWrapper(e)
    e = tg.ColorBlindWrapper.merged(
        e,
        red, green,
        disabled=disabled)
    return e

envs = []
for i in range(6):
    base_env = tg.ThreeGoalsEnv.constant(env_size, true_goal={"red": 1, "green": 1})
    env_blind = add_wrappers(base_env, disabled=False)
    env_color = add_wrappers(base_env, disabled=True)
    envs.append(env_blind)
    envs.append(env_color)
        
M.show_behavior(policy, envs, **BIG)

In [None]:
# Print all models
models_dir = Path("models")
for model_path in sorted(models_dir.glob("*.zip")):
    print(model_path.name)

In [None]:
name = "switched-blind=rg-4x4-switch=first-L1=1e-5"
path = models_dir / f"{name}.zip"
if path.exists():
    print("Loading existing model")
    previous_policy = policy  # Saved, in case I wanted to save it, but forgot to change the name
    policy = PPO.load(path)
else:
    print(f"Saving model to {path}")
    policy.save(path)

# Visualize the weights

In [None]:
def imshow(x, **kwargs):
    h, w = x.shape[-2:]
    if 'facet_col' in kwargs:
        wrap = kwargs.get('facet_col_wrap', 1)
        h *= np.ceil(x.shape[kwargs['facet_col']] / wrap)
        w *= wrap
    width = 50 + w * 25
    height = h * 25 + 50 * ('title' in kwargs)
    while width < 500 and height < 500:
        width *= 2
        height *= 2
    new = dict(
        width=max(300, width),
        height=max(300, height),
        facet_row_spacing=0.01,
        facet_col_spacing=0.01,
    )
    kwargs = {**new, **kwargs}
    px.imshow(x, **kwargs).show()


switch_biases = policy.policy.mlp_extractor.switch_biases
print("Biases shape:", switch_biases.shape)
print("Max abs bias:", switch_biases.abs().max(dim=-1).values)
imshow(switch_biases, title='Biases of the switch layers', color_continuous_scale='RdBu', color_continuous_midpoint=0)

In [None]:
# Plot the three switch layers

switch_layers_weights = policy.policy.mlp_extractor.switch_weights

# shape of a switch (out_dim, row, col, obj_type)
w1 = einops.rearrange(switch_layers_weights, 'agent out_dim (row col obj_type) -> agent obj_type out_dim row col', row=4, col=4)
b1 = einops.repeat(policy.policy.mlp_extractor.switch_biases, 'agent out_dim -> agent 1 out_dim 4 1')

avg = w1.mean(dim=0)
TYPES = ['empty', 'agent', 'goal_red', 'goal_green', 'goal_blue']
EMPTY, AGENT, RED, GREEN, BLUE = range(5)

print(w1.shape, avg.shape)
# d = w1[:, EMPTY] - avg[EMPTY] 
d = avg

d = w1
print(d.shape)
# Add one black col
d = torch.cat([d, torch.zeros(*d.shape[:-1], 1) + float('nan')], dim=-1)
d = torch.cat([d, torch.zeros(*d.shape[:-2], 1, d.shape[-1]) + float('nan')], dim=-2)
d = einops.rearrange(d, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
b1 = torch.cat([b1, torch.zeros(*b1.shape[:-2], 1, b1.shape[-1]) + float('nan')], dim=-2)
b1 = einops.rearrange(b1, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
print(d.shape, b1.shape)
d = torch.cat([d, b1], dim=-1)
# Remove weights close to zero
# d[abs(d) < 0.1] = float('nan')
imshow(d, title='First layer weights', facet_col=0, facet_col_wrap=4,
          # height=4000,
          # width=None,
          color_continuous_scale='RdBu',
          color_continuous_midpoint=0,
          )


In [None]:
imshow(w1.flatten(2).mean(dim=2),
          title="Mean of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
          )
imshow(w1.flatten(2).abs().mean(dim=2), 
          title="Mean absolute value of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
          )
imshow(w1.flatten(2).std(dim=2),
            title="Std of the weights of the switch layers",
            labels=dict(x="Object type", y="Agent"),
          )


In [None]:
last_layer: torch.nn.Linear = policy.policy.action_net
last_weights = last_layer.weight.detach().cpu().clone()
last_bias = last_layer.bias.detach().cpu().clone()

weights = torch.cat([
    last_weights.T @ net.switches[i].weight.detach().cpu().clone()
    for i in range(3)
], dim=0)
imshow(weights)

biases = torch.stack([net.switches[i].bias.detach().cpu() for i in range(3)], dim=1)
imshow(biases)

In [None]:
# compute correlations between rows of w2
w2 = net.post_switch[1].weight.detach().cpu().clone()  # (64, 64)
imshow(w2)
w2 = w2 / w2.norm(dim=1, keepdim=True)
corr = w2 @ w2.T

# Cluster the correlations matrix
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as ssd
import plotly.figure_factory as ff

# Compute and plot first dendrogram.
fig = ff.create_dendrogram(
    corr.numpy(),
    orientation='left',
    labels=list(range(64)),
    linkagefun=lambda x: sch.linkage(x, 'single'),
    distfun=lambda x: ssd.pdist(x, 'euclidean'),
)
fig.update_layout(width=1000, height=1000)
fig.show()

# Remove the diagonal
corr[range(64), range(64)] = float('nan')
px.imshow(corr, width=1000, height=1000)