# MiniGrid Environment

Try out the environment by running the following command:


```bash
python -m minigrid.manual_control
```

Later we can benchmark against torch-rl 


In [271]:
import gymnasium as gym
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper

env = gym.make('MiniGrid-Empty-8x8-v0')
env = RGBImgPartialObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
obs, _ = env.reset() # This now produces an RGB tensor only

In [274]:
import torch as t 
import plotly.express as px
obs = t.tensor(obs)
obs.shape
px.imshow(obs)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [275]:
env = gym.make('MiniGrid-Empty-8x8-v0')
env = RGBImgPartialObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
obs, _ = env.reset() # This now produces an RGB tensor only

# take several actions, store the observations, actions, returns and timesteps
all_obs = []
all_actions = []
all_returns = []
all_timesteps = []


for i in range(10):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    all_obs.append(obs)
    all_actions.append(action)
    all_returns.append(reward)
    all_timesteps.append(i)

# convert to tensors.unsqueeze(0)
all_obs = t.tensor(all_obs)
all_actions = t.tensor(all_actions).reshape(-1, 1)
all_returns = t.tensor(all_returns)
all_returns = t.randn((10, 1))
all_returns_to_go = all_returns.flip(0).cumsum(0).flip(0).reshape(-1, 1)
all_timesteps = t.tensor(all_timesteps).reshape(-1, 1)

In [276]:
print(all_returns.shape)

torch.Size([10, 1])


In [277]:
print(all_returns_to_go.shape)

torch.Size([10, 1])


# Getting a basic architecture


In [278]:
# for the grid world environment we will a small CNN to extract features from the image
# we will use the same CNN as in the original paper

obs, _, _, _, _ = env.step(2)
obs = t.tensor(obs)

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# to do: make this a custom class with hooks from transformer lense
# to do: work out how to feature visualize this

class StateEncoder(nn.Module):
    def __init__(self, n_embed):
        super(StateEncoder, self).__init__()
        self.n_embed = n_embed
        # input has shape 56 x 56 x 3
        # output has shape 1 x 1 x 512
        self.conv1 = nn.Conv2d(3, 32, 8, stride=4, padding=0) # 56 -> 13
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=0) # 13 -> 5
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=0) # 5 -> 3
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(576, n_embed)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = self.fc(x)
        x = F.relu(x)
        return x

# we will use the same CNN as in the original paper
cnn = StateEncoder(64).to("cpu")
x = obs.unsqueeze(0).to(t.float32)
x = rearrange(x, 'b h w c-> b c h w')
cnn(x)

tensor([[ 0.0000,  9.9623,  0.0000,  1.3285,  0.0000,  0.0000,  8.6895,  0.0000,
          0.0000, 14.3877,  6.6269,  0.0000,  0.0000, 10.2458,  9.8999,  0.0000,
          5.6620,  0.0000, 17.3232,  8.7777,  0.0000,  0.0000,  1.6546, 10.5847,
          0.0000,  0.0000,  7.8607,  0.0000, 12.7252,  1.4921,  5.1477,  0.0000,
          0.0000,  6.2238,  0.0000,  0.0000,  0.0000,  6.8364,  3.2929,  0.0000,
          2.7457,  4.7237,  0.0000, 10.5503,  2.5434,  0.0000,  0.1653,  9.5687,
          6.8229, 13.4620,  6.4368,  0.0000,  0.0000, 11.4811,  0.0000,  0.0000,
          0.0000, 11.6980,  0.0000, 11.9936, 12.6789,  0.0000,  9.6849,  0.0000]],
       grad_fn=<ReluBackward0>)

For reference: https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/model_atari.py

In [280]:

    # def forward(self, R: t.tensor, s: t.tensor, a: t.tensor, t: t.tensor):
    #     '''
    #     R: return
    #     s: state
    #     a: action
    #     t: timestep
    #     '''
    #     s = s.to(torch.float32)
    #     s = rearrange(s, 'b h w c-> b c h w')

    #     pos_emb = self.pos_embedding(t)
    #     state_emb = self.state_embedding(s) + pos_emb
    #     action_emb = self.action_embeddings(a) + pos_emb
    #     ret_emb = self.ret_emb(R) + pos_emb

    #     input_embeds = torch.stack([state_emb, action_emb, ret_emb], dim=1)
    #     print(input_embeds.shape)
    #     print(input_embeds.dtype)
    #     input_embeds = rearrange(input_embeds, 'batch sar block_size d_model -> batch sar block_size d_model')
        
    #     x = self.transformer(input_embeds)
        
    #     return x


In [None]:
print("------------------")

decision_transformer = DecisionTransformer(env, model_type='reward_conditioned')
print(all_obs.unsqueeze(0).shape)
logits, _ = decision_transformer(
    states = all_obs.repeat(3,1, 1, 1, 1), 
    actions = all_actions.repeat(3,1, 1),
    rtgs = all_returns_to_go.repeat(3,1, 1),
    timesteps = all_timesteps.repeat(3,1, 1)
)
print(all_obs.repeat(3,1, 1, 1, 1).shape)
print(logits.shape)
print(logits[0].softmax(-1).argmax(-1))


