# Bag of Words Encoding and Vision Transformer Notebook

Previously, I have been using relatively simplistic schemes for encoding images. I have been using a linear layer and then later a CNN encode images into a vector. However, I suspect a ViT will actually work very well here. Moreover, since there is little/no actual feature recognition to be done and more just combining/projecting into a space, I think a ViT will work well. I haven't see any examples of people working with a ViT in offline-RL so it could be cool to do it. 

## Visualizing the Bag of Word Embeddings

In [1]:
import sys 
sys.path.append('..')

import torch
from src.environments.memory import MemoryEnv
from src.models.components import MiniGridBOWEmbedding
from src.visualization import tensor_cosine_similarity_heatmap, tensor_2d_embedding_similarity
from minigrid.wrappers import ViewSizeWrapper
import plotly.express as px 

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
# px.imshow(obs['image'][:,:,0].T).show()


bow_embedding = MiniGridBOWEmbedding(
    embedding_dim=32, 
    max_values=[11, 6, 3], 
    channel_names=['object', 'color', 'state'], 
    view_size=7,
    add_positional_enc=True)

print(bow_embedding)
obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
print(bow_embedding(obs).shape)
print("--")
print(bow_embedding.get_channel_embedding('object').shape)
print(bow_embedding.get_channel_embedding('color').shape)
print(bow_embedding.get_channel_embedding('state').shape)
print(bow_embedding.get_all_channel_embeddings().shape)
print(bow_embedding.get_positional_encoding().shape) # 2D positional encoding
print("--")
print("average norm channel embeddings:", torch.norm(bow_embedding.get_all_channel_embeddings(), dim = 1).mean())
print("average norm positional encoding:", torch.norm(bow_embedding.get_positional_encoding(), dim = 1).mean())


# Test the function with a sample PyTorch tensor
for channel in ['object', 'color', 'state']:
    tensor_cosine_similarity_heatmap(bow_embedding.get_channel_embedding(channel).detach())
tensor_cosine_similarity_heatmap(bow_embedding.get_all_channel_embeddings().detach())

env = ViewSizeWrapper(MemoryEnv(size = 7, random_length=False, random_start_pos=False, max_steps=200, render_mode='rgb_array'), 7)
obs, info = env.reset()
obs['image'].shape
# px.imshow(env.render()).show()
px.imshow(obs['image'][:,:,0].T).show()

MiniGridBOWEmbedding(
  (embedding): Embedding(33, 32)
  (position_encoding): Summer(
    (penc): PositionalEncoding2D()
  )
)
torch.Size([1, 32, 7, 7])
--
torch.Size([11, 32])
torch.Size([6, 32])
torch.Size([3, 32])
torch.Size([20, 32])
torch.Size([1, 7, 7, 32])
--
average norm channel embeddings: tensor(1.0078, grad_fn=<MeanBackward0>)
average norm positional encoding: tensor(1.4963)


In [2]:
# Test the function with an example image sample PyTorch tensor
obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
embed_2d = bow_embedding(obs).detach()
x = 2; y = 1
print(embed_2d[0].shape)
tensor_2d_embedding_similarity(embed_2d[0].permute(1,2,0), x, y)
# tensor_2d_embedding_similarity(embed_2d[0], x, y, mode="contour")

torch.Size([32, 7, 7])


In [3]:
non_positional_embeddings = bow_embedding.get_all_channel_embeddings().detach()
positional_encoding = bow_embedding.get_positional_encoding().detach()
concatenated_embeddings = torch.cat(
    [non_positional_embeddings, 
     positional_encoding.squeeze(0).view(-1, 32)], dim=0)

tensor_cosine_similarity_heatmap(concatenated_embeddings)

## Building a ViT using HookedTransformer

- things to note:
    - tradition ViT's are encoder models which use an mlp head and a class token. 
    - I suspect an attention only decoder ViT, will be perfectly fine for this task.
    - However, it's not going to be that hard to turn that on/off so can try both.

resource: https://github.com/lucidrains/vit-pytorch

In [6]:
import torch.nn as nn 
from transformer_lens import HookedTransformer, HookedTransformerConfig

