# https://github.com/users/jbloomAus/projects/1/views/1?pane=issue&itemId=27197493




- [ ] Make a BOW model that encodes the object/color/state at each position using different embedding matrices
- [ ] Visualize each/make method to default to orthogonal or something? see how much by default?
- [ ] ensure init is in similar range to typical transformer situation
- [ ] Make a sinusoidal relative position encoding like what's used for some LLMs and add that on top. 
- [ ] Visualize positional embeddings via dot product as with did in arena. 

In [1]:
import sys 
sys.path.append('..')

import torch
from src.environments.memory import MemoryEnv
from src.models.components import MiniGridBOWEmbedding
from src.visualization import tensor_cosine_similarity_heatmap, tensor_2d_embedding_similarity
from minigrid.wrappers import ViewSizeWrapper
import plotly.express as px 

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
# px.imshow(obs['image'][:,:,0].T).show()


state_embedding = MiniGridBOWEmbedding(
    embedding_dim=32, 
    max_values=[11, 6, 3], 
    channel_names=['object', 'color', 'state'], 
    view_size=7,
    add_positional_enc=True)

print(state_embedding)
obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
print(state_embedding(obs).shape)
print("--")
print(state_embedding.get_channel_embedding('object').shape)
print(state_embedding.get_channel_embedding('color').shape)
print(state_embedding.get_channel_embedding('state').shape)
print(state_embedding.get_all_channel_embeddings().shape)
print(state_embedding.get_positional_encoding().shape) # 2D positional encoding
print("--")
print("average norm channel embeddings:", torch.norm(state_embedding.get_all_channel_embeddings(), dim = 1).mean())
print("average norm positional encoding:", torch.norm(state_embedding.get_positional_encoding(), dim = 1).mean())


# Test the function with a sample PyTorch tensor
for channel in ['object', 'color', 'state']:
    tensor_cosine_similarity_heatmap(state_embedding.get_channel_embedding(channel).detach())
tensor_cosine_similarity_heatmap(state_embedding.get_all_channel_embeddings().detach())

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
px.imshow(obs['image'][:,:,0].T).show()

# Test the function with an example image sample PyTorch tensor
obs = torch.from_numpy(obs['image']).unsqueeze(0)
embed_2d = state_embedding(obs).detach()
x = 2; y = 1
tensor_2d_embedding_similarity(embed_2d[0], x, y)
# tensor_2d_embedding_similarity(embed_2d[0], x, y, mode="contour")

non_positional_embeddings = state_embedding.get_all_channel_embeddings().detach()
positional_encoding = state_embedding.get_positional_encoding().detach()
concatenated_embeddings = torch.cat([non_positional_embeddings, positional_encoding.squeeze(0).view(-1, 32)], dim=0)

tensor_cosine_similarity_heatmap(concatenated_embeddings)

MiniGridBOWEmbedding(
  (embedding): Embedding(33, 32)
  (position_encoding): Summer(
    (penc): PositionalEncoding2D()
  )
)
torch.Size([1, 7, 7, 32])
--
torch.Size([11, 32])
torch.Size([6, 32])
torch.Size([3, 32])
torch.Size([20, 32])
torch.Size([1, 7, 7, 32])
--
average norm channel embeddings: tensor(0.9847, grad_fn=<MeanBackward0>)
average norm positional encoding: tensor(1.4963)


In [2]:
import torch
from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, PositionalEncoding3D, Summer

# now let's do it in 2d 
p_enc_2d_model = PositionalEncoding2D(128)
p_enc_2d_model_sum = Summer(PositionalEncoding2D(128))

x = torch.rand(1,7,7,128)
penc_no_sum = p_enc_2d_model(x) # penc_no_sum.shape == (1, 6, 10)
penc_sum = p_enc_2d_model_sum(x)
# print(penc_no_sum + x == penc_sum) # True

In [3]:

# Test the function with a sample PyTorch tensor
n_dim = 8
p_enc_2d_model = PositionalEncoding2D(128)
p_enc_2d_model_sum = Summer(PositionalEncoding2D(128))
x = torch.zeros(1,7,7,128)
tensor = p_enc_2d_model_sum(x)
x, y = 2, 5
