# PPO Transformer Model



In [73]:
import torch.nn as nn
from dataclasses import dataclass, field
import gymnasium as gym

@dataclass
class TransformerModelConfig():
    d_model: int = 128
    n_heads: int = 4
    d_mlp: int = 256
    n_layers: int = 2
    n_ctx: int = 3
    layer_norm: bool = False
    linear_time_embedding: bool = False
    state_embedding_type: str = 'grid'
    time_embedding_type: str = 'learned'
    seed: int = 1
    device: str = 'cpu'

    def __post_init__(self):
        assert self.d_model % self.n_heads == 0
        self.d_head = self.d_model // self.n_heads

transformer_config = TransformerModelConfig()

In [74]:


@dataclass
class EnvironmentConfig():
    env = None
    env_id: str = 'MiniGrid-Empty-8x8-v0'
    one_hot: bool = False
    fully_observed: bool = False
    max_steps: int = 1000
    seed: int = 1
    view_size: int = 7
    capture_video: bool = False
    video_dir: str = 'videos'
    render_mode: str = 'rgb_array'
    num_parralel_envs: int = 1
    action_space: None = None
    observation_space: None = None
    device: str = 'cpu'
    

    def __post_init__(self):

        env = gym.make(self.env_id)
        self.action_space = env.action_space or env.action_space
        self.observation_space = env.observation_space or env.observation_space


environment_config = EnvironmentConfig()
environment_config.action_space

Discrete(7)

In [75]:
environment_config.action_space

Discrete(7)

In [92]:
import torch 
from typing import Dict, Union, Tuple
from torchtyping import TensorType as TT
from abc import ABC, abstractmethod
from einops import rearrange
from src.decision_transformer.model import StateEncoder, PosEmbedTokens
from src.decision_transformer.model import DecisionTransformer as DecisionTransformerOld
import numpy as np
from transformer_lens import HookedTransformer, HookedTransformerConfig
from gymnasium.spaces import Box, Dict
# from transformer_lens.HookedTransformerConfig import HookedTransformerConfig

