In [None]:
import sys

sys.path.append("../../")

In [None]:
import gymnasium as gym
from vectorhash import build_vectorhash_architecture
from shifts import RatShift
from smoothing import RatSLAMSmoothing
from miniworld.params import DEFAULT_PARAMS
from miniworld_agent import MiniworldVectorhashAgent
import math
import torch
from preprocessing_cnn import GrayscaleAndFlattenPreprocessing, PreprocessingCNN
import matplotlib.pyplot as plt
from matplotlib import axes
from graph_utils import plot_probability_distribution_on_ax


### vhash
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

shapes = [(5, 5, 5), (8, 8, 8)]
N = 10


def make_env():
    params = DEFAULT_PARAMS.copy().no_random()
    env = gym.make(
        "MiniWorld-CollectHealth-v0",
        max_episode_steps=-1,
        params=params,
        domain_rand=False,
    )
    min_x = env.get_wrapper_attr("min_x")
    max_x = env.get_wrapper_attr("max_x")
    min_z = env.get_wrapper_attr("min_z")
    max_z = env.get_wrapper_attr("max_z")

    limits = torch.tensor([max_x - min_x, max_z - min_z, 2 * math.pi]).float()
    return env, limits


env_cnn, limits = make_env()
env_no_cnn, limits = make_env()
model_cnn = build_vectorhash_architecture(
    shapes,
    N_h=600,
    input_size=128,
    initalization_method="by_sparsity",
    limits=limits,
    device=device,
    shift=RatShift(),
    smoothing=RatSLAMSmoothing(device=device),
)
model_no_cnn = build_vectorhash_architecture(
    shapes,
    N_h=600,
    input_size=60 * 80,
    initalization_method="by_sparsity",
    limits=limits,
    device=device,
    shift=RatShift(),
    smoothing=RatSLAMSmoothing(device=device),
)

#### preprocessor
cnn_preproc = PreprocessingCNN(
    device=device,
    latent_dim=128,
    input_channels=3,
    target_size=(224, 224),
    model_path="resnet18_adapter.pth",
)

grayscale_flatten_preproc = GrayscaleAndFlattenPreprocessing(device=device)

#### agents
agent_cnn = MiniworldVectorhashAgent(model_cnn, env_cnn, preprocessor=cnn_preproc)

agent_no_cnn = MiniworldVectorhashAgent(
    model_no_cnn, env_no_cnn, preprocessor=grayscale_flatten_preproc
)

In [None]:
start_state = agent_no_cnn.vectorhash.scaffold.g
print("start state:", start_state)
print("grid limits:", agent_no_cnn.vectorhash.scaffold.grid_limits)
print("world limits:", limits)
print("scale factor:", agent_no_cnn.vectorhash.scaffold.scale_factor)

In [None]:
begin = agent_no_cnn.get_true_pos(env_no_cnn)
end = begin + 1 / agent_no_cnn.vectorhash.scaffold.scale_factor
dist = torch.distributions.uniform.Uniform(low=begin, high=end)
samples = dist.sample((N,))
grid_states = agent_no_cnn.vectorhash.scaffold.scale_factor * (samples - begin)

In [None]:
print("samples:", samples)
print("grid states:", grid_states)

In [None]:
first_img, first_pos = agent_no_cnn._env_reset(env_cnn)
first_img, first_pos = agent_no_cnn._obs_postpreprocess(agent_no_cnn.env.step(4), 4)
m = plt.imshow(first_img)
plt.colorbar(m)
agent_no_cnn.vectorhash.store_memory(s=agent_no_cnn.preprocessor.encode(first_img))
agent_cnn.vectorhash.store_memory(s=agent_cnn.preprocessor.encode(first_img))

In [None]:
imgs = []
states_cnn = []
states_no_cnn = []
for sample in samples:
    pos = sample.cpu().numpy()
    agent_no_cnn.set_agent_pos(pos)
    agent_cnn.set_agent_pos(pos)
    img, _ = agent_no_cnn._obs_postpreprocess(agent_no_cnn.env.step(4), 4)
    obs_cnn = agent_cnn.preprocessor.encode(img)
    obs_no_cnn = agent_no_cnn.preprocessor.encode(img)
    g_cnn = agent_cnn.vectorhash.scaffold.denoise(
        agent_cnn.vectorhash.scaffold.grid_from_hippocampal(
            agent_cnn.vectorhash.hippocampal_sensory_layer.hippocampal_from_sensory(
                obs_cnn
            )
        )
    )
    g_no_cnn = agent_no_cnn.vectorhash.scaffold.denoise(
        agent_no_cnn.vectorhash.scaffold.grid_from_hippocampal(
            agent_no_cnn.vectorhash.hippocampal_sensory_layer.hippocampal_from_sensory(
                obs_no_cnn
            )
        )
    )
    m = plt.imshow(img)
    plt.colorbar(m)
    plt.show()
    imgs.append(img)
    states_cnn.append(g_cnn)
    states_no_cnn.append(g_no_cnn)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))

