# Agents:
These simulations evaluate several agents exploring the thread the needle enviroment

In [1]:
%load_ext autoreload
%autoreload 2
# from IPython.display import display, HTML
# display(HTML("<style>.container { width:90% !important; }</style>"))

In [4]:
%matplotlib inline
import sys
import torch
import numpy as np
import pandas as pd
from stable_baselines3 import A2C, DQN, PPO
import matplotlib.pyplot as plt

from state_inference.gridworld_env import CnnWrapper, OpenEnv, ThreadTheNeedleEnv
from state_inference.utils.training_utils import train_model, parse_task_config
from state_inference.utils.pytorch_utils import DEVICE
from state_inference.model.agents import ValueIterationAgent, ViAgentWithExploration, ViDynaAgent

print(f"python {sys.version}")
print(f"torch {torch.__version__}")
print(f"device = {DEVICE}")


python 3.10.11 (main, Apr 20 2023, 13:58:42) [Clang 14.0.6 ]
torch 2.0.1
device = mps


In [5]:
CONFIG_FILE = "state_inference/env_config.yml"
TASK_NAME = "thread_the_needle"
TASK_CLASS = ThreadTheNeedleEnv

In [6]:
env_kwargs, training_kwargs = parse_task_config(TASK_NAME, CONFIG_FILE)

training_kwargs["n_train_steps"] = 50000
training_kwargs["n_epochs"] = 1

# create the task
task = CnnWrapper(TASK_CLASS.create_env(**env_kwargs))

pi, _ = task.get_optimal_policy()
training_kwargs["optimal_policy"] = pi

  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
from state_inference.model.vae import (
    DEVICE, Encoder, Decoder, StateVae
)

### Model + Training Parameters
N_EPOCHS = 20  # should be 20
EMBEDDING_LAYERS = 5
EMBEDDING_DIM = len(task.observation_model.states) // 2
OBSERVATION_DIM = task.observation_model.map_height**2
LR = 3e-4
beta = 1.0
tau = 2.0
gamma = 0.99
dropout = 0.0

optim_kwargs = dict(lr=LR)