class TrajectoryTransformer(nn.Module):
    '''
    Base Class for trajectory modelling transformers including:
        - Decision Transformer (offline, RTG, (R,s,a))
        - Online Transformer (online, reward, (s,a,r))
    '''

    def __init__(
        self,
        transformer_config: TransformerModelConfig,
        environment_config: EnvironmentConfig
        ):
        super().__init__()

        self.transformer_config = transformer_config
        self.environment_config = environment_config


        self.action_embedding = nn.Sequential(
            nn.Embedding(environment_config.action_space.n+1, self.transformer_config.d_model))
        self.reward_embedding = nn.Sequential(
            nn.Linear(1, self.transformer_config.d_model, bias=False))
        self.time_embedding = self.initialize_time_embedding()
        self.state_embedding = self.initialize_state_embedding()
        
        # Initialize weights
        nn.init.normal_(
            self.action_embedding[0].weight, mean=0.0, std=1/((environment_config.action_space.n+1 + 1)*self.transformer_config.d_model))
        nn.init.normal_(
            self.reward_embedding[0].weight, mean=0.0, std=1/self.transformer_config.d_model)
        
        self.transformer = self.initialize_easy_transformer()

        self.action_predictor = nn.Linear(self.transformer_config.d_model, environment_config.action_space.n)
        self.reward_predictor = nn.Linear(self.transformer_config.d_model, 1)
        self.initialize_state_predictor()
        
    def forward(self,
                # has variable shape, starting with batch, position
                states: TT[...],
                actions: TT["batch", "position"],  # noqa: F821
                rtgs: TT["batch", "position"],  # noqa: F821
                timesteps: TT["batch", "position"],  # noqa: F821
                ) -> Tuple[TT[...], TT["batch", "position"], TT["batch", "position"]]:  # noqa: F821

        batch_size = states.shape[0]
        seq_length = states.shape[1]

        # embed states and recast back to (batch, block_size, n_embd)
        token_embeddings = self.to_tokens(states, actions, rtgs, timesteps)
        x = self.transformer(token_embeddings)
        state_preds, action_preds, reward_preds = self.get_logits(
            x, batch_size, seq_length)

        return state_preds, action_preds, reward_preds

    def get_logits(self, x, batch_size, seq_length):

        # TODO replace with einsum
        x = x.reshape(batch_size, seq_length, 3,
                      self.transformer_config.d_model).permute(0, 2, 1, 3)

        # predict next return given state and action
        reward_preds = self.predict_rewards(x[:, 2])
        # predict next state given state and action
        state_preds = self.predict_states(x[:, 2])
        # predict next action given state
        action_preds = self.predict_actions(x[:, 1])

        return state_preds, action_preds, reward_preds

    def predict_rewards(self, x):
        return self.reward_predictor(x)  

    def predict_states(self, x):
        return self.state_predictor(x)

    def predict_actions(self, x):
        return self.action_predictor(x)

    def to_tokens(self, states, actions, rtgs, timesteps):

        # embed states and recast back to (batch, block_size, n_embd)
        state_embeddings = self.get_state_embeddings(
            states)  # batch_size, block_size, n_embd
        action_embeddings = self.get_action_embeddings(
            actions) if actions is not None else None  # batch_size, block_size, n_embd or None
        reward_embeddings = self.get_reward_embeddings(
            rtgs)  # batch_size, block_size, n_embd
        time_embeddings = self.get_time_embeddings(
            timesteps)  # batch_size, block_size, n_embd

        # use state_embeddings, actions, rewards to go and
        token_embeddings = self.get_token_embeddings(
            state_embeddings=state_embeddings,
            action_embeddings=action_embeddings,
            reward_embeddings=reward_embeddings,
            time_embeddings=time_embeddings
        )
        return token_embeddings
    
    def get_time_embedding(self, timesteps):

        assert timesteps.max() <= self.max_timestep, "timesteps must be less than max_timesteps"

        block_size = timesteps.shape[1]
        timesteps = rearrange(
            timesteps, 'batch block time-> (batch block) time')
        time_embeddings = self.time_embedding(timesteps)
        if self.time_embedding_type != 'linear':
            time_embeddings = time_embeddings.squeeze(-2)
        time_embeddings = rearrange(
            time_embeddings, '(batch block) n_embd -> batch block n_embd', block=block_size)
        return time_embeddings

    def get_state_embedding(self, states):
        # embed states and recast back to (batch, block_size, n_embd)
        block_size = states.shape[1]
        if self.state_embedding_type == "CNN":
            states = rearrange(
                states, 'batch block height width channel -> (batch block) channel height width')
            state_embeddings = self.state_encoder(states.type(
                torch.float32).contiguous())  # (batch * block_size, n_embd)
        else:
            states = rearrange(
                states, 'batch block height width channel -> (batch block) (channel height width)')
            state_embeddings = self.state_encoder(states.type(
                torch.float32).contiguous())  # (batch * block_size, n_embd)
        state_embeddings = rearrange(
            state_embeddings, '(batch block) n_embd -> batch block n_embd', block=block_size)
        return state_embeddings

    def get_reward_embedding(self, rtgs):
        block_size = rtgs.shape[1]
        rtgs = rearrange(rtgs, 'batch block rtg -> (batch block) rtg')
        rtg_embeddings = self.reward_embedding(rtgs.type(torch.float32))
        rtg_embeddings = rearrange(
            rtg_embeddings, '(batch block) n_embd -> batch block n_embd', block=block_size)
        return rtg_embeddings

    def get_action_embedding(self, actions):
        block_size = actions.shape[1]
        actions = rearrange(
            actions, 'batch block action -> (batch block) action')
        # I don't see why we need this but we do? Maybe because of the sequential?
        action_embeddings = self.action_embedding(actions).flatten(1)
        action_embeddings = rearrange(
            action_embeddings, '(batch block) n_embd -> batch block n_embd', block=block_size)
        return action_embeddings

    @abstractmethod
    def get_token_embeddings(self, states, actions, rtgs, timesteps):
        '''
        Returns the token embeddings for the transformer input.
        Note that different subclasses will have different token embeddings
        such as the DecisionTransformer which will use RTG (placed before the
        state embedding).
        
        Args:
            states: (batch, position, state_dim)
            actions: (batch, position)
            rtgs: (batch, position)
            timesteps: (batch, position)

        Returns:
            token_embeddings: (batch, position, n_embd)
        '''
        pass

    def get_action(self, states, actions, rewards, timesteps):

        state_preds, action_preds, reward_preds = self.forward(
            states, actions, rewards, timesteps)

        # get the action prediction
        action_preds = action_preds[:, -1, :]  # (batch, n_actions)
        action = torch.argmax(action_preds, dim=-1)  # (batch)
        return action

    def initialize_time_embedding(self):

        if not self.transformer_config.linear_time_embedding:
            self.time_embedding = nn.Embedding(
                self.environment_config.max_steps+1, self.transformer_config.d_model)
        else:
            self.time_embedding = nn.Linear(
                1, self.transformer_config.d_model)
            
        return self.time_embedding
    
    def initialize_state_embedding(self):

        if self.transformer_config.state_embedding_type == 'CNN':
            state_embedding = StateEncoder(self.transformer_config.d_model)
        else:
            n_obs = np.prod(self.environment_config.observation_space['image'].shape)
            state_embedding = nn.Linear(n_obs, self.transformer_config.d_model, bias=False)
            nn.init.normal_(state_embedding.weight, mean=0.0, std=0.02)

        return state_embedding
    
    def initialize_state_predictor(self):
        if isinstance(env.observation_space, Box):
            self.predict_states = nn.Linear(
                self.transformer_config.d_model, np.prod(self.environment_config.observation_space.shape))
        elif isinstance(env.observation_space, Dict):
            self.predict_states = nn.Linear(
                self.transformer_config.d_model, np.prod(self.environment_config.observation_space['image'].shape))

    def initialize_easy_transformer(self):

        # Transformer
        cfg = HookedTransformerConfig(
            n_layers=self.transformer_config.n_layers,
            d_model=self.transformer_config.d_model,
            d_head=self.transformer_config.d_head,
            n_heads=self.transformer_config.n_heads,
            d_mlp=self.transformer_config.d_mlp,
            d_vocab=self.transformer_config.d_model,
            # 3x the max timestep so we have room for an action, reward, and state per timestep
            n_ctx=self.transformer_config.n_ctx,
            act_fn="relu",
            normalization_type= "LN" if self.transformer_config.layer_norm else None,
            attention_dir="causal",
            d_vocab_out=self.transformer_config.d_model,
            seed=self.transformer_config.seed,
            device=self.transformer_config.device,
        )

        assert cfg.attention_dir == "causal", "Attention direction must be causal"
        # assert cfg.normalization_type is None, "Normalization type must be None"

        transformer = HookedTransformer(cfg)

        # Because we passing in tokens, turn off embedding and update the position embedding
        transformer.embed = nn.Identity()
        transformer.pos_embed = PosEmbedTokens(cfg)
        # initialize position embedding
        nn.init.normal_(transformer.pos_embed.W_pos,
                        cfg.initializer_range)
        # don't unembed, we'll do that ourselves.
        transformer.unembed = nn.Identity()
        
        return transformer

