In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from random import randint

import einops
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
import wandb
from tqdm.notebook import tqdm
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from sklearn.linear_model import LinearRegression, Lasso
from sklearn.neural_network import MLPRegressor
from torch import nn
from torchinfo import torchinfo
from pprint import pprint

import src as M

# For plotly, to have larger images
BIG = dict(width=1600, height=1600)
ZERO_CENTERED = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0)

In [None]:
ENV_SIZE = 4

def mk_env_generator(full_color, weighted=False):
    return M.wrap(
        lambda: M.ThreeGoalsEnv(ENV_SIZE, step_reward=-0.003),
        lambda e: M.ColorBlindWrapper(e, reduction='max', reward_indistinguishable_goals=True, disabled=full_color),
        # lambda e: M.OneHotColorBlindWrapper(e, reward_indistinguishable_goals=True, disabled=full_color),
        lambda e: M.WeightedChannelWrapper(e, weights=(1, 0.5, 1), disabled=not weighted),
        lambda e: M.AddTrueGoalToObsFlat(e),
        # lambda e: M.AddSwitch(e, 1, lambda _: 0),  # No switch, but we still use SwitchMLP as the architecture...
    )
mk_env = mk_env_generator(full_color=False)
mk_env_full_color = mk_env_generator(full_color=True)

In [None]:
const = M.ThreeGoalsEnv.constant(ENV_SIZE, true_goal={"red": 1, "green": 1})

In [None]:
# env = mk_env_full_color()
env = mk_env(const)

print(env)
print("Observation space:", env.observation_space)
print("Action space:", env.action_space)
check_env(env)
obs, _ = env.reset()

if isinstance(obs, dict):
    switch = obs['switch']
    print(f"{switch=}")
    obs = obs['obs']
    
print(f"{obs.shape=}")
px.imshow(obs[:-3].reshape(ENV_SIZE, ENV_SIZE, 3)).show()
print(obs[-3:])


In [None]:
# Show colorblind vs full color observations
env = mk_env(const)
obs_blind, _ = env.reset()
env = mk_env_full_color(const)
obs_full_color, _ = env.reset()
env = mk_env_generator(False, True)(const)
obs_weighted, _ = env.reset()

img_blind = obs_blind[:-3].reshape(ENV_SIZE, ENV_SIZE, 3)
img_full_color = obs_full_color[:-3].reshape(ENV_SIZE, ENV_SIZE, 3)
img_weighted = obs_weighted[:-3].reshape(ENV_SIZE, ENV_SIZE, 3)

# Show them on the same figure. Note obs are between 0 and 1
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=3, subplot_titles=("Colorblind", "Full color", "Weighted blind"))
fig.add_trace(go.Image(z=img_blind * 255), row=1, col=1)
fig.add_trace(go.Image(z=img_full_color * 255), row=1, col=2)
fig.add_trace(go.Image(z=img_weighted * 255), row=1, col=3)
fig.update_layout(width=1000, height=400, title="Observations")
fig.show()


# Train models

In [None]:
arch = M.Split(-3,
   left= nn.Sequential(
       M.Rearrange("... (h w c) -> ... c h w", h=ENV_SIZE, w=ENV_SIZE),
       # M.PerChannelL1WeightDecay(
           nn.LazyConv2d(8, 3, padding=1),
           # weight_decay=0,
           # name_filter="weight",
       # ),
       nn.ReLU(),
       nn.Conv2d(8, 8, 3, padding=1),
       nn.ReLU(),
       nn.Flatten(-3),
   ),
   right=nn.Identity(),
)

