In [1]:
import torch
from transition_decoder import DecoderTransformerConfig, DecoderTransformerTorch

# Create dummy config
config = DecoderTransformerConfig(
    emb_dim=32, # embedding dimension
    num_heads=4,
    attention_dropout_rate=0.1,
    dropout_rate=0.1,
    mlp_dim_factor=4,
    activation= 'relu' ,
    use_bias=True,
    max_rows=30,
    max_cols=30,
    vocab_size=10,
    num_layers=2
)

# For testing purposes, assign transformer_layer attribute
config.transformer_layer = config

decoder = DecoderTransformerTorch(config)
decoder.eval()  # set evaluation mode

# Create dummy inputs
B = 2
T_action = 5
T_state = 8  # ensure T_state >= 3 for grid tokens
embedded_action = torch.randn(B, T_action, config.emb_dim) 
embedded_state = torch.randn(B, T_state, config.emb_dim) 

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device:' , {device})

decoder = decoder.to(device)
embedded_action = embedded_action.to(device)
embedded_state = embedded_state.to(device)

# Run forward
shape_row_logits, shape_col_logits, grid_logits = decoder(embedded_action, embedded_state, dropout_eval=True)


Using device: mps for world_model/transformer.py
Using device: {'mps'}
embedded_action shape: torch.Size([2, 5, 32])
embedded_state shape: torch.Size([2, 8, 32])


In [2]:
# Expected output shapes
assert shape_row_logits.shape == (B, config.max_rows), f"Expected shape_row_logits {(B, config.max_rows)}, got {shape_row_logits.shape}"
assert shape_col_logits.shape == (B, config.max_cols), f"Expected shape_col_logits {(B, config.max_cols)}, got {shape_col_logits.shape}"
expected_grid_tokens = (T_action + T_state) - (T_action + 3)
assert grid_logits.shape == (B, expected_grid_tokens, config.vocab_size), f"Expected grid_logits {(B, expected_grid_tokens, config.vocab_size)}, got {grid_logits.shape}"

print('All tests passed!')

All tests passed!