class DecisionTransformer(TrajectoryTransformer):

    def __init__(self, environment_config, transformer_config, **kwargs):
        super().__init__(
            environment_config = environment_config,
            transformer_config = transformer_config, 
            **kwargs)
    
    def get_token_embeddings(self,
                             state_embeddings,
                             time_embeddings,
                             action_embeddings,
                             reward_embeddings,
                             targets=None):
        '''
        We need to compose the embeddings for:
            - states
            - actions
            - rewards
            - time

        Handling the cases where:
        1. we are training:
            1. we may not have action yet (reward, state)
            2. we have (action, state, reward)...
        2. we are evaluating:
            1. we have a target "a reward" followed by state

        1.1 and 2.1 are the same, but we need to handle the target as the initial reward.

        '''
        batches = state_embeddings.shape[0]

        reward_embeddings = reward_embeddings + time_embeddings
        state_embeddings = state_embeddings + time_embeddings

        if action_embeddings is not None:
            action_embeddings = action_embeddings + time_embeddings
        if targets:
            targets = targets + time_embeddings

        timesteps = time_embeddings.shape[1]  # number of timesteps
        if action_embeddings is not None:
            trajectory_length = timesteps*3
        else:
            trajectory_length = 2  # one timestep, no action yet

        # create the token embeddings
        token_embeddings = torch.zeros(
            (batches, trajectory_length, self.transformer_config.d_model),
            dtype=torch.float32, device=state_embeddings.device)  # batches, blocksize, n_embd

        if action_embeddings is not None:
            token_embeddings[:, ::3, :] = reward_embeddings
            token_embeddings[:, 1::3, :] = state_embeddings
            token_embeddings[:, 2::3, :] = action_embeddings
        else:
            token_embeddings[:, 0, :] = reward_embeddings[:, 0, :]
            token_embeddings[:, 1, :] = state_embeddings[:, 0, :]

        if targets is not None:
            token_embeddings[:, 0, :] = targets[:, 0, :]

        return token_embeddings


