# MiniGrid Environment

Try out the environment by running the following command:


```bash
python -m minigrid.manual_control
```

Later we can benchmark against torch-rl 


In [68]:
import gymnasium as gym
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper

env = gym.make('MiniGrid-Empty-8x8-v0')
env = RGBImgPartialObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
obs, _ = env.reset() # This now produces an RGB tensor only

In [None]:
env

In [64]:
import torch as t 
import plotly.express as px
obs = t.tensor(obs)
obs.shape
px.imshow(obs)

In [198]:
env = gym.make('MiniGrid-Empty-8x8-v0')
env = RGBImgPartialObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field
obs, _ = env.reset() # This now produces an RGB tensor only

# take several actions, store the observations, actions, returns and timesteps
all_obs = []
all_actions = []
all_returns = []
all_timesteps = []


for i in range(10):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    all_obs.append(obs)
    all_actions.append(action)
    all_returns.append(reward)
    all_timesteps.append(i)

# convert to tensors.unsqueeze(0)
all_obs = t.tensor(all_obs)
all_actions = t.tensor(all_actions).reshape(-1, 1)
all_returns = t.tensor(all_returns)
all_returns = t.randn((10, 1))
all_returns_to_go = all_returns.flip(0).cumsum(0).flip(0).reshape(-1, 1)
all_timesteps = t.tensor(all_timesteps).reshape(-1, 1)

In [199]:
print(all_returns.shape)

torch.Size([10, 1])


In [200]:
print(all_returns_to_go.shape)

torch.Size([10, 1])


# Getting a basic architecture


In [174]:
# for the grid world environment we will a small CNN to extract features from the image
# we will use the same CNN as in the original paper

obs, _, _, _, _ = env.step(2)
obs = t.tensor(obs)

import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# to do: make this a custom class with hooks from transformer lense
# to do: work out how to feature visualize this

class StateEncoder(nn.Module):
    def __init__(self, n_embed):
        super(StateEncoder, self).__init__()
        self.n_embed = n_embed
        # input has shape 56 x 56 x 3
        # output has shape 1 x 1 x 512
        self.conv1 = nn.Conv2d(3, 32, 8, stride=4, padding=0) # 56 -> 13
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=0) # 13 -> 5
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=0) # 5 -> 3
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(576, n_embed)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = self.fc(x)
        x = F.relu(x)
        return x

# we will use the same CNN as in the original paper
cnn = StateEncoder(64).to("cpu")
x = obs.unsqueeze(0).to(t.float32)
x = rearrange(x, 'b h w c-> b c h w')
cnn(x)

tensor([[ 3.1732,  0.0000,  8.5591,  9.1271,  0.0000,  9.6672,  8.2926,  0.0000,
          0.0000,  3.4197,  0.0000,  0.0000,  0.0000,  0.0000, 13.0668,  0.6257,
          3.3617,  0.8560,  0.0000,  0.0000, 10.8386,  1.3225, 13.4223,  4.6092,
          8.5116,  3.6104,  6.0747,  2.2926,  0.0000, 14.3694,  0.0000, 14.2668,
          0.0000,  5.7248,  7.7609,  0.0000,  2.8358,  0.0000,  3.1188, 11.4423,
          0.0000,  0.8637,  0.0000,  0.0000,  0.0000,  0.0000,  0.4379,  6.3394,
         12.1237,  0.0000,  0.0000,  0.0000,  0.2780,  0.0000, 11.5842,  0.0000,
          0.0000, 10.4421,  0.0000,  5.0201,  1.7064,  6.3745,  5.3732,  9.4602]],
       grad_fn=<ReluBackward0>)

For reference: https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/model_atari.py

In [262]:
import torch
import torch.nn as nn
from transformer_lens import EasyTransformer, EasyTransformerConfig

