In [None]:
from stable_baselines3.common.monitor import Monitor
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from random import randint

import einops
import numpy as np
import plotly.express as px
import torch
import wandb
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from torch import nn
from torchinfo import torchinfo

import src as M

# For plotly, to have larger images
BIG = dict(width=1600, height=1600)
ZERO_CENTERED = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0)

In [None]:
ENV_SIZE = 4

mk_env_generator = lambda full_color: M.wrap(
    lambda: M.ThreeGoalsEnv(ENV_SIZE, step_reward=0.0),
    # lambda e: M.ColorBlindWrapper(e, reduction='max', reward_indistinguishable_goals=True, disabled=full_color),
    lambda e: M.OneHotColorBlindWrapper(e, reward_indistinguishable_goals=True, disabled=full_color),
    lambda e: M.AddTrueGoalToObsFlat(e),
    # lambda e: M.AddSwitch(e, 1, lambda _: 0),  # No switch, but we still use SwitchMLP as the architecture...
)
mk_env = mk_env_generator(full_color=False)
mk_env_full_color = mk_env_generator(full_color=True)

In [None]:
# env = mk_env_full_color()
env = mk_env()

print(env)
print("Observation space:", env.observation_space)
print("Action space:", env.action_space)
check_env(env)
obs, _ = env.reset()

if isinstance(obs, dict):
    switch = obs['switch']
    print(f"{switch=}")
    obs = obs['obs']
    
print(f"{obs.shape=}")
px.imshow(obs[:-
3].reshape(4, 4, 3)).show()
print(obs[-3:])
# print(obs[:80].astype(int).reshape(4, 4, -1))
# print(obs[80:].astype(int))



In [None]:
arch = M.Split(-3,
   left= nn.Sequential(
       M.Rearrange("... (h w c) -> ... c h w", h=ENV_SIZE, w=ENV_SIZE, c=5),
       M.PerChannelL1WeightDecay(
           nn.Conv2d(5, 8, 3, padding=1),
           weight_decay=0,
           name_filter="weight",
       ),
       nn.ReLU(),
       nn.Conv2d(8, 8, 3, padding=1),
       nn.ReLU(),
       # nn.Conv2d(16, 4, 3, padding=1),
       # nn.ReLU(),
       nn.Flatten(-3),
   ),
   right=nn.Identity(),
)

