In [1]:
from collections import OrderedDict
import os
import torch
import torch.nn as nn
import tiktoken
import numpy as np

from mugato.data.utils import create_combined_dataloader
from mugato.mugato import MugatoConfig, Mugato, TransformerConfig
from mugato.nano_gpt import Block
from mugato.utils import data_home, select_device, generic_collate_fn
from mugato.tokenizer import Tokenizer

In [2]:
n_layer = 6
n_head = 4
n_embd = 512
bias = False
dropout = 0.0
block_size=768
batch_size=4
device = select_device()

In [3]:
text_tokenizer = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_tokenizer)
train_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="train", block_size=block_size))
val_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="val", block_size=block_size))
test_dataloader = iter(create_combined_dataloader(tokenizer, batch_size, split="test", block_size=block_size))

In [4]:
print(tokenizer.text_tokenizer.encode("Hello, abcDEF world!"))
print(tokenizer.text_tokenizer.decode(tokenizer.text_tokenizer.encode("Hello, abcDEF world!")))

[15496, 11, 450, 66, 32988, 995, 0]
Hello, abcDEF world!


In [5]:
print(f'{"token":>10}: {"text":>10}')
for token in tokenizer.text_tokenizer.encode("Hello, abcDEF world!"):
    decoded = tokenizer.text_tokenizer.decode([token])
    print(f'{token:>10}: {decoded:>10}')

     token:       text
     15496:      Hello
        11:          ,
       450:         ab
        66:          c
     32988:        DEF
       995:      world
         0:          !


In [6]:
text_tokenizer.n_vocab

50257

In [7]:
tokenizer.n_text, tokenizer.n_discrete

(50257, 1024)

In [8]:
tokenizer.encode_text("Hello, world!")

tensor([[15496],
        [   11],
        [  995],
        [    0]])

In [9]:
tokenizer.encode_discrete(1)

tensor([50258])

In [13]:
tokenizer.decode_text(torch.tensor([[50256]]))

'<|endoftext|>'

In [15]:
for i in range(5):
    print(tokenizer.encode_discrete(i))

tensor([50257])
tensor([50258])
tensor([50259])
tensor([50260])
tensor([50261])


In [4]:
# model init
transformer_model_args = dict(
    n_layer=n_layer,
    n_head=n_head,
    n_embd=n_embd,
    block_size=block_size,
    bias=bias,
    vocab_size=50257,  # tiktoken.get_encoding("r50k_base").n_vocab
    dropout=dropout,
)  # start with model_args from command line

mugato_model_args = dict(
    n_embd=n_embd,
    block_size=block_size,
    vocab_size=51281,  # text vocab + discrete vocab
)

In [5]:
# create the model
transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)
mugato_config = MugatoConfig(**mugato_model_args)
untrained_model = Mugato(tokenizer, transformer, mugato_config)

In [6]:
untrained_model = untrained_model.to(device);

In [7]:
untrained_model