bow_embedding = MiniGridBOWEmbedding(
    embedding_dim=64, 
    max_values=[11, 6, 3], 
    channel_names=['object', 'color', 'state'], 
    view_size=7,
    add_positional_enc=True)

vit_config = HookedTransformerConfig(
        n_layers=2,
        d_model=64,
        d_head=32,
        n_heads=2,
        d_mlp=256,
        d_vocab=128, # needs to match the model.
        n_ctx=7*7, # 7x7 grid
        normalization_type=None,
        attention_dir="causal",
        attn_only=True,
        )

vit = HookedTransformer(vit_config)
vit.embed = nn.Identity()
vit.unembed 


obs, info = env.reset()
obs = torch.from_numpy(obs['image']).unsqueeze(0)
bow = bow_embedding(obs)
print(bow.shape)
bow = bow.flatten(-2,-1).permute(0,2,1) # convert into 49 tokens (n_ctx)
print(bow.shape)
state_embedding = vit(bow)
state_embedding[:,-1].shape # 1 sequence token. 



torch.Size([1, 64, 7, 7])
torch.Size([1, 49, 64])


torch.Size([1, 128])

# Comparing Model Sizes
Since we are using a learned embedding for the position which is provided by T-Lens and helps the model recognize different tokens I guess.  

In [1]:
import sys 
sys.path.append('..')
import torch
from src.config import TransformerModelConfig, EnvironmentConfig, OnlineTrainConfig, LSTMModelConfig
from src.models.trajectory_transformer import DecisionTransformer
from src.environments.environments import make_env
from torchinfo import summary

env_config = EnvironmentConfig()
transformer_config = TransformerModelConfig(n_ctx=26)
DT = DecisionTransformer(env_config, transformer_config)
env = make_env(env_config, 0, 0, "test")()
obs, info = env.reset()

summary(DT, 
        input_size=[
            (2, 9, 7, 7, 3), # n_ctx + 1 // 3 is seq length
            (2, 8, 1), # one less action
            (2, 9, 1), # same rtg for each step
            (2, 9, 1), # time of each step
            ],
        dtypes=[torch.float32, torch.int, torch.int, torch.int],
        depth=3,
        )

Layer (type:depth-idx)                        Output Shape              Param #
DecisionTransformer                           [2, 9, 147]               --
├─Linear: 1-1                                 [18, 128]                 18,816
├─Sequential: 1-2                             [16, 1, 128]              --
│    └─Embedding: 2-1                         [16, 1, 128]              512
├─Sequential: 1-3                             [18, 128]                 --
│    └─Linear: 2-2                            [18, 128]                 128
├─Embedding: 1-4                              [18, 1, 128]              128,128
├─HookedTransformer: 1-5                      [2, 26, 128]              --
│    └─Identity: 2-3                          [2, 26, 128]              --
│    └─HookPoint: 2-4                         [2, 26, 128]              --
│    └─PosEmbedTokens: 2-5                    [2, 26, 128]              3,328
│    └─HookPoint: 2-6                         [2, 26, 128]              --
│    └

In [2]:
transformer_config = TransformerModelConfig(n_ctx=26)
DT = DecisionTransformer(env_config, transformer_config)
env = make_env(env_config, 0, 0, "test")()
obs, info = env.reset()

summary(DT, 
        input_size=[
            (2, 9, 7, 7, 3), # n_ctx + 1 // 3 is seq length
            (2, 8, 1), # one less action
            (2, 9, 1), # same rtg for each step
            (2, 9, 1), # time of each step
            ],
        dtypes=[torch.float32, torch.int, torch.int, torch.int],
        depth=6,
        )

Layer (type:depth-idx)                        Output Shape              Param #
DecisionTransformer                           [2, 9, 147]               --
├─Linear: 1-1                                 [18, 128]                 18,816
├─Sequential: 1-2                             [16, 1, 128]              --
│    └─Embedding: 2-1                         [16, 1, 128]              512
├─Sequential: 1-3                             [18, 128]                 --
│    └─Linear: 2-2                            [18, 128]                 128
├─Embedding: 1-4                              [18, 1, 128]              128,128
├─HookedTransformer: 1-5                      [2, 26, 128]              --
│    └─Identity: 2-3                          [2, 26, 128]              --
│    └─HookPoint: 2-4                         [2, 26, 128]              --
│    └─PosEmbedTokens: 2-5                    [2, 26, 128]              3,328
│    └─HookPoint: 2-6                         [2, 26, 128]              --
│    └