arch = nn.Sequential(
    arch,
    # A.Rearrange("... h w c -> ... c h w", h=ENV_SIZE, w=ENV_SIZE, c=3),
    # nn.Conv2d(3, 8, 3, padding=1),
    # nn.ReLU(),
    # nn.Conv2d(8, 8, 3, padding=1),
    # nn.ReLU(),
    nn.LazyLinear(32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
)

print(torchinfo.summary(arch, input_size=(7, *obs.shape), depth=4))
print(arch)

In [None]:
learning_rate = 1e-3
weight_decay = 1e-2
seed = randint(0, 2**32 - 1)
n_env = 4
use_wandb = True

# assert not isinstance(arch, M.L1WeightDecay), "You forgot to re-run the arch definition"
# arch = M.L1WeightDecay(arch, l1_weight_decay)

policy = PPO(
    M.CustomActorCriticPolicy,
    make_vec_env(mk_env, n_envs=n_env),
    policy_kwargs=dict(
        arch=arch,
        # optimizer_kwargs=dict(weight_decay=0),
    ),
    n_steps=2_048 // n_env,
    tensorboard_log="../run_logs",
    seed=seed,
    learning_rate=lambda f: f * learning_rate,
    device='cpu',
)
print(policy.policy)
print("Total number of parameters:", sum(p.numel() for p in policy.policy.parameters()))

if use_wandb:
    wandb.init(
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        save_code=True,
        config=dict(
            l1_weight_decay=weight_decay,
            arch=str(arch),
            seed=seed,
        ),
        project="3goals-blind",
        notes=M.unique("""
Trying to get the per-channel regularisation to work.
533: First try -> Fail
534: Second try -> Slight red pref :s
535: Added logging to see the per-channel l1 norm, and also averaged over the out channels too now
536: Increase wd to 1e-3. Use now `w -= w * max(per_channel) * wd` instead of `w -= w / (1e-8 + norm(per_channel)) * wd`
539: Try tensorboard logging -> meh
540: Back to wandb
541: One more try to make the logs work
542: Forgot to add the logger
543: Didn't work before. It doens work... -> Seems to make every channel equal
544: Using w -= max(per channel) * wd, with wd=1e-3
546: It did not log. Added an assert to make sure the l1 weight decay is actually applied
552: Fixed l1 not applying -> failed impl
553: use `w -= w.sign() * max(per_channel) * wd`. Not better
554: now w[min channel] -= w * wd. Also wd = 1. Damn, the Agent channel got destroyed!
555: Re-run, but also use default l2 wd on other params
556: Re-run, lower wd to 1e-2
557: Use w -= w * max(per channel) ** 0.5 * wd
"""))
    

callbacks = [M.ProgressBarCallback(), 
             M.WeightDecayCallback(lambda f: (1-f) * weight_decay)]
if use_wandb:
    callbacks.append(M.WandbWithBehaviorCallback(mk_env()))
policy.learn(total_timesteps=400_000, callback=callbacks)

In [None]:
# policy.lr_schedule = lambda _: 2e-4
# policy.policy.mlp_extractor.policy_net.weight_decay = 3e-4
policy.learn(total_timesteps=300_000, reset_num_timesteps=False, callback=callbacks)

In [None]:
M.make_stats(policy, mk_env(), n_episodes=10_000, subtitle="Blind agent with blind inputs");

In [None]:
M.make_stats(policy, mk_env_full_color(),
             n_episodes=10_000,
             subtitle="Blind agent with full color inputs");

In [None]:
color = 1
M.evaluate(policy, mk_env_generator(color)(), n_episodes=1000)

In [None]:
M.show_behavior(policy, M.ThreeGoalsEnv.interesting(4, 10, [mk_env]), **BIG)

In [None]:
envs = []
for i in range(6):
    base_env = M.ThreeGoalsEnv.constant(ENV_SIZE, true_goal={"red": 1, "green": 1})
    envs.append(mk_env(base_env))
    envs.append(mk_env_full_color(base_env))
        
M.show_behavior(policy, envs, **BIG)

# Load and Save models

In [None]:
# Print all models
models_dir = Path("../models/3-goal-blind-one-hot")
for model_path in sorted(models_dir.glob("*.zip")):
    print(model_path.name)

In [None]:
no_green_channel_agent = policy

In [None]:
name = "no-red-channel"
path = models_dir / f"{name}.zip"
if path.exists():
    print("Loading existing model")
    try:
        previous_policy = policy  # Saved, in case I wanted to save it, but forgot to change the name
    except NameError: 
        pass
    policy = PPO.load(path)
else:
    print(f"Saving model to {path}")
    try:
        for m in policy.policy.modules():
            if hasattr(m, "logger"):
                print("Removing logger on", m)
                del m.logger
        policy.save(path)
    except Exception as e:
        path.unlink()
        # Without this I get TypeError: can't pickle LazyModule objects
        # I don't know why, but my hack for the weight_decay seem to interfere with
        # their saving mechanism
        raise

# Visualize the weights

In [None]:
def imshow(x, symetric: bool = True, **kwargs):
    h, w = x.shape[-2:]
    if 'facet_col' in kwargs:
        wrap = kwargs.get('facet_col_wrap', 1)
        h *= np.ceil(x.shape[kwargs['facet_col']] / wrap)
        w *= wrap
    if symetric:
        kwargs.setdefault("color_continuous_midpoint", 0)
        kwargs.setdefault("color_continuous_scale", "RdBu")
    width = 50 + w * 25
    height = h * 25 + 50 * ('title' in kwargs)
    while width < 500 and height < 500:
        width *= 2
        height *= 2
    new = dict(
        width=max(300, width),
        height=max(300, height),
        facet_row_spacing=0.01,
        facet_col_spacing=0.01,
    )
    kwargs = {**new, **kwargs}
    px.imshow(x, **kwargs).show()

In [None]:
switch_biases = policy.policy.mlp_extractor.switch_biases
print("Biases shape:", switch_biases.shape)
print("Max abs bias:", switch_biases.abs().max(dim=-1).values)
imshow(switch_biases, title='Biases of the switch layers')


In [None]:
# Plot the three switch layers

switch_layers_weights = policy.policy.mlp_extractor.switch_weights

# shape of a switch (out_dim, row, col, obj_type)
w1 = einops.rearrange(switch_layers_weights, 'agent out_dim (row col obj_type) -> agent obj_type out_dim row col', row=4, col=4)
b1 = einops.repeat(policy.policy.mlp_extractor.switch_biases, 'agent out_dim -> agent 1 out_dim 4 1')

avg = w1.mean(dim=0)
TYPES = ['empty', 'agent', 'goal_red', 'goal_green', 'goal_blue']
EMPTY, AGENT, RED, GREEN, BLUE = range(5)

print(w1.shape, avg.shape)
# d = w1[:, EMPTY] - avg[EMPTY] 
d = avg

d = w1[..., :-3]
print(d.shape)
# Add one black col
d = torch.cat([d, torch.zeros(*d.shape[:-1], 1) + float('nan')], dim=-1)
d = torch.cat([d, torch.zeros(*d.shape[:-2], 1, d.shape[-1]) + float('nan')], dim=-2)
d = einops.rearrange(d, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
b1 = torch.cat([b1, torch.zeros(*b1.shape[:-2], 1, b1.shape[-1]) + float('nan')], dim=-2)
b1 = einops.rearrange(b1, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
print(d.shape, b1.shape)
d = torch.cat([d, b1], dim=-1)
# Remove weights close to zero
# d[abs(d) < 0.1] = float('nan')
imshow(d[:16], title='First layer weights', facet_col=0, facet_col_wrap=4,
          # height=4000,
          # width=None,
          )


In [None]:
imshow(w1.flatten(2).mean(dim=2),
          title="Mean of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )
imshow(w1.flatten(2).abs().mean(dim=2), 
          title="Mean absolute value of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )
imshow(w1.flatten(2).std(dim=2),
            title="Std of the weights of the switch layers",
            labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )


In [None]:
last_layer: torch.nn.Linear = policy.policy.action_net
last_weights = last_layer.weight.detach().cpu().clone()
last_bias = last_layer.bias.detach().cpu().clone()
net = policy.policy.mlp_extractor.policy_net.module

weights = torch.cat([
    last_weights.T @ net.switches[i].weight.detach().cpu().clone()
    for i in range(3)
], dim=0)
imshow(weights)

biases = torch.stack([net.switches[i].bias.detach().cpu() for i in range(3)], dim=1)
imshow(biases)

In [None]:
# compute correlations between rows of w2
w2 = net.post_switch[1].weight.detach().cpu().clone()  # (64, 64)
imshow(w2)
w2 = w2 / w2.norm(dim=1, keepdim=True)
corr = w2 @ w2.T

# Cluster the correlations matrix
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as ssd
import plotly.figure_factory as ff

# Compute and plot first dendrogram.
fig = ff.create_dendrogram(
    corr.numpy(),
    orientation='left',
    labels=list(range(64)),
    linkagefun=lambda x: sch.linkage(x, 'single'),
    distfun=lambda x: ssd.pdist(x, 'euclidean'),
)
fig.update_layout(width=1000, height=1000)
fig.show()

# Remove the diagonal
corr[range(64), range(64)] = float('nan')
px.imshow(corr, width=1000, height=1000)

# Looking at the activations

In [None]:
envs: list[M.ThreeGoalsEnv]
n = 5000
envs = [add_wrappers(M.ThreeGoalsEnv.constant(), disabled=False) for _ in range(n)]
inputs = [env.reset()[0] for env in envs]
with M.record_activations(policy.policy) as cache:
    for i in inputs:
        policy.predict(i)
        
print(cache)

In [None]:
observations = np.stack([obs['obs'] for obs in inputs])
switches = np.stack([obs['switch'] for obs in inputs])
print("Observations shape:", observations.shape)
print("Switches shape:", switches.shape)

In [None]:
goal_positions = torch.tensor([e.goal_positions for e in envs])
agent_positions = torch.tensor([e.agent_pos for e in envs])
red_goals = goal_positions[:, 0]
green_goals = goal_positions[:, 1]
blue_goals = goal_positions[:, 2]

In [None]:
# Find the correlation between each act and red_goal.x
def one_hot_encode(arr, maxi=4):
    # Set up one-hot encoding array
    one_hot = np.zeros((*arr.shape, maxi))
    # Populate the one-hot encoding array        
    np.put_along_axis(one_hot, arr[..., None], 1, axis=-1)
    return one_hot

# to_check = goal_positions.flatten(1)
# to_check = one_hot_encode(to_check)
to_check = np.concatenate([
    goal_positions.flatten(1),
    agent_positions], 
    axis=-1)

for name in ("switches.0", "switches.1"):
    act = cache[name]
    corrs = np.corrcoef(act, to_check, rowvar=False)
    # remove the diagonal
    corrs[range(corrs.shape[0]), range(corrs.shape[0])] = 0
    imshow(corrs[32:])

In [None]:
corr_red_blind = np.corrcoef(cache['switches.0'], goal_positions[:, 0], rowvar=False)
corr_blue_color = np.corrcoef(cache['switches.1'], goal_positions[:, 2], rowvar=False)
corr_red_blind = corr_red_blind[32:, :-2]
corr_blue_color = corr_blue_color[32:, :-2]
corrs = einops.rearrange([corr_red_blind, corr_blue_color], 
                         "type dim neuron -> (dim type) neuron")

imshow(corrs)
imshow(corr_red_blind)
imshow(corr_blue_color)

In [None]:
conv1 = next(m for m in policy.policy.modules() if isinstance(m, torch.nn.Conv2d))
with torch.no_grad():
    conv1.weight[:, 3] = 0


In [None]:
def show_first_conv(policy):
    conv1 = next(m for m in policy.policy.modules() if isinstance(m, torch.nn.Conv2d))
    conv1 = conv1.weight.detach().numpy()
    
    in_channel_names = ["Empty", "Agent", "Red", "Green", "Blue"]

    # Per channel norms
    norms = np.linalg.norm(conv1, axis=(2, 3))
    print(np.linalg.norm(conv1))
    px.imshow(norms, 
              title="Norms of the weights of the first convolutional layer",
                labels=dict(x="Input channel", y="Output channel"),
              x=in_channel_names,
              width=500,
              **ZERO_CENTERED).show()

    conv1 = einops.repeat(conv1, "out in row col -> out row (in col)")
    px.imshow(conv1, facet_col=0,
              **ZERO_CENTERED,
              # zmax=2, zmin=-2,
              facet_col_wrap=4).show()
    
    
show_first_conv(policy)

In [None]:
# Do a PCA on the data
import csv
import pandas as pd
from sklearn.decomposition import PCA

file = "../data.csv"
with open(file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    data = [row for row in reader]
    
data = pd.DataFrame(data)
del data["Name"]
del data["_wandb"]
data = data.astype(float)
x_axis = "eval/full_color/true_goal_red/end_type_red"
y_axis = "eval/full_color/true_goal_red/end_type_green"

# Print mean on the x and y axis
print(data[x_axis].mean())
print(data[y_axis].mean())

pca = PCA(n_components=1)
fitted = pca.fit_transform(data[[x_axis, y_axis]])

px.histogram(fitted, marginal="rug", title="PCA histogram", width=1000).show()

px.histogram(data[[x_axis, y_axis]],
             barmode="overlay",
             marginal="rug",
             width=1000,
             ).show()
                
px.histogram(data[x_axis] - data[y_axis],
             marginal="rug",
             nbins=60
             ).show()