In [None]:
from collections import OrderedDict
import os
import torch
import torch.nn as nn
import tiktoken
import numpy as np

from mugato.data.utils import create_combined_dataloader
from mugato.mugato import MugatoConfig, Mugato, TransformerConfig
from mugato.nano_gpt import Block
from mugato.utils import data_home, select_device, generic_collate_fn
from mugato.tokenizer import Tokenizer

In [None]:
n_layer = 6
n_head = 4
n_embd = 512
bias = False
dropout = 0.0
block_size=768
batch_size=4
device = select_device()

In [None]:
# `create_combined_dataloader` will return a dataloader that cycles
# through all datasets and yields a batch from each one on each iteration.
text_tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_tokenizer)
train_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="train", block_size=block_size))
val_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="val", block_size=block_size))
test_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="test", block_size=block_size))

In [None]:
# Text tokens are encoded directly by the text tokenizer.
# If the text tokenizer encodes a string as [15496, 11, 995, 0],
# then the tokenizer will encode it as [[15496, 11, 995, 0]]
print(tokenizer.text_tokenizer.encode("Hello, world!"))
print(tokenizer.text_tokenizer.decode(tokenizer.text_tokenizer.encode("Hello, world!")))
print(tokenizer.encode_text("Hello, world!"))

In [None]:
# Discrete (and continuous, which we'll get to later) are encoded to 
# the 1024 token positions immediately after the text tokens.
text_tokenizer.n_vocab, tokenizer.n_text, tokenizer.n_discrete, tokenizer.decode_text(torch.tensor([[50256]]))

In [None]:
# The first discrete token, 0, gets encoded immedately after the last text token.
tokenizer.encode_discrete(0)

In [None]:
print(tokenizer.encode_discrete([0, 1, 2, 3, 4]))

In [None]:
transformer_model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=50257,  # tiktoken.get_encoding("r50k_base").n_vocab
    dropout=dropout,
)  # start with model_args from command line

mugato_model_args = dict(
    n_embd=n_embd,
    block_size=block_size,
    vocab_size=51281,  # text vocab + discrete vocab
)

In [None]:
# create the model
transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)
mugato_config = MugatoConfig(**mugato_model_args)
untrained_model = Mugato(tokenizer, transformer, mugato_config)

In [None]:
untrained_model = untrained_model.to(device);

In [None]:
untrained_model

In [None]:
untrained_model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [None]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, xs]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
with torch.no_grad():
    logits, loss = untrained_model(xs, pad=False)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
print(text)

In [None]:
from mugato.data.four_rooms import (
    initialize as initialize_four_rooms, 
    create_dataloader as create_four_rooms_dataloader, 
    tokenize as four_rooms_tokenize
)


In [None]:
four_rooms_dataset = initialize_four_rooms()
four_rooms_dataloader = create_four_rooms_dataloader(tokenizer, batch_size=batch_size, split="test")
batch = next(iter(four_rooms_dataloader))
X, Y, M = batch
X, Y, M = X.to(device), Y.to(device), M.to(device)
logits, loss = untrained_model(X, Y, M)
loss

# Test Four Rooms

In [None]:
test_data = four_rooms_dataset["test"]

In [None]:
episode = test_data[1]

In [None]:
env = test_data.recover_environment(render_mode="human")
obs, info = env.reset()
obs['direction'] = np.array([obs['direction']])
obs['image'] = np.array([obs['image']])
obs['mission'] = [obs['mission']]
dummy_action = 0  # Will be sliced off after sequencing.
obs['action'] = np.array([dummy_action])

In [None]:
from mugato.data.four_rooms import four_rooms_to_rgb
from mugato.utils import image_transform
from mugato.utils import Timesteps

In [None]:
tokenizer.separator

In [None]:
def tokenize(obs):
    mission_tokens = [
        tokenizer.encode_text(mission)
        for mission in obs["mission"]
    ]
    direction_tokens = [
        tokenizer.encode_discrete([direction])
        for direction in obs["direction"]
    ]
    _image = obs["image"]
    _image = four_rooms_to_rgb(_image)
    image_tokens = [tokenizer.encode_image(image) for image in image_transform(_image)]
    action_tokens = [
        tokenizer.encode_discrete([tokenizer.separator, action])
        for action in obs["action"]
    ]

    mission = torch.stack(mission_tokens)
    direction = torch.stack(direction_tokens)
    image = torch.stack(image_tokens)
    action = torch.stack(action_tokens)
    xs = Timesteps({
        "mission": mission,
        "direction": direction,
        "image": image,
        "action": action,
    })
    return xs

In [None]:
xs = tokenize(obs)
xs

In [None]:
# Add batch dimension.
xs = Timesteps([
    (k, torch.stack([v])) for k, v in xs.items()
])

In [None]:
def sequence_four_rooms(embedder, xs, ys=None, ms=None, sequence_length=1024, pad=True):
    embeddings = torch.concat([embedder.embed(v) for k, v in xs.items()], dim=2)
    B, E, T, C = embeddings.shape
    embeddings = embeddings.view(B, E * T, C)
    # Slice off final actions, so we can predict it.
    return embeddings[:, :-1]

In [None]:
next_word_token = None
i = 0
xs = xs.to(device)
logits, loss = untrained_model(xs, pad=False, sequence=sequence_four_rooms)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_token = tokenizer.decode_discrete(next_token)