# create the model
encoder_hidden = [OBSERVATION_DIM // 5, OBSERVATION_DIM // 10]
decoder_hidden = [OBSERVATION_DIM // 2, OBSERVATION_DIM // 5]
z_dim = EMBEDDING_DIM * EMBEDDING_LAYERS


def make_model():
    encoder = Encoder(
        OBSERVATION_DIM,
        encoder_hidden,
        z_dim,
        dropout=dropout,
    )
#     encoder = CnnEncoder(
#         1,
#         task.observation_model.map_height,
#         task.observation_model.map_height,
#         z_dim
#     )

    decoder = Decoder(
        z_dim,
        decoder_hidden,
        OBSERVATION_DIM,
        dropout=dropout,
    )

    vae_kwargs = dict(
        z_dim=EMBEDDING_DIM, z_layers=EMBEDDING_LAYERS, beta=beta, tau=tau, gamma=gamma
    )

    vae_model = StateVae(encoder, decoder, **vae_kwargs).to(DEVICE)

    agent = ViAgentWithExploration(
        task, vae_model, set_action=set(range(4)), optim_kwargs=optim_kwargs
    )
    return agent




In [10]:
agent = make_model()
agent.learn(10000, estimate_batch=True, progress_bar=True)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:25<00:00, 117.54it/s]


In [11]:
from state_inference.utils.pytorch_utils import DEVICE, convert_8bit_to_float, train
obs = torch.stack([
    torch.tensor(agent.get_env().env_method("generate_observation", s)[0]).view((40, 40, 1))
    for s in range(400)
])
agent.state_inference_model.get_state(convert_8bit_to_float(obs))

array([[ 64,  64, 114, 181, 114],
       [ 64,  64, 114, 181, 114],
       [ 64,  64, 114, 181, 114],
       ...,
       [ 98,  91, 112,  70,  39],
       [ 98,  91, 112,  70,  39],
       [ 98,  91, 112,  70,  39]])

In [None]:
# from state_inference.utils.training_utils import get_policy_prob, vae_get_pmf

# agent = make_model()
# # agent.learn(10000, estimate_batch=True, progress_bar=True)
# for ii in range(5):
#     agent.learn(10000, estimate_batch=True, progress_bar=True)
    
#     pmf = get_policy_prob(
#         agent,
#         vae_get_pmf,
#         n_states=env_kwargs["n_states"],
#         map_height=env_kwargs["map_height"],
#         cnn=True,
#     )
#     score = np.sum(pi * pmf, axis=1).mean()
#     print(f"batch {ii+1}, score {score}")
    
    

In [None]:

    
pmf = get_policy_prob(
    agent,
    cnn_get_pmf,
    n_states=env_kwargs["n_states"],
    map_height=env_kwargs["map_height"],
    cnn=True,
)
score = np.sum(pi * pmf, axis=1).mean()
print(f"Single Big Batch, score {score}")

In [None]:

pmf = get_policy_prob(
    agent,
    vae_get_pmf,
    n_states=env_kwargs["n_states"],
    map_height=env_kwargs["map_height"],
    cnn=True,
)
pmf

In [None]:
import seaborn as sns

fig, axes = plt.subplots(2, 2)
h, w = env_kwargs["height"], env_kwargs["width"]

axes[0][0].imshow(pmf[:, 0].reshape(h, w))
axes[0][1].imshow(pmf[:, 1].reshape(h, w))
axes[1][0].imshow(pmf[:, 2].reshape(h, w))
axes[1][1].imshow(pmf[:, 3].reshape(h, w))


axes[0][0].set_title("up")
axes[0][1].set_title("down")
axes[1][0].set_title("left")
axes[1][1].set_title("right")

plt.subplots_adjust(hspace=0.3, wspace=-0.3)

plt.suptitle("Value Iteration Agent Learned Policy")

In [None]:
np.sum(pi * pmf, axis=1).mean()

In [None]:
room_1_mask = (np.arange(400) < 200) * (np.arange(400) % 20 < 10)
room_2_mask = (np.arange(400) >= 200) * (np.arange(400) % 20 < 10)
room_3_mask = np.arange(400) % 20 >= 10

score_room_1 = np.sum(pi[room_1_mask] * pmf[room_1_mask], axis=1).mean()
score_room_2 = np.sum(pi[room_2_mask] * pmf[room_2_mask], axis=1).mean()
score_room_3 = np.sum(pi[room_3_mask] * pmf[room_3_mask], axis=1).mean()
plt.bar([0, 1, 2], [score_room_1, score_room_2, score_room_3])

sns.despine()

In [None]:
from state_inference.utils.pytorch_utils import make_tensor, convert_8bit_to_float
from sklearn.metrics import pairwise_distances

obs = convert_8bit_to_float(
    torch.stack(
        [
            make_tensor(task.observation_model(s))
            for s in range(task.transition_model.n_states)
            for _ in range(1)
        ]
    )
).to(DEVICE)
z = agent.state_inference_model.get_state(obs)

hash_vector = np.array(
    [
        agent.state_inference_model.z_dim**ii
        for ii in range(agent.state_inference_model.z_layers)
    ]
)

z = z.dot(hash_vector)
d = pairwise_distances(z.reshape(-1, 1), metric=lambda x, y: x == y)
plt.imshow(1 - d)

In [None]:
#plot the overlap of different states
# number the states and plot them
clusters = {}
k = 0
for z0 in sorted(z):
    if z0 not in clusters.keys():
        clusters[z0] = k 
        k += 1
clustered_states = np.array([clusters[z0] for z0 in z])
plt.imshow(clustered_states.reshape(-1, 20))
task.display_gridworld(plt.gca(), wall_color='w', annotate=False)
plt.title('State Clusters')

In [None]:
euc = pairwise_distances(
    [(x, y) for x in range(20) for y in range(20)],
    # metric=lambda x, y: np.sqrt((x[0] - y[0]) ** 2 + (x[1] - y[1]) ** 2),
    metric=lambda x, y: np.abs(x[0] - y[0]) + np.abs(x[1] - y[1]),
)

d_w_wall = np.mean([d[s1][s2] for s1, s2 in task.transition_model.walls])
print(f"Distance between neighboring states sepearted by a wall     {d_w_wall}")


wall_mask = np.zeros((task.n_states, task.n_states))
for s0, s1 in task.transition_model.walls:
    wall_mask[s0][s1] = 1.0
    wall_mask[s1][s0] = 1.0


d_wo_wall = d.reshape(-1)[(wall_mask.reshape(-1) == 0) & (euc.reshape(-1) == 1)].mean()
print(f"Distance between neighboring states NOT sepearted by a wall {d_wo_wall}")

In [None]:
# agent._estimate_reward_model()

rews = np.array([agent.reward_estimator.get_reward(z0) for z0 in z]).reshape(20, 20)
plt.imshow(rews)

In [None]:
obs = convert_8bit_to_float(
    torch.stack(
        [
            make_tensor(task.observation_model(s))
            for s in range(task.transition_model.n_states)
            for _ in range(1)
        ]
    )
).to(DEVICE)
z = agent.state_inference_model.get_state(obs)

hash_vector = np.array(
    [
        agent.state_inference_model.z_dim**ii
        for ii in range(agent.state_inference_model.z_layers)
    ]
)

z = z.dot(hash_vector)

rews = np.array([agent.reward_estimator.get_reward(z0) for z0 in z]).reshape(20, 20)
plt.imshow(rews)

In [None]:
def get_value_function(model, task):
    obs = convert_8bit_to_float(
        torch.stack(
            [
                make_tensor(task.observation_model(s))
                for s in range(task.transition_model.n_states)
                for _ in range(1)
            ]
        )
    ).to(DEVICE)
    z = model.state_inference_model.get_state(obs)

    hash_vector = np.array(
        [
            model.state_inference_model.z_dim**ii
            for ii in range(agent.state_inference_model.z_layers)
        ]
    )

    z = z.dot(hash_vector)

    value_function = np.array(
        [agent.value_function.get(z0, np.nan) for z0 in z]
    ).reshape(20, 20)
    return value_function


v = get_value_function(agent, task)
plt.imshow(v)
task.display_gridworld(plt.gca(), wall_color='w', annotate=True)
plt.title("Learned Value function")


In [None]:
plt.plot(v[5] - np.nanmin(v))

In [None]:
plt.plot(v[:, 10] - np.nanmin(v))
plt.plot(v[:, 9] - np.nanmin(v))

In [None]:
agent.state_inference_model.parameters()

In [None]:
agent.state_inference_model.tau