# Bag of Words Encoding and Vision Transformer Notebook

Previously, I have been using relatively simplistic schemes for encoding images. I have been using a linear layer and then later a CNN encode images into a vector. However, I suspect a ViT will actually work very well here. Moreover, since there is little/no actual feature recognition to be done and more just combining/projecting into a space, I think a ViT will work well. I haven't see any examples of people working with a ViT in offline-RL so it could be cool to do it. 

## Visualizing the Bag of Word Embeddings

In [1]:
import sys 
sys.path.append('..')

import torch
from src.environments.memory import MemoryEnv
from src.models.components import MiniGridBOWEmbedding
from src.visualization import tensor_cosine_similarity_heatmap, tensor_2d_embedding_similarity
from minigrid.wrappers import ViewSizeWrapper
import plotly.express as px 

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
# px.imshow(obs['image'][:,:,0].T).show()


bow_embedding = MiniGridBOWEmbedding(
    embedding_dim=32, 
    max_values=[11, 6, 3], 
    channel_names=['object', 'color', 'state'], 
    view_size=7,
    add_positional_enc=True)

print(bow_embedding)
obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
print(bow_embedding(obs).shape)
print("--")
print(bow_embedding.get_channel_embedding('object').shape)
print(bow_embedding.get_channel_embedding('color').shape)
print(bow_embedding.get_channel_embedding('state').shape)
print(bow_embedding.get_all_channel_embeddings().shape)
print(bow_embedding.get_positional_encoding().shape) # 2D positional encoding
print("--")
print("average norm channel embeddings:", torch.norm(bow_embedding.get_all_channel_embeddings(), dim = 1).mean())
print("average norm positional encoding:", torch.norm(bow_embedding.get_positional_encoding(), dim = 1).mean())


# Test the function with a sample PyTorch tensor
for channel in ['object', 'color', 'state']:
    tensor_cosine_similarity_heatmap(bow_embedding.get_channel_embedding(channel).detach())
tensor_cosine_similarity_heatmap(bow_embedding.get_all_channel_embeddings().detach())

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
px.imshow(obs['image'][:,:,0].T).show()

MiniGridBOWEmbedding(
  (embedding): Embedding(33, 32)
  (position_encoding): Summer(
    (penc): PositionalEncoding2D()
  )
)
torch.Size([1, 32, 7, 7])
--
torch.Size([11, 32])
torch.Size([6, 32])
torch.Size([3, 32])
torch.Size([20, 32])
torch.Size([1, 7, 7, 32])
--
average norm channel embeddings: tensor(1.0078, grad_fn=<MeanBackward0>)
average norm positional encoding: tensor(1.4963)


In [2]:
# Test the function with an example image sample PyTorch tensor
obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
embed_2d = bow_embedding(obs).detach()
x = 2; y = 1
print(embed_2d[0].shape)
tensor_2d_embedding_similarity(embed_2d[0].permute(1,2,0), x, y)
# tensor_2d_embedding_similarity(embed_2d[0], x, y, mode="contour")

torch.Size([32, 7, 7])


In [3]:
non_positional_embeddings = bow_embedding.get_all_channel_embeddings().detach()
positional_encoding = bow_embedding.get_positional_encoding().detach()
concatenated_embeddings = torch.cat(
    [non_positional_embeddings, 
     positional_encoding.squeeze(0).view(-1, 32)], dim=0)

tensor_cosine_similarity_heatmap(concatenated_embeddings)

## Building a ViT using HookedTransformer

- things to note:
    - tradition ViT's are encoder models which use an mlp head and a class token. 
    - I suspect an attention only decoder ViT, will be perfectly fine for this task.
    - However, it's not going to be that hard to turn that on/off so can try both.

resource: https://github.com/lucidrains/vit-pytorch

In [6]:
import torch.nn as nn 
from transformer_lens import HookedTransformer, HookedTransformerConfig

bow_embedding = MiniGridBOWEmbedding(
    embedding_dim=64, 
    max_values=[11, 6, 3], 
    channel_names=['object', 'color', 'state'], 
    view_size=7,
    add_positional_enc=True)

vit_config = HookedTransformerConfig(
        n_layers=2,
        d_model=64,
        d_head=32,
        n_heads=2,
        d_mlp=256,
        d_vocab=128, # needs to match the model.
        n_ctx=7*7, # 7x7 grid
        normalization_type=None,
        attention_dir="causal",
        attn_only=True,
        )

vit = HookedTransformer(vit_config)
vit.embed = nn.Identity()
vit.unembed 


obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
bow = bow_embedding(obs)
print(bow.shape)
bow = bow.flatten(-2,-1).permute(0,2,1) # convert into 49 tokens (n_ctx)
print(bow.shape)
state_embedding = vit(bow)
state_embedding[:,-1].shape # 1 sequence token. 



torch.Size([1, 64, 7, 7])
torch.Size([1, 49, 64])


torch.Size([1, 128])