class DecisionTransformer(torch.nn.Module):
    def __init__(self, env, max_game_length: int = 1000, model_type = "naive"):
        '''
        model = Classifier(cfg)
        '''
        super().__init__()

        self.model_type = model_type
        self.d_model = 64
        self.block_size = 10
        vocab_size = env.action_space.n
        self.max_timestep = max_game_length
        self.ctx_size = self.block_size 

        if self.model_type == "transformer":
            self.ctx_size = self.block_size

        # Embedding layers
        self.pos_emb = nn.Parameter(torch.zeros(1, self.block_size + 1, self.d_model))
        self.global_pos_emb = nn.Parameter(torch.zeros(1, self.max_timestep+1, self.d_model))
        self.state_encoder = StateEncoder(self.d_model)
        
        self.action_embeddings = nn.Sequential(nn.Embedding(env.action_space.n, self.d_model), nn.Tanh())
        nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02)

        self.ret_emb = nn.Sequential(nn.Linear(1, self.d_model), nn.Tanh())

        # Transformer

        cfg = EasyTransformerConfig(
            n_layers=2,
            d_model=self.d_model,
            d_head=32,
            n_heads=2,
            d_mlp=128,
            d_vocab= 64,
            n_ctx= self.ctx_size,
            act_fn="relu",
            normalization_type=None,
            attention_dir="causal",
            d_vocab_out=env.action_space.n, #
        )

        assert cfg.attention_dir == "causal", "Attention direction must be causal"
        assert cfg.normalization_type is None, "Normalization type must be None"


        self.transformer = EasyTransformer(cfg)
        self.transformer.embed = nn.Identity()
        self.transformer.pos_embed.W_pos.data[:] = 0.0
        self.transformer.pos_embed.W_pos.requires_grad = False


        
    # def forward(self, R: t.tensor, s: t.tensor, a: t.tensor, t: t.tensor):
    #     '''
    #     R: return
    #     s: state
    #     a: action
    #     t: timestep
    #     '''
    #     s = s.to(torch.float32)
    #     s = rearrange(s, 'b h w c-> b c h w')

    #     pos_emb = self.pos_embedding(t)
    #     state_emb = self.state_embedding(s) + pos_emb
    #     action_emb = self.action_embeddings(a) + pos_emb
    #     ret_emb = self.ret_emb(R) + pos_emb

    #     input_embeds = torch.stack([state_emb, action_emb, ret_emb], dim=1)
    #     print(input_embeds.shape)
    #     print(input_embeds.dtype)
    #     input_embeds = rearrange(input_embeds, 'batch sar block_size d_model -> batch sar block_size d_model')
        
    #     x = self.transformer(input_embeds)
        
    #     return x

    # state, action, and return
    def forward(self, states, actions, targets=None, rtgs=None, timesteps=None):
        # states: (batch, block_size, 56, 56, 3)
        # actions: (batch, block_size, 1)
        # targets: (batch, block_size, 1)
        # rtgs: (batch, block_size, 1)
        # timesteps: (batch, 1, 1) # this seems wrong because the time should be different for each element in the each block (incrememnting by 1)

        # asset all batch sizes are the same
        if actions is not None:
            assert states.shape[0] == actions.shape[0] == rtgs.shape[0] == timesteps.shape[0], "batch sizes must be the same"
        if targets is not None:
            assert states.shape[0] == targets.shape[0]

        
        # assert all block sizes are the same
        if actions is not None:
            assert states.shape[1] == actions.shape[1] == rtgs.shape[1], "block sizes must be the same"
        if targets is not None:
            assert states.shape[1] == targets.shape[1]

        batches = states.shape[0]
        block_size = states.shape[1]
        assert batches == 3 
        assert block_size == 10

        # embed states and recast back to (batch, block_size, n_embd)
        states = rearrange(states, 'batch block height width channel -> (batch block) channel height width')
        state_embeddings = self.state_encoder(states.type(torch.float32).contiguous()) # (batch * block_size, n_embd)
        state_embeddings = rearrange(state_embeddings, '(batch block) n_embd -> batch block n_embd', block=block_size)
        
        # generate time embeddings
        all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batches, dim=0) # batch_size, trajectory_length, n_embd
        specific_global_pos_emb = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.d_model, dim=-1)) # batch_size, block_size, n_embd

        # reward conditioned in for evaluation
        if actions is not None and self.model_type == 'reward_conditioned': 
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
            action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)

            trajectory_length = block_size*3 - int(targets is None)

            token_embeddings = torch.zeros((batches, trajectory_length, self.d_model), dtype=torch.float32, device=state_embeddings.device)
            token_embeddings[:,::3,:] = rtg_embeddings
            token_embeddings[:,1::3,:] = state_embeddings
            token_embeddings[:,2::3,:] = action_embeddings[:,-block_size + int(targets is None):,:]
        # eval at the first time step
        elif actions is None and self.model_type == 'reward_conditioned': # only happens at very first timestep of evaluation
            rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))

            trajectory_length = block_size*2 
            token_embeddings = torch.zeros((batches, trajectory_length, self.d_model), dtype=torch.float32, device=state_embeddings.device)

            # really just [:,0,:]
            token_embeddings[:,::2,:] = rtg_embeddings - int(targets is None)
            # really just [:,1,:]
            token_embeddings[:,1::2,:] = state_embeddings  #+ specific_global_pos_emb[:, -block_size:, :]

        elif actions is not None and self.model_type == 'naive':
            action_embeddings = self.action_embeddings(actions.type(torch.long).squeeze(-1)) # (batch, block_size, n_embd)

            trajectory_length = block_size*2 - int(targets is None)
            token_embeddings = torch.zeros((batches, trajectory_length, self.d_model), dtype=torch.float32, device=state_embeddings.device)


            # it might make more sense to add the position embedding for time in here.
            token_embeddings[:,::2,:] = state_embeddings + specific_global_pos_emb # this line breaks if batch_size = 1!
            token_embeddings[:,1::2,:] = action_embeddings[:,-block_size + int(targets is None):,:] + specific_global_pos_emb[:, -block_size + int(targets is None):, :]
        elif actions is None and self.model_type == 'naive': # only happens at very first timestep of evaluation
            token_embeddings = state_embeddings + specific_global_pos_emb
        else:
            raise NotImplementedError()

        # position embeddings for transformer logic
        # position_embeddings = self.pos_emb[:, :trajectory_length, :]

        # no dropout 

        # x has shape 

        # let's just let easy transfomer add the position embeddings
        x = token_embeddings[:,-block_size:,:] # truncate to block size which should be less that trajectory_length (which varies depending on model type) 
        
        #+ position_embeddings # batch, trajectory_size, n_embd


        logits = self.transformer(x)

        print("logits shape is: ", logits.shape)

        if actions is not None and self.model_type == 'reward_conditioned':
            logits = logits[:, 1::3, :] # only keep predictions from state_embeddings (predictions over actions)
        elif actions is None and self.model_type == 'reward_conditioned':
            logits = logits[:, 1:, :]
        elif actions is not None and self.model_type == 'naive':
            logits = logits[:, ::2, :] # only keep predictions from state_embeddings (predictions over actions)
        elif actions is None and self.model_type == 'naive':
            logits = logits # for completeness
        else:
            raise NotImplementedError()

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

        return logits, loss