transformer_config = TransformerModelConfig()
environment_config = EnvironmentConfig()
decision_transformer_new = DecisionTransformer(
    transformer_config = transformer_config,
    environment_config = environment_config,
)

env = gym.make(environment_config.env_id)

decision_transformer_old = DecisionTransformerOld(
    env = env,
    d_model =transformer_config.d_model,
    n_heads = transformer_config.n_heads,
    d_mlp = transformer_config.d_mlp,
    n_layers = transformer_config.n_layers,
    layer_norm = transformer_config.layer_norm,
    state_embedding_type = transformer_config.state_embedding_type,
    time_embedding_type = "learned",
    max_timestep = environment_config.max_steps,
    n_ctx= transformer_config.n_ctx,
    seed= transformer_config.seed,
    device="cpu"
)

import pandas as pd
import torch

def compare_models(model1, model2):
    # Get the parameter dictionaries for both models
    params1 = model1.state_dict()
    params2 = model2.state_dict()

    # Create a list of dictionaries containing information about each parameter
    param_info = []
    for param_name in params1.keys():
        # Check if the parameter exists in both models
        if param_name in params2:
            # Get the shape of the parameter
            shape1 = params1[param_name].shape
            shape2 = params2[param_name].shape

            # Add the parameter info to the list
            param_info.append({'Parameter': param_name,
                               'Model 1 Shape': shape1,
                               'Model 2 Shape': shape2,
                               'Sizes Match': shape1 == shape2})
        else:
            # Add a placeholder entry for the missing parameter
            param_info.append({'Parameter': param_name,
                               'Model 1 Shape': shape1,
                               'Model 2 Shape': 'MISSING',
                               'Sizes Match': False})

    # Add entries for any parameters that exist in model2 but not model1
    for param_name in params2.keys():
        if param_name not in params1:
            param_info.append({'Parameter': param_name,
                               'Model 1 Shape': 'MISSING',
                               'Model 2 Shape': params2[param_name].shape,
                               'Sizes Match': False})

    # Convert the list of dictionaries to a pandas DataFrame
    df = pd.DataFrame(param_info)

    # Print the DataFrame
    return df

compare_models(decision_transformer_new, decision_transformer_old)

Unnamed: 0,Parameter,Model 1 Shape,Model 2 Shape,Sizes Match
0,action_embedding.0.weight,"(8, 128)","(8, 128)",True
1,reward_embedding.0.weight,"(128, 1)","(128, 1)",True
2,time_embedding.weight,"(1001, 128)","(1001, 128)",True
3,state_embedding.weight,"(1001, 128)",MISSING,False
4,transformer.pos_embed.W_pos,"(3, 128)","(3, 128)",True
5,transformer.blocks.0.attn.W_Q,"(4, 128, 32)","(4, 128, 32)",True
6,transformer.blocks.0.attn.W_K,"(4, 128, 32)","(4, 128, 32)",True
7,transformer.blocks.0.attn.W_V,"(4, 128, 32)","(4, 128, 32)",True
8,transformer.blocks.0.attn.W_O,"(4, 32, 128)","(4, 32, 128)",True
9,transformer.blocks.0.attn.b_Q,"(4, 32)","(4, 32)",True


In [101]:
# get the total number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(decision_transformer_new)

432411

In [102]:
count_parameters(decision_transformer_old)

432411

To do:
- prove the refactor hasn't introduced any bugs:
    - check par count changes
    - unit test components
    - check that the model can learn something
    - check that each hyperparameter works