arch = nn.Sequential(
    arch,
    nn.LazyLinear(32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
)

# arch = M.L1WeightDecay(arch, 0)

print(arch)
print(torchinfo.summary(arch, input_size=(7, *obs.shape), depth=4))

In [None]:
learning_rate = 5e-4
# weight_decay = 1e-2
seed = randint(0, 2**32 - 1)
n_env = 4
use_wandb = True

# assert not isinstance(arch, M.L1WeightDecay), "You forgot to re-run the arch definition"
# arch = M.L1WeightDecay(arch, l1_weight_decay)

policy = PPO(
    M.CustomActorCriticPolicy,
    make_vec_env(mk_env, n_envs=n_env),
    policy_kwargs=dict(
        arch=arch,
        # optimizer_kwargs=dict(weight_decay=0),
    ),
    n_steps=2_048 // n_env,
    tensorboard_log="../run_logs",
    seed=seed,
    learning_rate=lambda f: f * learning_rate,
    device='cpu',
)

if use_wandb:
    wandb.init(
        sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
        save_code=True,
        config=dict(
            # l1_weight_decay=weight_decay,
            arch=str(arch),
            seed=seed,
        ),
        project="3goals-blind",
        notes=M.unique("""
Can I scale my results?
614: Trying again with 1M steps on 7x7 env. -> learn slowly, in >1M steps. Still exibits preference for one color
615: Added a step penalty and increased lr to 5e-4
616: Decrease step penalty (-0.01 -> -0.003)
"""))

callbacks = [M.ProgressBarCallback(),
             # M.WeightDecayCallback(lambda f: (1-f) * weight_decay),
             M.LogChannelNormsCallback(),
             ]
if use_wandb:
    callbacks.append(M.WandbWithBehaviorCallback(mk_env()))
policy.learn(total_timesteps=1000_000, callback=callbacks);

In [None]:
policy.lr_schedule = lambda _: 1e-4
# policy.policy.mlp_extractor.policy_net.weight_decay = 3e-4
policy.learn(total_timesteps=400_000, reset_num_timesteps=False, callback=callbacks)

In [None]:
M.make_stats(policy, mk_env(), n_episodes=10_000,
             wandb_name="eval/blind",
             subtitle="Blind agent with blind inputs");

In [None]:
M.make_stats(policy, mk_env_full_color(),
             n_episodes=10_000,
             wandb_name="eval/full-color-non-weighted",
             subtitle="Blind agent with full color inputs");

In [None]:
color = 0
M.evaluate(policy, mk_env(), n_episodes=400, show_n=3, height=1000)

In [None]:
M.show_behavior(policy, mk_env_full_color(), 4, **BIG)

In [None]:
envs = []
for i in range(6):
    base_env = M.ThreeGoalsEnv.constant(ENV_SIZE, true_goal={"red": 1, "green": 1})
    envs.append(mk_env(base_env))
    envs.append(mk_env_full_color(base_env))
        
M.show_behavior(policy, envs, **BIG)

# Load and Save models


In [None]:
import train

exp = train.BlindThreeGoalsOneHot()
policy, stats = exp.load(0)
mk_env = exp.get_env(False)
mk_env_full_color = exp.get_env(True)
pprint(stats)

In [None]:
M.make_stats(policy, mk_env_full_color(), n_episodes=10_000)

In [None]:
M.make_stats(policy, mk_env_full_color(), n_episodes=500_000)

In [None]:
# Print all models
models_dir = Path("../models/3-goal-blind-one-hot")
for model_path in sorted(models_dir.glob("*.zip")):
    print(model_path.name)

In [None]:
name = "no-red-channel"
path = models_dir / f"{name}.zip"
if path.exists():
    print("Loading existing model")
    try:
        previous_policy = policy  # Saved, in case I wanted to save it, but forgot to change the name
    except NameError: 
        pass
    policy = PPO.load(path)
else:
    print(f"Saving model to {path}")
    try:
        for name, m in policy.policy.named_modules():
            if hasattr(m, "logger"):
                print("Removing logger on", m)
                del m.logger
        policy.save(path)
    except Exception as e:
        path.unlink()
        # Without this I get TypeError: can't pickle LazyModule objects
        # I don't know why, but my hack for the weight_decay seem to interfere with
        # their saving mechanism
        raise

# Visualize the weights

In [None]:
def imshow(x, symetric: bool = True, **kwargs):
    h, w = x.shape[-2:]
    if 'facet_col' in kwargs:
        wrap = kwargs.get('facet_col_wrap', 1)
        h *= np.ceil(x.shape[kwargs['facet_col']] / wrap)
        w *= wrap
    if symetric:
        kwargs.setdefault("color_continuous_midpoint", 0)
        kwargs.setdefault("color_continuous_scale", "RdBu")
    width = 50 + w * 25
    height = h * 25 + 50 * ('title' in kwargs)
    while width < 500 and height < 500:
        width *= 2
        height *= 2
    new = dict(
        width=max(300, width),
        height=max(300, height),
        facet_row_spacing=0.01,
        facet_col_spacing=0.01,
    )
    kwargs = {**new, **kwargs}
    px.imshow(x, **kwargs).show()

In [None]:
switch_biases = policy.policy.mlp_extractor.switch_biases
print("Biases shape:", switch_biases.shape)
print("Max abs bias:", switch_biases.abs().max(dim=-1).values)
imshow(switch_biases, title='Biases of the switch layers')


In [None]:
# Plot the three switch layers

switch_layers_weights = policy.policy.mlp_extractor.switch_weights

# shape of a switch (out_dim, row, col, obj_type)
w1 = einops.rearrange(switch_layers_weights, 'agent out_dim (row col obj_type) -> agent obj_type out_dim row col', row=4, col=4)
b1 = einops.repeat(policy.policy.mlp_extractor.switch_biases, 'agent out_dim -> agent 1 out_dim 4 1')

avg = w1.mean(dim=0)
TYPES = ['empty', 'agent', 'goal_red', 'goal_green', 'goal_blue']
EMPTY, AGENT, RED, GREEN, BLUE = range(5)

print(w1.shape, avg.shape)
# d = w1[:, EMPTY] - avg[EMPTY] 
# d = avg

d = w1[..., :-3]
print(d.shape)
# Add one black col
d = torch.cat([d, torch.zeros(*d.shape[:-1], 1) + float('nan')], dim=-1)
d = torch.cat([d, torch.zeros(*d.shape[:-2], 1, d.shape[-1]) + float('nan')], dim=-2)
d = einops.rearrange(d, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
b1 = torch.cat([b1, torch.zeros(*b1.shape[:-2], 1, b1.shape[-1]) + float('nan')], dim=-2)
b1 = einops.rearrange(b1, 'agent obj out row col -> out (agent row) (obj col)')[..., :-1, :]
print(d.shape, b1.shape)
d = torch.cat([d, b1], dim=-1)
# Remove weights close to zero
# d[abs(d) < 0.1] = float('nan')
imshow(d[:16], title='First layer weights', facet_col=0, facet_col_wrap=4,
          # height=4000,
          # width=None,
          )


In [None]:
imshow(w1.flatten(2).mean(dim=2),
          title="Mean of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )
imshow(w1.flatten(2).abs().mean(dim=2), 
          title="Mean absolute value of the weights of the switch layers",
          labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )
imshow(w1.flatten(2).std(dim=2),
            title="Std of the weights of the switch layers",
            labels=dict(x="Object type", y="Agent"),
       symetric=False,
          )


In [None]:
last_layer: torch.nn.Linear = policy.policy.action_net
last_weights = last_layer.weight.detach().cpu().clone()
last_bias = last_layer.bias.detach().cpu().clone()
net = policy.policy.mlp_extractor.policy_net.module

weights = torch.cat([
    last_weights.T @ net.switches[i].weight.detach().cpu().clone()
    for i in range(3)
], dim=0)
imshow(weights)

biases = torch.stack([net.switches[i].bias.detach().cpu() for i in range(3)], dim=1)
imshow(biases)

In [None]:
# compute correlations between rows of w2
w2 = net.post_switch[1].weight.detach().cpu().clone()  # (64, 64)
imshow(w2)
w2 = w2 / w2.norm(dim=1, keepdim=True)
corr = w2 @ w2.T

# Cluster the correlations matrix
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as ssd
import plotly.figure_factory as ff

# Compute and plot first dendrogram.
fig = ff.create_dendrogram(
    corr.numpy(),
    orientation='left',
    labels=list(range(64)),
    linkagefun=lambda x: sch.linkage(x, 'single'),
    distfun=lambda x: ssd.pdist(x, 'euclidean'),
)
fig.update_layout(width=1000, height=1000)
fig.show()

# Remove the diagonal
corr[range(64), range(64)] = float('nan')
px.imshow(corr, width=1000, height=1000)

# Looking at the activations

In [None]:
import train
from pprint import pprint

exp = train.BlindThreeGoalsOneHot()
for idx in range(170, 190):
    policy, stats = exp.load(idx)
    if stats["seed"] == 616_632_426:
        print(idx)
        pprint(stats)
        break

In [None]:
envs: list[M.ThreeGoalsEnv]
n = 2000
envs = [exp.get_env(False)(M.ThreeGoalsEnv.constant()) for _ in range(n)]
inputs = [env.reset()[0] for env in envs]
with M.record_activations(policy.policy) as cache:
    for i in tqdm(inputs):
        policy.predict(i)
    
# cache.remove_batch_dim()
print(cache)

In [None]:
# Show what the input looks like
px.imshow(cache["left.0"][0], facet_col=0).show()
print("Goal", cache["right"])


In [None]:
# Plot the activations after the first conv2d, one for each output channel
act = torch.stack([cache[f"left.{i}"] for i in (2, 4)])
act = einops.rearrange(act, 'layer state out row col -> state (layer out) row col')
act = torch.cat([act, cache["left.0"], torch.zeros(n, 3, 4, 4)], dim=1)
px.imshow(act[:20], facet_col=1, facet_col_wrap=8,
          animation_frame=0, height=1000,
          **ZERO_CENTERED).show()

In [None]:
observations = np.stack([obs['obs'] for obs in inputs])
switches = np.stack([obs['switch'] for obs in inputs])
print("Observations shape:", observations.shape)
print("Switches shape:", switches.shape)

In [None]:
goal_positions = torch.tensor([e.goal_positions for e in envs]).float()
agent_positions = torch.tensor([e.agent_pos for e in envs]).float()
red_goals = goal_positions[:, 0]
green_goals = goal_positions[:, 1]
blue_goals = goal_positions[:, 2]
dist_to_red = torch.linalg.norm(agent_positions - red_goals, dim=1, ord=1)
dist_to_green = torch.linalg.norm(agent_positions - green_goals, dim=1, ord=1)
dist_to_blue = torch.linalg.norm(agent_positions - blue_goals, dim=1, ord=1)

In [None]:
# Try to predict the blue position from the activations of the first conv

act = cache["left.4"]
act = einops.rearrange(act, 'batch out row col -> batch (out row col)')
to_predict = blue_goals[:, 1]
to_predict = (blue_goals - agent_positions)[:, 0]
# to_predict = agent_positions[:, 0]
to_predict = torch.minimum(dist_to_red, dist_to_green)
# to_predict = to_predict - to_predict.mean(dtype=float)

# reg = LinearRegression().fit(act, to_predict)
reg = Lasso(1e-2).fit(act, to_predict)
print(reg.score(act, to_predict))

predictions = reg.predict(act)
fig = px.scatter(x=to_predict, y=predictions,
           labels=dict(x="Predicted blue goal position", y="True blue goal position"),
           **ZERO_CENTERED)
# Add x=y line
minimum = min(to_predict.min(), predictions.min())
maximum = max(to_predict.max(), predictions.max())
fig.add_shape(type="line", x0=minimum, y0=minimum, x1=maximum, y1=maximum,
              line=dict(color="red"))
fig.show()

# Plot histogram of the coefficients
# fig = px.histogram(reg.coef_, nbins=100,
#                    labels=dict(x="Coefficient", y="Count"))
# fig.show()

# Plot the coefficients (output=8 row=4 col=4)
print(reg.intercept_)
coef = einops.rearrange(reg.coef_, '(out row col) -> out row col', out=8, row=4, col=4)
px.imshow(coef, facet_col=0, facet_col_wrap=4, **ZERO_CENTERED).show()


In [None]:
# Find the correlation between each act and red_goal.x
def one_hot_encode(arr, maxi=4):
    # Set up one-hot encoding array
    one_hot = np.zeros((*arr.shape, maxi))
    # Populate the one-hot encoding array        
    np.put_along_axis(one_hot, arr[..., None], 1, axis=-1)
    return one_hot

# to_check = goal_positions.flatten(1)
# to_check = one_hot_encode(to_check)
to_check = np.concatenate([
    goal_positions.flatten(1),
    agent_positions], 
    axis=-1)

for name in ("switches.0", "switches.1"):
    act = cache[name]
    corrs = np.corrcoef(act, to_check, rowvar=False)
    # remove the diagonal
    corrs[range(corrs.shape[0]), range(corrs.shape[0])] = 0
    imshow(corrs[32:])

In [None]:
corr_red_blind = np.corrcoef(cache['switches.0'], goal_positions[:, 0], rowvar=False)
corr_blue_color = np.corrcoef(cache['switches.1'], goal_positions[:, 2], rowvar=False)
corr_red_blind = corr_red_blind[32:, :-2]
corr_blue_color = corr_blue_color[32:, :-2]
corrs = einops.rearrange([corr_red_blind, corr_blue_color], 
                         "type dim neuron -> (dim type) neuron")

imshow(corrs)
imshow(corr_red_blind)
imshow(corr_blue_color)

In [None]:
conv1 = next(m for m in policy.policy.modules() if isinstance(m, torch.nn.Conv2d))
with torch.no_grad():
    conv1.weight[:, 3] = 0


In [None]:
def show_first_conv(policy):
    conv1 = next(m for m in policy.policy.modules() if isinstance(m, torch.nn.Conv2d))
    conv1 = conv1.weight.detach().numpy()
    
    in_channels = conv1.shape[1]
    in_channel_names = ["Empty", "Agent", "Red", "Green", "Blue"][-in_channels:]

    # Per channel norms
    norms = np.linalg.norm(conv1, axis=(2, 3))
    print(np.linalg.norm(conv1))
    px.imshow(norms, 
              title="Norms of the weights of the first convolutional layer",
                labels=dict(x="Input channel", y="Output channel"),
              x=in_channel_names,
              width=500,
              **ZERO_CENTERED).show()

    conv1 = einops.repeat(conv1, "out in row col -> out row (in col)")
    px.imshow(conv1, facet_col=0,
              **ZERO_CENTERED,
              # zmax=2, zmin=-2,
              facet_col_wrap=4).show()
    
    
show_first_conv(policy)

In [None]:
# Do a PCA on the data
import csv
import pandas as pd
from sklearn.decomposition import PCA

file = "../data.csv"
with open(file, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    data = [row for row in reader]
    
data = pd.DataFrame(data)
del data["Name"]
del data["_wandb"]
data = data.astype(float)
x_axis = "eval/full_color/true_goal_red/end_type_red"
y_axis = "eval/full_color/true_goal_red/end_type_green"

# Print mean on the x and y axis
n_points = len(data)
print("Number of points", n_points)
print("Mean on x axis", data[x_axis].mean())
print("Mean on y axis", data[y_axis].mean())

pca = PCA(n_components=1)
fitted = pca.fit_transform(data[[x_axis, y_axis]])

px.histogram(fitted, marginal="rug", title=f"PCA on Red vs. Green probas (n={n_points})",
             nbins=70,
             width=1000,
             ).show()

px.histogram(data[[x_axis, y_axis]],
             barmode="overlay",
             marginal="rug",
             nbins=70,
             width=1000,
             ).show()
                
fig = px.histogram(data[x_axis] - data[y_axis],
             title=f"Diff between red and green probas (n={n_points})",
             # Normalized histogram
             histnorm="probability density",
             marginal="rug",
             nbins=70,
             width=1000,
             )
# Add the pdf of a normal distribution
from plotly import graph_objects as go
from scipy import stats
mean = (data[x_axis] - data[y_axis]).mean()
std = (data[x_axis] - data[y_axis]).std()
x = np.linspace(mean - 3 * std, mean + 3 * std, 100)
y = stats.norm.pdf(x, mean, std)
fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name="Normal distribution"))
fig.show()


In [None]:
# Compare last histogram with a normal distribution
mean = (data[x_axis] - data[y_axis]).mean()
std = (data[x_axis] - data[y_axis]).std()

# Generate samples from a normal distribution
samples = np.random.normal(mean, std, size=n_points)
# Plot the same histogram
px.histogram(samples,
             title=f"Gaussian samples with the same mean & std (n={n_points})",
             marginal="rug",
             nbins=70,
             width=1000,
             ).show()


# Loading and plotting metrics of the models

In [None]:
import train

train.Experiment.show_all_experiments()

In [None]:
train.MODELS_DIR = train.ROOT / "models"
# train.MODELS_DIR = train.ROOT / "cam_fs" / "models"
Exp = train.BlindThreeGoals
stats = Exp.load_all_checkpoints_stats()
print(len(stats))

In [None]:
Exp = train.BlindThreeGoalsOneHot
models, stats = Exp.load_all()

In [None]:
def get_red_stats(data: dict):

    try:
        return [
            data["eval"][f"eval/full_color/true_goal_red/end_type_{type_}"]
            for type_ in ["red", "green", "blue", "no goal"]
        ]
    except KeyError:
        pass
    try:
        return [
            data["eval"]["full_color"]["true_goal_red"][f"end_type_{type_}"]
            for type_ in ["red", "green", "blue", "no goal"]
        ]
    except KeyError:
        pass

    try:
        return [
            data["eval"]["full_color_non_weighted"]["true_goal_red"][f"end_type_{type_}"]
            for type_ in ["red", "green", "blue", "no goal"]
        ]
    except KeyError:
        pass

    try:
        return data["stats_full_color"][0]
    except KeyError:
        pass

    print("Could not find the red stats in")
    pprint(data)
    raise KeyError

# full_color_probas = np.array([get_red_stats(d) for d in tqdm(stats)])
stat_array = [[get_red_stats(d) for d in model_stats] for model_stats in tqdm(stats)]
full_color_probas_by_run = np.array(stat_array)
full_color_probas = full_color_probas_by_run.reshape(-1, 4)

red_proba = full_color_probas[:, 0]

In [None]:
# Compute the norm of the red channel in the first conv layer
def get_norms(model):
    conv1 = next(m for m in model.policy.modules() if isinstance(m, torch.nn.Conv2d))
    conv1 = conv1.weight.detach().numpy()
    conv1 = einops.rearrange(conv1, "out input row col -> input (out row col)")
    norms = np.linalg.norm(conv1, axis=1)
    return norms


channel_1 = 2  # red
channel_2 = 3  # green
channel_norms = np.stack([get_norms(model) for model in tqdm(models)])

px.scatter(x=channel_norms[:, channel_1], y=red_proba,
           labels=dict(x="Norm of the red channel", y="Probability of red"),
           title="Norm of the red channel vs probability of red",
           width=1000).show()

minx = min(full_color_probas[:, 1].min(), red_proba.min())
maxx = max(full_color_probas[:, 1].max(), red_proba.max())
# Making it symetrical around 0.5
minx = min(minx, 1 - maxx)
maxx = max(maxx, 1 - minx)

fig = px.scatter(x=full_color_probas[:, 1], y=red_proba,
                 color=list(range(len(full_color_probas))),
                 # Link the dots
                 # color=full_color_probas[:, 3],
                 range_x=(minx, maxx),
                 range_y=(minx, maxx),
           labels=dict(x="Probability of green", y="Probability of red", color="No goal"),
           title=f"Behavior distribution when the target tile is red (n={len(full_color_probas)})",
           width=1000, height=1000)
M.add_line(fig, "y=1-x")
fig.show()


px.histogram(red_proba, marginal="rug", 
             nbins=70,
             title="Histogram of the probability of red", width=1000).show()


to_explain = red_proba
# Plot the correlation between channel norms and probabilities
corrs = np.corrcoef(channel_norms, to_explain, rowvar=False)
px.imshow(corrs, 
          title="Correlation between channel norms and probabilities",
          labels=dict(x="Input channel", y="Output channel"),
          x=["Empty", "Agent", "Red", "Green", "Blue", "Red prob"][-len(corrs):],
          width=1000,
          **ZERO_CENTERED).show()


In [None]:
# Plot the proba of red against the position in the list on one plot
px.scatter(x=range(len(red_proba)), y=red_proba,
           labels=dict(x="Model", y="Probability of red"),
           title="Probability of red vs model",
           width=1000).show()

In [None]:
checkpoint = np.array([[d["timesteps"] for d in model_stats] for model_stats in stats])
checkpoint = checkpoint.reshape(-1)

# Plot the proba of red against the proba of green, colored by the checkpoint
px.scatter(x=full_color_probas[:, 1], y=red_proba,
           color=checkpoint,
           labels=dict(x="Probability of green", y="Probability of red", color="Checkpoint"),
           title="Probability of red vs probability of green",
           width=1000).show()

# Plot mean and std of the proba of red for each run
mean = full_color_probas_by_run.mean(axis=0)
std = full_color_probas_by_run.std(axis=0)
px.scatter(x=checkpoint[:len(std)], y=mean[:, 0],
           error_y=std[:, 0],
           labels=dict(x="Run", y="Mean probability of red", error_y="Std of the probability of red"),
           title="Mean probability of red vs run",
           width=1000).show()


In [None]:
# Do a linear regression from the channel norms to the probabilities
import utils

reg = utils.show_fit(LinearRegression(), channel_norms, to_explain,
                     title=f"Predicted probability of red vs true probability of red using a linear regression on the channel norms",
                     xaxis="True probability of red",
                     yaxis="Predicted probability of red")
print("Coefficients:", reg.coef_)
print("Intercept:", reg.intercept_)

In [None]:
px.scatter(
    x=red_proba,
    y=channel_norms[:, 2] - channel_norms[:, 3],
           title="Norm of the red channel vs probability of red",
           width=1000).show()

# Predicting which goal will be reached by an agent
Here, we try to predict the red frequency from the the first convolutional layer, 
using a linear regression. 
Input: one output channel of the first convolutional layer
Output: probability of reaching the red goal of the corresponding model

In [None]:
models[0].policy

In [None]:
env = M.wrap(
    M.ThreeGoalsEnv.constant(),
    lambda e: M.OneHotColorBlindWrapper(e, reward_indistinguishable_goals=True),
    lambda e: M.AddTrueGoalToObsFlat(e),
)()

obs, _ = env.reset(seed=36)
with M.record_activations(models[0].policy) as cache:
    models[0].predict(obs)
print(cache)


In [None]:
import utils
from sklearn.linear_model import Ridge


# noinspection PyUnreachableCode
def get_inputs(model):
    convs = [m.weight.detach() for m in model.policy.modules() if isinstance(m, torch.nn.Conv2d)]
    linears = [m.weight.detach() for m in model.policy.modules() if isinstance(m, torch.nn.Linear)]

    conv1, conv2 = convs
    lin1, lin2 = linears[:2]
    
    return torch.linalg.vector_norm(conv1, dim=(0, 2, 3))
    
    with M.record_activations(model.policy) as cache:
        model.predict(obs)
    cache.apply(torch.flatten)
    return torch.cat([
        # cache['left.2'].flatten(),
        # conv1.flatten(),
        # conv2.flatten(),
        lin1.flatten(),
        lin2.flatten(),
        # cache['left.2'],
        # cache['left.4'],
        cache["Split"],
        cache['dule.2'],
        cache['dule.4'],
    ])
    return cache['left.2']

predictors = np.stack([get_inputs(model).flatten() for model in models])
to_predict = full_color_probas[:, 0]

if predictors.ndim == 3:
    to_predict = einops.repeat(to_predict, "n -> (n c)", c=predictors.shape[1])
    predictors = predictors.reshape(-1, predictors.shape[-1])
    

# Split the data into train and test
# X_train, X_test, y_train, y_test = train_test_split(flat_conv1s, to_predict, test_size=0.2, random_state=42)

# augment the training data by adding some permutations of the output channels
# to_add = []
# for i in range(30):
#     by_out_channel = einops.rearrange(X_train, "model (out rest) -> out model rest", out=n_out_channels)
#     by_out_channel = by_out_channel[np.random.permutation(n_out_channels)]
#     by_out_channel = einops.rearrange(by_out_channel, "out model rest -> model (out rest)")
#     to_add.append(by_out_channel)
# X_train = np.concatenate([X_train] + to_add)
# y_train = np.concatenate([y_train] * (len(to_add) + 1))
# 


reg = [
    LinearRegression(),
    Ridge(),
    Lasso(0.0001),
    MLPRegressor(hidden_layer_sizes=(64, ), alpha=0.01)
][0]

utils.show_fit(reg, predictors, to_predict,
               title=f"Predicted probability of red vs true probability of red using {reg}",
               xaxis="True probability of red",
               yaxis="Predicted probability of red")

px.imshow(reg.coef_.reshape(8, 5),
          title="Coefficients of the linear regression",
          # facet_col=0,
          width=1000,
          **ZERO_CENTERED).show()

# Finding how they make their decisions

In [None]:
Exp = train.BlindThreeGoalsOneHot

model, stat = Exp.load(42)
M.make_stats(model, Exp().get_eval_env().constant(), 1_000)

In [None]:
def stats_per_cell(policy: PPO, env):
    unwrapped = env.unwrapped
    assert isinstance(unwrapped, M.ThreeGoalsEnv)
    probas = torch.zeros(unwrapped.height, unwrapped.width, 4) + float('nan')
    values = torch.zeros(unwrapped.height, unwrapped.width) + float('nan')
    for x in range(unwrapped.width):
        for y in range(unwrapped.height):
            if (x, y) in unwrapped.goal_positions:
                continue
            unwrapped.agent_start = (x, y)
            obs, _ = env.reset()
            obs = torch.from_numpy(obs)
            distr = policy.policy.get_distribution(obs)
            probas[y, x] = distr.distribution.probs.detach()
            
            values[y, x] = policy.policy.predict_values(obs).detach()
            
    # Render the env
    img = env.render()
    px.imshow(np.flip(img, 0), title="Environment", width=500).show()
    
    # Plot the heatmap of the values and arrows between the cells of length corresponding to the probas
    fig = go.Figure()
    fig.add_trace(go.Heatmap(
        z=values,
        colorscale="Mint",
        colorbar=dict(title="Value"),
    ))
    # Add arrows
    for x in range(unwrapped.width):
        for y in range(unwrapped.height):
            if (x, y) in unwrapped.goal_positions:
                continue
            for i, (dx, dy) in enumerate(unwrapped.DIR_TO_VEC):
                proba = probas[y, x, i].item()
                if abs(proba) < 0.1:
                    continue
                    
                fig.add_annotation(
                    text="",
                    ax=x,
                    ay=y,
                    x=x+ dx * proba / 2,
                    y=y + dy * proba / 2,
                    xref="x", yref="y", axref="x", ayref="y",
                    showarrow=True,
                    arrowhead=4,
                    arrowsize=1,
                    arrowwidth=3,
                    arrowcolor="red",
                )
    # Draw a filled colored square on the 3 goal cells
    for i, color in enumerate(["#ff0000", "#00ff00", "#0000ff"]):
        x, y = unwrapped.goal_positions[i]
        fig.add_shape(
            type="rect",
            x0=x - 0.5, y0=y - 0.5, x1=x + 0.5, y1=y + 0.5,
            fillcolor=color,
            line=dict(color=color, width=0),
            layer="above",
        )
    fig.update_layout(
        title="Values and probabilities of the policy",
        xaxis=dict(title="x"),
        yaxis=dict(title="y"),
        width=500,
        height=500,
    )
    fig.show()
    
    return probas, values

env = Exp().get_env(True)(M.ThreeGoalsEnv.constant(4, 'red'))

stats_per_cell(model, env);

In [None]:
agent_pos, goals_pos, end_type, true_goal = M.destination_stats(model, Exp().get_eval_env(), 100_000)

In [None]:
# Print all the shapes
print("Agent pos shape:", agent_pos.shape)
print("Goals pos shape:", goals_pos.shape)
print("End type shape:", end_type.shape)
print("True goal shape:", true_goal.shape)

In [None]:
TARGET = 1

vec_to_goals = (goals_pos - agent_pos[:, None, :])
end_on_target = end_type == TARGET
red_or_green = end_type < 2

# Find the distance to red and green goals
dist_to_goal = torch.linalg.vector_norm(vec_to_goals.float(), dim=2, ord=1).long()
dist_to_goal.shape, red_or_green.shape, end_on_target.shape

In [None]:
from sklearn.linear_model import LogisticRegression

reg = M.show_fit(
    LogisticRegression(), 
    dist_to_goal[red_or_green][:, [0, 1]], 
    end_on_target[red_or_green],
    title="Predicting whether the policy goes to red or green end depending on the distance to the target cells", 
    xaxis="Prediction: does it go to green?",
    yaxis="True: does it go to green?",
    classification=True,
)

print(reg.coef_, reg.intercept_)

In [None]:
# Plot the frequency of going to each end type, as a heatmap x=distance to red, y=distance to green

plot_only_red_green = True
n_plots = 2 if plot_only_red_green else 4

max_dist = dist_to_goal.max().item() + 1
frequency = torch.full((n_plots, max_dist, max_dist), float('nan'))
counts = torch.full((n_plots, max_dist, max_dist), 0)

for target in range(n_plots):
    for x in range(1, max_dist):
        for y in range(1, max_dist):
            correct_target = (end_type == target)[red_or_green]
            correct_x = (dist_to_goal[:, 0] == x)[red_or_green]
            correct_y = (dist_to_goal[:, 1] == y)[red_or_green]
            frequency[target, y, x] = correct_target[correct_x & correct_y].float().mean()
            counts[target, y, x] = len(correct_target[correct_x & correct_y])
            
        
fig = px.imshow(frequency,
                title="Frequency of going to each end type, as a heatmap x=distance to red, y=distance to green",
                labels=dict(x="Distance to red", y="Distance to green", color="Frequency"),
                facet_col=0, height=1000,
                zmin=0, zmax=1, color_continuous_scale="Blues")
# Set facet labels
for i, label in enumerate(["Red", "Green", "Blue", "No goal"][:n_plots]):
    fig.layout.annotations[i].text = f"Ended on {label}"
    
# Add percentage labels
for i in range(n_plots):
    for j in range(max_dist):
        for k in range(max_dist):
            value = frequency[i, j, k]
            if not np.isnan(value):
                fig.add_annotation(
                    x=j, y=k, text=f"{value:.0%}", showarrow=False,
                    col=i, row=1,
                    font=dict(color="white" if value > 0.5 else "black", size=20)
                )
                fig.add_annotation(
                    x=j, y=k+0.2, text=f"n={counts[i, k, j]}", showarrow=False,
                    col=i, row=0,
                    font=dict(color="white" if value > 0.5 else "black", size=10)
                )
   
fig.show()


In [None]:
# Are there some cells that are prefered?

# Show what the baseline does 

Here the baseline of going to the red goal while avoiding the blue and ignoring the green one

In [None]:
env = mk_env_full_color()
obs, _ = env.reset()
print(obs)

In [None]:
from baselines import Baseline
class BaselineAvoidGreen(Baseline):
    def _predict(self, obs, deterministic=False):
        grid: np.ndarray = obs[:-3].reshape(4, 4, -1)
        
        agent_pos = self.find(grid, [1, 1, 1])
        red_pos = self.find(grid, [1, 0, 0])
        green_pos = self.find(grid, [0, 1, 0])
        blue_pos = self.find(grid, [0, 0, 1])
        
        obstacles = np.zeros((4, 4), dtype=bool)
        obstacles[blue_pos] = True
        
        path = self.find_path(agent_pos, red_pos, obstacles)
        if path:
            action = self.direction_to(agent_pos, path[1])
            return action
        else:
            return self.random_action()

# M.evaluate(BaselineAvoidGreen(), env, 1, height=400, width=1000)
# M.evaluate(BaselineAvoidGreen(), env)
M.make_stats(BaselineAvoidGreen(), env, 100000, "Perfect agent going to red, avoiding blue, ignoring green")


# Try to split the network into the red, green and blue parts

In [None]:
import train
Exp = train.BlindThreeGoalsOneHot

# We load the model, and set it to use only one environment
model, stat = Exp.load(42, n_envs=1)
print(stat)
print(model.policy)

In [None]:
model, stat = Exp.load(42, n_envs=1)

# Set the weight decay to 0
for module in model.policy.modules():
    if isinstance(module, M.WeightDecay):
        module.weight_decay = 0
        
# Turn off gradient for the weights
for param in model.policy.parameters():
    param.requires_grad = False

env = Exp().get_eval_env()
dummy_obs = env.reset()[0]

# -- Add the mask -- #

target_type = torch.nn.Conv2d
nth_target = 0
after = False

# Find the target and the sequential in which it is
pre_mask = [m for m in model.policy.modules() if isinstance(m, target_type)][nth_target]
sequential = next(m for m in model.policy.modules() if isinstance(m, torch.nn.Sequential) and pre_mask in m)

# Find the index of the layer we want to insert the mask after
target_idx = next(i for i, m in enumerate(sequential) if m is pre_mask)
print(sequential)
print("Target index:", target_idx)

mask = M.LazyMask()
sequential.insert(target_idx + after, M.ZeroOneRegularisation(mask, 0))
model.predict(dummy_obs)
print(sequential)

# Testing that only the mask is trainable
trainable = [param for param in model.policy.parameters() if param.requires_grad]
assert trainable == [mask.mask], trainable
assert len([module for module in model.policy.modules() if isinstance(module, M.Mask)]) == 1, "Multiple masks, reload the model."


In [None]:
def reward_fn(env: M.ThreeGoalsEnv) -> float:
    if mask.flipped:
        return env.agent_pos == env.goal_positions[0]  # red
    else:
        return env.agent_pos == env.goal_positions[1]  # green
    
mk_env = lambda: M.FunctionRewardWrapper(Exp().get_eval_env(), reward_fn)
# mk_env = lambda: Exp().get_train_env()
# mk_env = lambda: Exp().get_eval_env()

In [None]:
from stable_baselines3.common.callbacks import BaseCallback

# Train the model
use_wandb = False
weight_decay = 0
learning_rate = 3e-4
model.lr_schedule = lambda _: learning_rate
model.verbose = 1
model.set_env(mk_env())
mask.flipped = 0

model.policy.optimizer = torch.optim.AdamW(mask.parameters(), learning_rate, weight_decay=0)
# model.batch_size = 100
# model.n_steps = 100
# model.rollout_buffer.buffer_size = model.n_steps
# model.rollout_buffer.reset()

class FlipMaskCallback(BaseCallback):
    def __init__(self, mask: M.Mask):
        super().__init__()
        self.mask = mask
        
    def _on_rollout_start(self) -> None:
        self.mask.flip()
        print(f"Flipped mask to {self.mask.flipped}")
        
    def _on_step(self) -> bool:
        return True

callbacks = [M.ProgressBarCallback(),
             # FlipMaskCallback(mask),
             M.WeightDecayCallback(lambda f: weight_decay),
             # M.LogChannelNormsCallback()
             ]
# if use_wandb:
#     callbacks.append(M.WandbWithBehaviorCallback(mk_env()))


model.learn(total_timesteps=20_000, callback=callbacks);

In [None]:
with torch.no_grad():
    mask.mask[:] = 1
    mask.mask[2] = 0

In [None]:
# Plot the mask
px.imshow(mask.mask.detach(), title="Mask of the first convolutional layer", facet_col=0, width=1000, zmin=0, zmax=1, color_continuous_scale="Blues").show()

In [None]:
M.make_stats(model, mk_env(), 10_000)

In [None]:
for _ in range(2):
    M.make_stats(model, mk_env(), 1_000)
    mask.flip()

In [None]:
model, _ = Exp.load(42, n_envs=1)
model.policy

In [None]:
with M.record_activations(model.policy) as cache:
    agent_pos, goal_pos, end_type, true_goal = M.destination_stats(model, Exp().get_eval_env(), 10_000)

In [None]:
print(cache)
print("Agent pos shape:", agent_pos.shape)
print("Goals pos shape:", goal_pos.shape)
print("End type shape:", end_type.shape)
print("True goal shape:", true_goal.shape)

In [None]:
# Find which activation correspond to which end type
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

end_red_or_green = end_type < 2
for name in cache:
    print(name)
    out = M.show_fit(
        # make_pipeline(StandardScaler(), LogisticRegression()),
        LogisticRegression(max_iter=1000),
        cache[name][end_red_or_green].flatten(1),
        end_type[end_red_or_green].bool(),
        title=f"Predicting the end type from the activations of {name}",
        xaxis="Predicted end_type is green",
        yaxis="True end_type is green",
        classification=True,
    )
    if "NOP" in name:
        reg = out
        break

In [None]:
print(reg.coef_)
print(reg.intercept_)

# Plot reg.coef_[:-3] as a 5x4x4 image
coef = reg.coef_[0, :-3].reshape(-1, 4, 4)
px.imshow(coef, facet_col=0, **ZERO_CENTERED, ).show()
print(1)

# Exploring reward hacking

In [None]:
model = M.MLP(1, 10, 10, 10, 1, activation=nn.ReLU)

lr = 0.2
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=0.00)
epochs = 10_000

losses = []
prediction = []

correct = lambda x: np.sin(x / 200 * np.pi * x ** 0.2)

for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    out = model(torch.ones(1))
    loss = (out - correct(epoch)) ** 2
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    prediction.append(out.item())
    
# Show prediction and correct on one plot
import plotly.graph_objects as go
corrects = np.array([correct(i) for i in range(epochs)])
predictions = np.array(prediction)
losses = np.array(losses)

# Clip them
predictions = np.clip(predictions, -1, 1)
losses = np.clip(losses, 0, 1)

skip = 0
fig = go.Figure()
fig.add_scatter(y=corrects[skip:], name="Correct")
fig.add_scatter(y=prediction[skip:], name="Prediction")
fig.show()

# Show the loss (log scale)
px.scatter(y=losses[skip:], log_y=True, title="Loss", labels=dict(y="Loss")).show()
px.scatter(y=(corrects - prediction)[skip:], 
        title="Error", labels=dict(y="Error")).show()