In [8]:
import torch
import pandas as pd

def get_param_stats(model):
    param_stats = []

    for name, param in model.named_parameters():
        mean = param.data.mean().item()
        std = param.data.std().item()
        if param.data.dim() > 1:
            norm = torch.norm(param.data, dim=1).mean().item()
        else:
            norm = torch.norm(param.data).item()
        param_stats.append({'name': name, 'mean': mean, 'std': std, 'norm': norm})

    df = pd.DataFrame(param_stats)
    return df


import plotly.express as px
import plotly.graph_objects as go
import numpy as np 

def plot_param_stats(df):
    # Calculate log of standard deviation
    df['log_std'] = -1*np.log(df['std'])

    # add color column to df, red for ends with weight, blue for ends with bias, purple if embedding
    df['color'] = 'green'
    df.loc[df['name'].str.endswith('weight'), 'color'] = 'red'
    df.loc[df['name'].str.contains('W_'), 'color'] = 'red'
    df.loc[df['name'].str.endswith('bias'), 'color'] = 'blue'
    df.loc[df['name'].str.contains('b_'), 'color'] = 'blue'
    df.loc[df['name'].str.contains('embedding'), 'color'] = 'purple'

    # make a name label which is name.split('.')[-1]
    df['name_label'] = df['name'].apply(lambda x: x.split('.')[-1])

    # Create the mean bar chart
    fig_mean = go.Figure()
    fig_mean.add_trace(go.Bar(x=df['name'], y=df['mean'], text=df['mean'], textposition='outside', 
                          hovertext=df['name_label'], marker_color=df['color']))
    fig_mean.update_traces(texttemplate='%{text:.4f}', hovertemplate='Parameter: %{hovertext}<br>Mean: %{text:.4f}')
    fig_mean.update_yaxes(title_text='Mean')
    fig_mean.update_xaxes(title_text='Parameter Name')
    fig_mean.update_layout(title_text='Mean of Model Parameters')
    

    # Create the norm chart
    fig_norm = go.Figure()
    fig_norm.add_trace(go.Bar(x=df['name'], y=df['norm'], text=df['norm'], textposition='outside', 
                              hovertext=df['name_label'], marker_color=df['color']))
    fig_norm.update_traces(texttemplate='%{text:.4f}', hovertemplate='Parameter: %{hovertext}<br>Norm: %{text:.4f}')
    fig_norm.update_yaxes(title_text='Norm')
    fig_norm.update_xaxes(title_text='Parameter Name')
    fig_norm.update_layout(title_text='Norm of Model Parameters')

    # Create the log of standard deviation bar chart
    fig_log_std = go.Figure()
    fig_log_std.add_trace(go.Bar(x=df['name'], y=df['log_std'], text=df['log_std'], textposition='outside', 
                                 hovertext=df['name_label'], marker_color=df['color']))
    fig_log_std.update_traces(texttemplate='%{text:.4f}', hovertemplate='Parameter: %{hovertext}<br>Log Std: %{text:.4f}')
    fig_log_std.update_yaxes(title_text='Log of Standard Deviation')
    fig_log_std.update_xaxes(title_text='Parameter Name')
    fig_log_std.update_layout(title_text='Log of Standard Deviation of Model Parameters')
    # add a horizontal line at y = 1.69
    fig_log_std.add_shape(type='line', x0=0, y0=-np.log(0.02), x1=len(df['name']), y1=-np.log(0.02), line=dict(color='red', width=2, dash='dash'))

    # Show both plots
    fig_mean.show()
    fig_log_std.show()
    fig_norm.show()


df = get_param_stats(DT)
plot_param_stats(df)

# df = get_param_stats(lstm_model)
# plot_param_stats(df)


divide by zero encountered in log



In [None]:


import gymnasium as gym 
from src.models.trajectory_lstm import TrajectoryLSTM

lstm_config = LSTMModelConfig(env_config)
lstm_model = TrajectoryLSTM(lstm_config)

print(lstm_config)
summary(lstm_model, depth=4)