In [None]:
next_token

In [None]:
def get_action(token, action_space):
    return token % tokenizer.n_text % env.action_space.n

In [None]:
# Track memory usage
import gc
import torch.cuda

def print_gpu_memory():
    print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    print(f"Cached: {torch.cuda.memory_reserved()/1e9:.2f}GB")

print("Initial GPU memory:")
print_gpu_memory()


In [None]:
for _ in range(10):
    # Clear memory before each iteration
    torch.cuda.empty_cache()
    gc.collect()
    
    # Step with the previously predicted action: `next_token[0]`
    obs, reward, terminated, truncated, info = env.step(get_action(next_token[0], env))
    
    # Prepare the next observation.
    obs['direction'] = np.array([obs['direction']])
    obs['image'] = np.array([obs['image']])
    obs['mission'] = [obs['mission']]
    # Prepare a temporary action token. Will be sliced off after sequencing.
    # We just need this because each modality of the episodes need to have 
    # the same `E` dimension (remember - (B, E, T, C)), so that we can 
    # concatenate them on the `T` dimension.
    dummy_action = 0
    obs['action'] = np.array([dummy_action])
    
    # Move old tensors to CPU to free GPU memory
    xs = xs.to("cpu")
    
    xs_new = tokenize(obs)
    # Merge the new episode.
    xs = Timesteps([
        (k, torch.concat([xs[k], xs_new[k].to("cpu").unsqueeze(0)])) for k in xs.keys()
    ])
    
    # Only move to GPU right before model inference
    xs = xs.to(device)
    
    # Predict the next action
    with torch.no_grad():  # Use mixed precision to reduce memory
        logits, loss = untrained_model(xs, pad=False, sequence=sequence_four_rooms)
    
    temp = 0.8
    scaled_logits = logits / temp
    probs = scaled_logits.softmax(dim=2)
    next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
    next_token = tokenizer.decode_discrete(next_token)
    
    # Move tensors back to CPU and clear GPU cache
    xs = xs.to("cpu")
    logits = logits.to("cpu")
    probs = probs.to("cpu")
    torch.cuda.empty_cache()
    
    print(f"Next token: {next_token}")

In [None]:
env.close()

In [None]:
torch.cuda.empty_cache()
gc.collect()

# Trained model

In [None]:
out_dir = data_home / "out"
ckpt_path = os.path.join(out_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
checkpoint_model_args = checkpoint["model_args"]
for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]:
    transformer_model_args[k] = checkpoint_model_args[k]

transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)

mugato_config = MugatoConfig(**mugato_model_args)
trained_model = Mugato(tokenizer, transformer, mugato_config)
trained_model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

trained_model = trained_model.to(device);

In [None]:
trained_model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [None]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, xs]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
logits, loss = trained_model(xs, pad=False)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
print(text)

In [None]:
env = test_data.recover_environment(render_mode="human")
obs, info = env.reset()
obs['direction'] = np.array([obs['direction']])
obs['image'] = np.array([obs['image']])
obs['mission'] = [obs['mission']]
dummy_action = 0  # Will be sliced off after sequencing.
obs['action'] = np.array([dummy_action])
xs = tokenize(obs)
# Add batch dimension.
xs = Timesteps([
    (k, torch.stack([v])) for k, v in xs.items()
])

In [None]:
next_word_token = None
i = 0
xs = xs.to(device)
logits, loss = trained_model(xs, pad=False, sequence=sequence_four_rooms)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_token = tokenizer.decode_discrete(next_token)

In [None]:
for _ in range(10):
    # Clear memory before each iteration
    torch.cuda.empty_cache()
    gc.collect()
    
    # Step with the previously predicted action: `next_token[0]`
    obs, reward, terminated, truncated, info = env.step(get_action(next_token[0], env))
    
    # Prepare the next observation.
    obs['direction'] = np.array([obs['direction']])
    obs['image'] = np.array([obs['image']])
    obs['mission'] = [obs['mission']]
    # Prepare a temporary action token. Will be sliced off after sequencing.
    # We just need this because each modality of the episodes need to have 
    # the same `E` dimension (remember - (B, E, T, C)), so that we can 
    # concatenate them on the `T` dimension.
    dummy_action = 0
    obs['action'] = np.array([dummy_action])
    
    # Move old tensors to CPU to free GPU memory
    xs = xs.to("cpu")
    
    xs_new = tokenize(obs)
    # Merge the new episode.
    xs = Timesteps([
        (k, torch.concat([xs[k], xs_new[k].to("cpu").unsqueeze(0)])) for k in xs.keys()
    ])
    
    # Only move to GPU right before model inference
    xs = xs.to(device)
    
    # Predict the next action
    with torch.no_grad():  # Use mixed precision to reduce memory
        logits, loss = trained_model(xs, pad=False, sequence=sequence_four_rooms)
    
    temp = 0.8
    scaled_logits = logits / temp
    probs = scaled_logits.softmax(dim=2)
    next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
    next_token = tokenizer.decode_discrete(next_token)
    
    # Move tensors back to CPU and clear GPU cache
    xs = xs.to("cpu")
    logits = logits.to("cpu")
    probs = probs.to("cpu")
    torch.cuda.empty_cache()
    
    print(f"Next token: {next_token}")

In [None]:
env.close()