plot_probability_distribution_on_ax(states_cnn[0][0].cpu(), ax)

ax.set_ylim(0, 0.01)
ax.set_xlim(0, len(states_cnn[0][0]))
ax.set_ylabel("probability mass")

ax.set_xlabel("g distribution")

fig.suptitle("probability mass across g (cnn)")

In [None]:
fig.savefig('results_cnn_across_g.png')

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 4))

plot_probability_distribution_on_ax(states_no_cnn[0][0].cpu(), ax)

ax.set_ylim(0, 0.01)
ax.set_xlim(0, len(states_cnn[0][0]))
ax.set_ylabel("probability mass")

ax.set_xlabel("g distribution")

fig.suptitle("probability mass across g (no_cnn)")

In [None]:
fig.savefig('results_across_g_no_cnn.png')

In [None]:
cnn_x_dists = []
cnn_y_dists = []
cnn_theta_dists = []
no_cnn_x_dists = []
no_cnn_y_dists = []
no_cnn_theta_dists = []

for i in range(N):
    agent_cnn.vectorhash.scaffold.modules = (
        agent_cnn.vectorhash.scaffold.modules_from_g(states_cnn[i][0])
    )
    agent_no_cnn.vectorhash.scaffold.modules = (
        agent_no_cnn.vectorhash.scaffold.modules_from_g(states_no_cnn[i][0])
    )

    cnn_x_dists.append(agent_cnn.vectorhash.scaffold.expand_distribution(0).cpu())
    cnn_y_dists.append(agent_cnn.vectorhash.scaffold.expand_distribution(1).cpu())
    cnn_theta_dists.append(agent_cnn.vectorhash.scaffold.expand_distribution(2).cpu())
    no_cnn_x_dists.append(agent_no_cnn.vectorhash.scaffold.expand_distribution(0).cpu())
    no_cnn_y_dists.append(agent_no_cnn.vectorhash.scaffold.expand_distribution(1).cpu())
    no_cnn_theta_dists.append(agent_no_cnn.vectorhash.scaffold.expand_distribution(2).cpu())

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(6, 12))

plot_probability_distribution_on_ax(cnn_x_dists[0], axs[0])
axs[0].set_xlabel('x dist')
axs[0].set_xlim(0, len(cnn_x_dists[0]))

plot_probability_distribution_on_ax(cnn_y_dists[0], axs[1])
axs[1].set_xlabel('y dist')
axs[1].set_xlim(0, len(cnn_y_dists[0]))

plot_probability_distribution_on_ax(cnn_theta_dists[0], axs[2])
axs[2].set_xlabel('θ dist')
axs[2].set_xlim(0, len(cnn_theta_dists[0]))

for i in range(3):
  axs[i].set_ylabel('probability mass')
  axs[i].set_ylim(0, 0.1)

fig.suptitle("probability mass across dimension distributions (cnn)")

In [None]:
fig.savefig("results_across_dims_cnn.png")

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(6, 12))

plot_probability_distribution_on_ax(no_cnn_x_dists[0], axs[0])
axs[0].set_xlabel('x dist')
axs[0].set_xlim(0, len(no_cnn_x_dists[0]))

plot_probability_distribution_on_ax(cnn_y_dists[0], axs[1])
axs[1].set_xlabel('y dist')
axs[1].set_xlim(0, len(no_cnn_y_dists[0]))

plot_probability_distribution_on_ax(cnn_theta_dists[0], axs[2])
axs[2].set_xlabel('θ dist')
axs[2].set_xlim(0, len(no_cnn_theta_dists[0]))

for i in range(3):
  axs[i].set_ylabel('probability mass')
  axs[i].set_ylim(0, 0.1)

fig.suptitle("probability mass across dimension distributions (no cnn)")

In [None]:
fig.savefig("results_across_dims_no_cnn.png")