Mugato(
  (lookup_embedding): Embedding(51281, 512)
  (image_embedding): ResNetV2(
    (stem): Sequential(
      (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (stages): Sequential(
      (0): ResNetStage(
        (blocks): Sequential(
          (0): PreActBottleneck(
            (downsample): DownsampleConv(
              (conv): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (norm): Identity()
            )
            (norm1): GroupNormAct(
              32, 64, eps=1e-05, affine=True
              (drop): Identity()
              (act): ReLU(inplace=True)
            )
            (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (norm2): GroupNormAct(
              32, 64, eps=1e-05, affine=True
              (drop): Identity()
              (act): ReLU(inplace=True)
        

In [47]:
untrained_model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [16]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, xs]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
with torch.no_grad():
    logits, loss = untrained_model(xs, pad=False)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
print(text)

NameError: name 'tokens' is not defined

In [17]:
from mugato.data.four_rooms import (
    initialize as initialize_four_rooms, 
    create_dataloader as create_four_rooms_dataloader, 
    tokenize as four_rooms_tokenize
)


In [21]:
four_rooms_dataset = initialize_four_rooms()
four_rooms_dataloader = create_four_rooms_dataloader(tokenizer, batch_size=batch_size, split="test")
batch = next(iter(four_rooms_dataloader))
X, Y, M = batch
X, Y, M = X.to(device), Y.to(device), M.to(device)
logits, loss = untrained_model(X, Y, M)
loss

tensor(11.2002, device='cuda:0', grad_fn=<DivBackward0>)

# Test Four Rooms

In [61]:
test_data = four_rooms_dataset["test"]

In [62]:
episode = test_data[1]

In [63]:
env = test_data.recover_environment(render_mode="human")
obs, info = env.reset()
obs['direction'] = np.array([obs['direction']])
obs['image'] = np.array([obs['image']])
obs['mission'] = [obs['mission']]
dummy_action = 0  # Will be sliced off after sequencing.
obs['action'] = np.array([dummy_action])

In [64]:
from mugato.data.four_rooms import four_rooms_to_rgb
from mugato.utils import image_transform
from mugato.utils import Timesteps

In [65]:
tokenizer.separator

1023

In [66]:
def tokenize(obs):
    mission_tokens = [
        tokenizer.encode_text(mission)
        for mission in obs["mission"]
    ]
    direction_tokens = [
        tokenizer.encode_discrete([direction])
        for direction in obs["direction"]
    ]
    _image = obs["image"]
    _image = four_rooms_to_rgb(_image)
    image_tokens = [tokenizer.encode_image(image) for image in image_transform(_image)]
    action_tokens = [
        tokenizer.encode_discrete([tokenizer.separator, action])
        for action in obs["action"]
    ]

    mission = torch.stack(mission_tokens)
    direction = torch.stack(direction_tokens)
    image = torch.stack(image_tokens)
    action = torch.stack(action_tokens)
    xs = Timesteps({
        "mission": mission,
        "direction": direction,
        "image": image,
        "action": action,
    })
    return xs

In [67]:
xs = tokenize(obs)
xs

Timesteps([('mission',
            tensor([[[16250],
                     [  262],
                     [ 3061]]])),
           ('direction', tensor([[[50260]]])),
           ('image',
            tensor([[[-0.2500, -0.2500, -0.2500,  ...,  0.0850,  0.0850,  0.0850],
                     [-0.2500, -0.2500, -0.2500,  ...,  0.0850,  0.0850,  0.0850],
                     [-0.2500, -0.2500, -0.2500,  ...,  0.0850,  0.0850,  0.0850],
                     ...,
                     [ 0.2500,  0.2500,  0.2500,  ..., -0.2230, -0.2230, -0.2230],
                     [ 0.2500,  0.2500,  0.2500,  ..., -0.2230, -0.2230, -0.2230],
                     [ 0.2500,  0.2500,  0.2500,  ..., -0.2230, -0.2230, -0.2230]]])),
           ('action',
            tensor([[[51280],
                     [50257]]]))])

In [68]:
# Add batch dimension.
xs = Timesteps([
    (k, torch.stack([v])) for k, v in xs.items()
])

In [69]:
def sequence_four_rooms(embedder, xs, ys=None, ms=None, sequence_length=1024, pad=True):
    embeddings = torch.concat([embedder.embed(v) for k, v in xs.items()], dim=2)
    B, E, T, C = embeddings.shape
    embeddings = embeddings.view(B, E * T, C)
    # Slice off final actions, so we can predict it.
    return embeddings[:, :-1]

In [70]:
next_word_token = None
i = 0
xs = xs.to(device)
logits, loss = untrained_model(xs, pad=False, sequence=sequence_four_rooms)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_token = tokenizer.decode_discrete(next_token)

In [71]:
next_token

[2165]

In [72]:
def get_action(token, action_space):
    return token % tokenizer.n_text % env.action_space.n

In [73]:
# Track memory usage
import gc
import torch.cuda

def print_gpu_memory():
    print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
    print(f"Cached: {torch.cuda.memory_reserved()/1e9:.2f}GB")

print("Initial GPU memory:")
print_gpu_memory()


Initial GPU memory:
Allocated: 5.22GB
Cached: 6.89GB


In [74]:
for _ in range(10):
    # Clear memory before each iteration
    torch.cuda.empty_cache()
    gc.collect()
    
    # Step with the previously predicted action: `next_token[0]`
    obs, reward, terminated, truncated, info = env.step(get_action(next_token[0], env))
    
    # Prepare the next observation.
    obs['direction'] = np.array([obs['direction']])
    obs['image'] = np.array([obs['image']])
    obs['mission'] = [obs['mission']]
    # Prepare a temporary action token. Will be sliced off after sequencing.
    # We just need this because each modality of the episodes need to have 
    # the same `E` dimension (remember - (B, E, T, C)), so that we can 
    # concatenate them on the `T` dimension.
    dummy_action = 0
    obs['action'] = np.array([dummy_action])
    
    # Move old tensors to CPU to free GPU memory
    xs = xs.to("cpu")
    
    xs_new = tokenize(obs)
    # Merge the new episode.
    xs = Timesteps([
        (k, torch.concat([xs[k], xs_new[k].to("cpu").unsqueeze(0)])) for k in xs.keys()
    ])
    
    # Only move to GPU right before model inference
    xs = xs.to(device)
    
    # Predict the next action
    with torch.no_grad():  # Use mixed precision to reduce memory
        logits, loss = untrained_model(xs, pad=False, sequence=sequence_four_rooms)
    
    temp = 0.8
    scaled_logits = logits / temp
    probs = scaled_logits.softmax(dim=2)
    next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
    next_token = tokenizer.decode_discrete(next_token)
    
    # Move tensors back to CPU and clear GPU cache
    xs = xs.to("cpu")
    logits = logits.to("cpu")
    probs = probs.to("cpu")
    torch.cuda.empty_cache()
    
    print(f"Next token: {next_token}")

Next token: [2970]
Next token: [25720]
Next token: [40118]
Next token: [30150]
Next token: [41947]
Next token: [35754]
Next token: [49789]
Next token: [24458]
Next token: [20414]
Next token: [24190]


: 

In [36]:
env.close()

In [37]:
torch.cuda.empty_cache()
gc.collect()

0

# Trained model

In [38]:
out_dir = data_home / "out"
ckpt_path = os.path.join(out_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)

state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
checkpoint_model_args = checkpoint["model_args"]
for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]:
    transformer_model_args[k] = checkpoint_model_args[k]

transformer_config = TransformerConfig(**transformer_model_args)
transformer = nn.ModuleDict(
    dict(
        wpe=nn.Embedding(transformer_config.block_size, transformer_config.n_embd),
        drop=nn.Dropout(transformer_config.dropout),
        h=nn.ModuleList(
            [
                Block(transformer_config)
                for _ in range(transformer_config.n_layer)
            ]
        ),
    )
)

mugato_config = MugatoConfig(**mugato_model_args)
trained_model = Mugato(tokenizer, transformer, mugato_config)
trained_model.load_state_dict(state_dict)
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]

trained_model = trained_model.to(device);

In [39]:
trained_model.eval()
text = "First Citizen:\n"
tokens = torch.stack([torch.concat([torch.tensor([tokenizer.eot_token_id]).unsqueeze(0), tokenizer.encode_text(text)])])

In [40]:
xs = OrderedDict(text=tokens)
xs, ys, ms = generic_collate_fn([[xs, xs]])
next_word_token = None
i = 0
xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
logits, loss = trained_model(xs, pad=False)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_word_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_word = tokenizer.decode_text(next_word_token)
text += next_word
tokens = torch.stack([tokenizer.encode_text(text)])
print(text)

First Citizen:
Where


In [41]:
env = test_data.recover_environment(render_mode="human")
obs, info = env.reset()
obs['direction'] = np.array([obs['direction']])
obs['image'] = np.array([obs['image']])
obs['mission'] = [obs['mission']]
dummy_action = 0  # Will be sliced off after sequencing.
obs['action'] = np.array([dummy_action])
xs = tokenize(obs)
# Add batch dimension.
xs = Timesteps([
    (k, torch.stack([v])) for k, v in xs.items()
])

In [45]:
next_word_token = None
i = 0
xs = xs.to(device)
logits, loss = trained_model(xs, pad=False, sequence=sequence_four_rooms)
temp = 0.6
scaled_logits = logits / temp
probs = scaled_logits.softmax(dim=2)
next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
next_token = tokenizer.decode_discrete(next_token)

In [46]:
for _ in range(10):
    # Clear memory before each iteration
    torch.cuda.empty_cache()
    gc.collect()
    
    # Step with the previously predicted action: `next_token[0]`
    obs, reward, terminated, truncated, info = env.step(get_action(next_token[0], env))
    
    # Prepare the next observation.
    obs['direction'] = np.array([obs['direction']])
    obs['image'] = np.array([obs['image']])
    obs['mission'] = [obs['mission']]
    # Prepare a temporary action token. Will be sliced off after sequencing.
    # We just need this because each modality of the episodes need to have 
    # the same `E` dimension (remember - (B, E, T, C)), so that we can 
    # concatenate them on the `T` dimension.
    dummy_action = 0
    obs['action'] = np.array([dummy_action])
    
    # Move old tensors to CPU to free GPU memory
    xs = xs.to("cpu")
    
    xs_new = tokenize(obs)
    # Merge the new episode.
    xs = Timesteps([
        (k, torch.concat([xs[k], xs_new[k].to("cpu").unsqueeze(0)])) for k in xs.keys()
    ])
    
    # Only move to GPU right before model inference
    xs = xs.to(device)
    
    # Predict the next action
    with torch.no_grad():  # Use mixed precision to reduce memory
        logits, loss = trained_model(xs, pad=False, sequence=sequence_four_rooms)
    
    temp = 0.8
    scaled_logits = logits / temp
    probs = scaled_logits.softmax(dim=2)
    next_token = torch.multinomial(probs[0, [-1]], num_samples=1)
    next_token = tokenizer.decode_discrete(next_token)
    
    # Move tensors back to CPU and clear GPU cache
    xs = xs.to("cpu")
    logits = logits.to("cpu")
    probs = probs.to("cpu")
    torch.cuda.empty_cache()
    
    print(f"Next token: {next_token}")

error: cannot convert without pygame.display initialized

In [44]:
env.close()