decision_transformer = DecisionTransformer(env)
print(all_obs.unsqueeze(0).shape)
logits, _ = decision_transformer(
    states = all_obs.repeat(3,1, 1, 1, 1), 
    actions = all_actions.repeat(3,1, 1),
    rtgs = all_returns_to_go.repeat(3,1, 1),
    timesteps = all_timesteps.repeat(3,1, 1)
)
print(all_obs.repeat(3,1, 1, 1, 1).shape)
print(logits.shape)
print(logits[0].softmax(-1).argmax(-1))

print("------------------")

decision_transformer = DecisionTransformer(env)
print(all_obs.unsqueeze(0).shape)
logits, _ = decision_transformer(
    states = all_obs.repeat(3,1, 1, 1, 1), 
    actions = None,#all_actions.repeat(3,1, 1),
    rtgs = all_returns_to_go.repeat(3,1, 1),
    timesteps = all_timesteps.repeat(3,1, 1)
)
print(all_obs.repeat(3,1, 1, 1, 1).shape)
print(logits.shape)
print(logits[0].softmax(-1).argmax(-1))

print("------------------")

decision_transformer = DecisionTransformer(env, model_type='reward_conditioned')
print(all_obs.unsqueeze(0).shape)
logits, _ = decision_transformer(
    states = all_obs.repeat(3,1, 1, 1, 1), 
    actions = all_actions.repeat(3,1, 1),
    rtgs = all_returns_to_go.repeat(3,1, 1),
    timesteps = all_timesteps.repeat(3,1, 1)
)
print(all_obs.repeat(3,1, 1, 1, 1).shape)
print(logits.shape)
print(logits[0].softmax(-1).argmax(-1))

torch.Size([1, 10, 56, 56, 3])
logits shape is:  torch.Size([3, 10, 7])
torch.Size([3, 10, 56, 56, 3])
torch.Size([3, 5, 7])
tensor([1, 1, 1, 1, 1])
------------------
torch.Size([1, 10, 56, 56, 3])
logits shape is:  torch.Size([3, 10, 7])
torch.Size([3, 10, 56, 56, 3])
torch.Size([3, 10, 7])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
------------------
torch.Size([1, 10, 56, 56, 3])


Notes:
- This turned out to be really complicated. 
- Specifically:
    - it seems like the model has very different formulations during reward_conditioned vs naive. And also changes size if targets are used.
    - This includes things like the token embeddings for actions not going through in naive mode, but going through in reward_conditioned mode.
        - This is fine because a model is only ever of one type
    - However, the model isn't using padding of any kind? I should look for evidence of this. 

In [233]:
all_timesteps.unsqueeze(1).shape

torch.Size([10, 1, 1])

In [228]:
all_timesteps.repeat(3,1, 1).shape

torch.Size([3, 10, 1])

In [203]:
print(all_actions.repeat(3,1, 1).shape)
print(all_returns_to_go.repeat(3,1, 1).shape)
print(all_timesteps.repeat(3,1, 1).shape)
print(all_obs.repeat(3,1, 1, 1, 1).shape)

torch.Size([3, 10, 1])
torch.Size([3, 10, 1])
torch.Size([3, 10, 1])
torch.Size([3, 10, 56, 56, 3])


In [117]:

cfg = EasyTransformerConfig(
    n_layers=2,
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=128,
    d_vocab= 64,
    n_ctx= 30,
    act_fn="relu",
    normalization_type=None,
    attention_dir="causal",
    positional_embedding_type="learned",
    d_vocab_out=env.action_space.n, #
)


transformer = EasyTransformer(cfg)

In [119]:
transformer.embed

Embed()

In [136]:
rearrange(t.tensor([[1,2],[4,5],[7,8]]), 'b (n d) -> b n d', n=1, d=2)

tensor([[[1, 2]],

        [[4, 5]],

        [[7, 8]]])

In [None]:

# reference code:

# all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, batch_size, dim=0) # batch_size, traj_length, n_embd
# position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :]
# x = self.drop(token_embeddings + position_embeddings)