print("------------------")

decision_transformer = DecisionTransformer(env, model_type='reward_conditioned')
print(all_obs.unsqueeze(0).shape)
logits, _ = decision_transformer(
    states = all_obs.repeat(3,1, 1, 1, 1), 
    actions = None,#all_actions.repeat(3,1, 1),
    rtgs = all_returns_to_go.repeat(3,1, 1),
    timesteps = all_timesteps.repeat(3,1, 1)
)
print(all_obs.repeat(3,1, 1, 1, 1).shape)
print(logits.shape)
print(logits[0].softmax(-1).argmax(-1))

Notes:
- This turned out to be really complicated. 
- Specifically:
    - it seems like the model has very different formulations during reward_conditioned vs naive. And also changes size if targets are used.
    - This includes things like the token embeddings for actions not going through in naive mode, but going through in reward_conditioned mode.
        - This is fine because a model is only ever of one type
    - However, the model isn't using padding of any kind? I should look for evidence of this. 

In [233]:
all_timesteps.unsqueeze(1).shape

torch.Size([10, 1, 1])

In [228]:
all_timesteps.repeat(3,1, 1).shape

torch.Size([3, 10, 1])

In [203]:
print(all_actions.repeat(3,1, 1).shape)
print(all_returns_to_go.repeat(3,1, 1).shape)
print(all_timesteps.repeat(3,1, 1).shape)
print(all_obs.repeat(3,1, 1, 1, 1).shape)

torch.Size([3, 10, 1])
torch.Size([3, 10, 1])
torch.Size([3, 10, 1])
torch.Size([3, 10, 56, 56, 3])


In [117]:

cfg = EasyTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=128,
    d_vocab= 64,
    n_ctx= 30,
    act_fn="relu",
    normalization_type=None,
    attention_dir="causal",
    positional_embedding_type="learned",
    d_vocab_out=env.action_space.n, #
)


transformer = EasyTransformer(cfg)

In [136]:
rearrange(t.tensor([[1,2],[4,5],[7,8]]), 'b (n d) -> b n d', n=1, d=2)

tensor([[[1, 2]],

        [[4, 5]],

        [[7, 8]]])

In [None]:

# reference code:

# all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd
# position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :]
# x = self.drop(token_embeddings + position_embeddings)

In [2]:
import torch as t 
import gymnasium as gym
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper

env = gym.make('MiniGrid-Empty-8x8-v0')
env = RGBImgPartialObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
obs, _ = env.reset() # This now produces an RGB tensor only

# take several actions, store the observations, actions, returns and timesteps
all_obs = []
all_actions = []
all_returns = []
all_timesteps = []


for i in range(10):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    all_obs.append(obs)
    all_actions.append(action)
    all_returns.append(reward)
    all_timesteps.append(i)

# convert to tensors.unsqueeze(0)
all_obs = t.tensor(all_obs).to(t.float32).unsqueeze(0)
all_actions = t.tensor(all_actions).reshape(-1, 1).unsqueeze(0)
all_returns = t.randn((10, 1))
all_returns_to_go = all_returns.flip(0).cumsum(0).flip(0).reshape(-1, 1).unsqueeze(0)
all_timesteps = t.tensor(all_timesteps).reshape(-1, 1).unsqueeze(0)


  all_obs = t.tensor(all_obs).to(t.float32).unsqueeze(0)


In [6]:
from typing import Union, Dict
import einops 
from torchtyping import TensorType as TT
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
import torch 

import torch.nn as nn
from src.model import DecisionTransformer

# Positional Embeddings
class PosEmbedTokens(nn.Module):
    def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
        super().__init__()
        if isinstance(cfg, Dict):
            cfg = HookedTransformerConfig.from_dict(cfg)
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty(self.cfg.n_ctx, self.cfg.d_model))

    def forward(
        self, tokens: TT["batch", "position"], past_kv_pos_offset: int = 0
    ) -> TT["batch", "position", "d_model"]:
        """Tokens have shape [batch, pos]
        Output shape [pos, d_model] - will be broadcast along batch dim"""

        tokens_length = tokens.size(-2)
        pos_embed = self.W_pos[:tokens_length, :]  # [pos, d_model]
        broadcast_pos_embed = einops.repeat(
            pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
        )  # [batch, pos, d_model]
        return broadcast_pos_embed


decision_transformer = DecisionTransformer(env)
cfg = decision_transformer.transformer.cfg
decision_transformer.transformer.pos_embed = PosEmbedTokens(cfg)


logits, _ = decision_transformer(
    states = all_obs,
    actions = all_actions,
    rtgs = all_returns_to_go,
    timesteps = all_timesteps
)