In [1]:
from typing import Callable
from dataclasses import dataclass, fields
from functools import partial
from itertools import cycle
import pdb
import random
from einops import rearrange
import numpy as np
import minari
from minigrid.core import constants as mgc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, ResNet18_Weights
import torchvision.transforms.v2 as transforms
from tqdm.notebook import tqdm
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
minigrid_dataset = minari.load_dataset('D4RL/minigrid/fourrooms-v0', download=True)
env  = minigrid_dataset.recover_environment()

In [4]:
# Note on shapes:
# You're probably familiar with the old (B, T, C) shape – batch, timestep, channel.
# I'd like to introduce a new dimension: sequence. It fits between timestep and channel.
# `stack` here concats together sequences. A timestep is a list of sequence.
# IMPORTANT! Shape must always be (T, S, ...) for the below code to work.
# Stack pads along S and concats along T.
@dataclass
class TokenData:
    tokens: torch.Tensor
    targets: torch.Tensor
    attention_mask: torch.Tensor
    embedding: torch.Tensor = torch.tensor([])  # Optional at first.

    def combine(self, other):
        """Concats attributes of self to attributes of other."""
        # Requires padding to already be handled.
        # Requires shapes to be (T', T, C, ...)
        # Where T' is episode timestep and T is the usual timestep.
        return type(self)(
            tokens=torch.concat([self.tokens, other.tokens]),
            targets=torch.concat([self.targets, other.targets]),
            attention_mask=torch.concat([self.attention_mask, other.attention_mask]),
        )

    def embed(self, embedder):
        raise Exception('TODO: Override')

    def to(self, device):
        return type(self)(
            tokens=self.tokens.to(device),
            targets=self.targets.to(device),
            attention_mask=self.attention_mask.to(device),
        )
        
    @property
    def size(self):
        """The number of tokens this will consume of the context window"""
        return self.tokens.size(0) * self.tokens.size(1)

class TextTokenData(TokenData):
    def embed(self, embedder):
        return type(self)(
            tokens=self.tokens,
            targets=self.targets,
            attention_mask=self.attention_mask,
            embedding=embedder.text(self.tokens),
        ) 

class ImageTokenData(TokenData):
    def embed(self, embedder):
        return type(self)(
            tokens=self.tokens,
            targets=self.targets,
            attention_mask=self.attention_mask,
            embedding=embedder.image(self.tokens),
        ) 

class DiscreteTokenData(TokenData):
    def embed(self, embedder):
        return type(self)(
            tokens=self.tokens,
            targets=self.targets,
            attention_mask=self.attention_mask,
            embedding=embedder.discrete(self.tokens),
        ) 

In [5]:
@dataclass
class EpisodeData:
    def __getitem__(self, i):
        # Iterate over fields
        return type(self)(**{
            field.name: type(getattr(self, field.name))(
                tokens=getattr(self, field.name).tokens[i:i+1],
                targets=getattr(self, field.name).targets[i:i+1],
                attention_mask=getattr(self, field.name).attention_mask[i:i+1],
            )
            for field in fields(self)
        })

    def combine(self, other):
        return type(self)(**{
            field.name: getattr(self, field.name).combine(getattr(other, field.name))
            for field in fields(self)
        })

    @property
    def size(self):
        return sum(getattr(self, field.name).size for field in fields(self))

    @property
    def num_timesteps(self):
        return next(getattr(self, field.name) for field in fields(self)).tokens.size(0)

    def embed(self, embedder):
        return type(self)(**{
            field.name: getattr(self, field.name).embed(embedder)
            for field in fields(self)
        })

    def to(self, device):
        return type(self)(**{
            field.name: getattr(self, field.name).to(device)
            for field in fields(self)
        })

    def sequence(self, embeddings):
        raise Exception('Override me')

@dataclass
class FourRoomsTimestep(EpisodeData):
    mission: TextTokenData  # torch.Size((length of episode subsequence, length of _max_ (pad) mission text tokens))
    image: ImageTokenData
    direction: DiscreteTokenData
    actions: DiscreteTokenData

    def sequence(self, sequence_length):
        xs = torch.concat([self.mission.embedding, self.image.embedding, self.direction.embedding, self.actions.embedding], dim=1)
        ys = torch.concat([self.mission.targets, self.image.targets, self.direction.targets, self.actions.targets], dim=1)
        ms = torch.concat([self.mission.attention_mask, self.image.attention_mask, self.direction.attention_mask, self.actions.attention_mask], dim=1)
        T, S, C = xs.shape
        xs, ys, ms = xs.reshape(T*S, C), ys.reshape(T*S), ms.reshape(T*S)
        padding_len = sequence_length - T*S
        xs = F.pad(xs, (0, 0, 0, padding_len), value=0)
        ys, ms = [F.pad(x, (0, padding_len), value=0) for x in [ys, ms]]
        return xs, ys, ms

In [6]:
class Tokenizer:
    def __init__(self, text_gen_tokenizer, text_obs_tokenizer):
        self.text_gen_tokenizer = text_gen_tokenizer
        self.text_obs_tokenizer = text_obs_tokenizer

    def text_gen(self, data, **kwargs):
        tokenized =  self.text_gen_tokenizer(data, **kwargs)
        return TextTokenData(**{
            "tokens": tokenized["input_ids"][:, :-1],
            "targets": tokenized["input_ids"][:, 1:].to(torch.long),
            "attention_mask": tokenized["attention_mask"][:, :-1],
        })

    def text_obs(self, data, **kwargs):
        tokenized =  self.text_obs_tokenizer(data, **kwargs)
        return TextTokenData(**{
            "tokens": tokenized["input_ids"],
            "targets": tokenized["input_ids"].to(torch.long),
            "attention_mask": torch.zeros_like(tokenized["attention_mask"]),
        })

    def image(self, data):
        if len(data.shape) == 3:
          data = data.unsqueeze(0)
        patches = images_to_patches(data, patch_size=16)
        # Hardcoding as a reminder to do something smarter
        SQUARE_ROOT_OF_PATCH_SIZE = 3.464
        xs = (
            apply_along_dimension(
                normalize_to_between_minus_one_plus_one, 2, patches
            )
            / SQUARE_ROOT_OF_PATCH_SIZE
        )
        # We don't predict images, but we need ys
        # becaues these image ys will be in our
        # concatenated ys of text/image/action/etc...
        ys = torch.zeros(xs.shape[:2])
        ms = torch.zeros(xs.shape[:2])  # Same story as above.
        return ImageTokenData(tokens=xs, targets=ys, attention_mask=ms)

    def discrete_obs(self, data):
        if len(data.shape) == 0:
            data = data.unsqueeze(0)
        if len(data.shape) == 1:
            data = data.unsqueeze(1)
        xs = data
        ys = torch.zeros(xs.shape[:2])
        ms = torch.zeros(xs.shape[:2])
        return DiscreteTokenData(tokens=xs, targets=ys, attention_mask=ms)

    def discrete_act(self, data):
        if len(data.shape) == 0:
            data = data.unsqueeze(0)
        if len(data.shape) == 1:
            data = data.unsqueeze(1)
        xs = torch.concat([torch.full((data.size(0), 1), 1023), data], dim=1)[:, :-1]  # Instead of '|' being the separator, like Gato...
        ys = data
        ms = torch.ones(*ys.shape)
        return DiscreteTokenData(tokens=xs, targets=ys, attention_mask=ms)

    def continuous(self, data):
        raise Exception('TODO: Tokenizer.continuous')

In [7]:
torch.concat([torch.arange(3), torch.tensor([5])])

tensor([0, 1, 2, 5])

In [8]:
class TransformDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.transform(self.dataset[idx])

In [9]:
SEQUENCE_LENGTH = 1024

In [10]:
__text_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
__text_tokenizer.pad_token = __text_tokenizer.eos_token
_text_gen_tokenizer = partial(
    __text_tokenizer,
    max_length=SEQUENCE_LENGTH+1,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)
_text_obs_tokenizer = partial(
    __text_tokenizer,
    max_length=SEQUENCE_LENGTH,
    truncation=True,
    padding="longest",
    return_tensors="pt",
)

In [11]:
def images_to_patches(images, patch_size=16):
    return rearrange(images, 'b c (h s1) (w s2) -> b (h w) (c s1 s2)', s1=patch_size, s2=patch_size)
def normalize_to_between_minus_one_plus_one(t: torch.Tensor):
    min_val, max_val = t.min(), t.max()
    if min_val == max_val:
        return torch.zeros_like(t)
    normalized = 2 * (t - min_val) / (max_val - min_val) - 1
    return normalized
# There's a small deviation in the NEKO codebase from the paper.
# The paper normalizes _per patch_. The NEKO codebase currently normalizes _per image_.
# https://github.com/eihli/NEKO/blob/master/gato/policy/embeddings.py#L38
# This notebook normalizeds per patch. That's what this utility helps.
def apply_along_dimension(func, dim, tensor):
    tensor = tensor.transpose(0, dim)
    shape = tensor.shape
    tensor = tensor.reshape(shape[0], -1)
    result = torch.stack([func(tensor[:, i]) for i in range(tensor.size(1))], dim=1)
    result = result.reshape(shape).transpose(0, dim)
    return result

In [12]:
# Create lookup table
lut = np.zeros((256, 3), dtype=np.uint8)
for idx, color_name in mgc.IDX_TO_COLOR.items():
    lut[idx] = mgc.COLORS[color_name]

def minigrid_to_rgb(episode):
    """Convert discrete "image" observations into actual images.
    I'm expecting this will improve our image modality while not losing
    much. The downside is we can fit less in our context window. Note:
    We might need to overlay the color/type image (index 1) with the
    state image (index 2), if we really don't want to lose any info."""
    # Apply lookup to second channel
    image = lut[episode.observations['image'][:, :, :, 1]]
    # Convert to PyTorch tensor and permute
    image = torch.from_numpy(image).permute(0, 3, 1, 2)
    return image

image_transform = transforms.Compose([
    # No particular reason to use `transforms.Compose` here since we're only doing one transform. But it's nice to know about.
    transforms.RandomResizedCrop((192, 192), (0.5, 1.0)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def minigrid_tokenizer(tokenizer, episode):
    num_timesteps = len(episode.actions)
    image = image_transform(minigrid_to_rgb(episode)[:num_timesteps])
    image = tokenizer.image(image[:num_timesteps])
    mission = tokenizer.text_obs(episode.observations['mission'][:num_timesteps], padding=False)
    direction = tokenizer.discrete_obs(torch.from_numpy(episode.observations['direction'])[:num_timesteps])
    actions = tokenizer.discrete_act(torch.from_numpy(episode.actions))
    return FourRoomsTimestep(mission=mission, image=image, direction=direction, actions=actions)

In [13]:
tokenizer = Tokenizer(_text_gen_tokenizer, _text_obs_tokenizer)
minigrid_tokenize = partial(minigrid_tokenizer, tokenizer)

In [14]:
minigrid_dataset_xf = TransformDataset(minigrid_dataset, minigrid_tokenize)

In [15]:
minigrid_dataset_xf[0].actions.tokens.shape

torch.Size([19, 1])

In [16]:
minigrid_dataset_xf[0].mission.targets.shape

torch.Size([19, 3])

In [17]:
BATCH_SIZE=4

In [18]:
def minigrid_collate_fn(batch):
    result = []
    for sample in batch:
        i = random.randint(0, sample.num_timesteps - 1)

        # Starting at that index, we'll continue adding observations to our context window until
        # we run out of space.
        step = sample[i]
        while i < len(sample.actions.tokens) and step.size + step[0].size < SEQUENCE_LENGTH:
            i += 1
            step = step.combine(sample[i])
        result.append(step)
    return result

In [19]:
minigrid_dataloader = DataLoader(minigrid_dataset_xf, batch_size=BATCH_SIZE, collate_fn=minigrid_collate_fn)

In [20]:
minigrid_batch = next(iter(minigrid_dataloader))

In [21]:
# From section 2.2 of the Gato paper:
#
#    Tokens belonging to image patches for any time-step are embedded using a
#    single ResNet (He et al., 2016a) block to obtain a vector per patch. For
#    image patch token embeddings, we also add a learnable within-image position
#    encoding vector.
class ResNetV2Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, num_groups=24):
        super(ResNetV2Block, self).__init__()
        self.gn1 = nn.GroupNorm(1, in_channels)
        self.gelu = nn.GELU()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.gn2 = nn.GroupNorm(num_groups, out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=False
        )

    def forward(self, x):
        B, T, CHW = x.shape
        # TODO: Remove these hardcoded values.
        out = rearrange(x, 'b t (c h w) -> (b t) c h w', c=3, h=16)
        out = self.gn1(out)
        out = self.gelu(out)
        out = self.conv1(out)
        out = self.gn2(out)
        out = self.gelu(out)
        out = self.conv2(out)
        return x + rearrange(out, '(b t) c h w -> b t (c h w)', b=B, t=T)

In [22]:
@dataclass
class Embedder:
    text: Callable
    image: Callable
    discrete: Callable

In [23]:
@dataclass
class MiniGatoConfig:
    embedding_dim: int
    sequence_length: int
    vocab_size: int 
    transformer_config: GPT2Config
    transformer: GPT2Model

In [24]:
def init_default_config() -> MiniGatoConfig:
    transformer_config = GPT2Config()
    return MiniGatoConfig(
        embedding_dim=768,
        sequence_length=1024,
        vocab_size=__text_tokenizer.vocab_size,
        transformer_config=transformer_config,
        transformer=GPT2Model(transformer_config),
    )
default_config = init_default_config()

In [25]:
class MiniGato(nn.Module):
    def __init__(self, config: MiniGatoConfig=default_config):
        super().__init__()
        self.config = config
        self.sequence_length = self.config.sequence_length
        text_embedding = nn.Embedding(self.config.vocab_size, self.config.embedding_dim)
        image_embedding = ResNetV2Block(3, self.config.embedding_dim)
        discrete_embedding = nn.Embedding(1024, self.config.embedding_dim)
        self.embedder = Embedder(text=text_embedding, image=image_embedding, discrete=discrete_embedding)
        self.transformer = self.config.transformer
        self.lm_head = nn.Linear(self.transformer.config.hidden_size, self.config.vocab_size)     

    def forward(self, batch):
        batch = [
            sample.embed(self.embedder).sequence(self.sequence_length) for sample in batch
        ]
        xs, ys, ms = map(torch.stack, zip(*batch))
        xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
        out = self.transformer(inputs_embeds=xs)
        predicted = self.lm_head(out.last_hidden_state)
        return predicted, ys, ms

In [26]:
minigrid_dataloader = DataLoader(minigrid_dataset_xf, batch_size=BATCH_SIZE, collate_fn=minigrid_collate_fn)
minigrid_iterator = iter(minigrid_dataloader)

In [27]:
minigrid_batch = next(minigrid_iterator)

In [28]:
def infinite_dataloader(fn):
    it = iter(fn())
    while True:
        try:
            yield next(it)
        except StopIteration:
            it = iter(fn())

In [29]:
dataloaders = [
    infinite_dataloader(partial(DataLoader, minigrid_dataset_xf, batch_size=BATCH_SIZE, collate_fn=minigrid_collate_fn, num_workers=4))
]

In [30]:
## Loss
##
## See section 2.3 of the Gato paper.
##
##   Let b index a training batch of sequences B. We define a masking function m
##   such that m(b, l) = 1 if the token at index l is either from text or from
##   the logged action of an agent, and 0 otherwise. The training loss for a
##   batch B can then be written as...
def cross_entropy(predicted, target, mask):
    # See: https://youtu.be/kCc8FmEb1nY?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&t=1553
    B, T, C = predicted.shape
    predicted = predicted.view(B * T, C)
    target = target.view(-1).to(torch.long)
    losses = F.cross_entropy(predicted, target, reduction="none")
    losses = losses * mask.squeeze(-1).view(-1)
    loss = losses.sum() / mask.sum()
    return loss

In [34]:
class MiniGatoTrainer:
    def __init__(self, model, optimizer, dataloaders):
        self.model = model
        self.optimizer = optimizer
        self.dataloaders = dataloaders
        self.dl_it = cycle(dataloaders)
        self.losses = []

    def train(self, iterations=50):
        self.model.train()
        for i in tqdm(range(iterations)):
            dl = next(self.dl_it)
            batch = next(dl)
            predicted, targets, attention_mask = self.model(batch)
            loss = cross_entropy(predicted, targets, attention_mask)
            self.losses.append(loss.item())
            loss.backward()
            self.optimizer.step()

In [35]:
config = init_default_config()
model = MiniGato(config).to(device)
optimizer = torch.optim.AdamW(model.parameters())
trainer = MiniGatoTrainer(
    model,
    optimizer,
    dataloaders,
)

In [44]:
trainer.train()
trainer.train()
trainer.train()
trainer.train()
trainer.train()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [45]:
trainer.losses

[10.339223861694336,
 1.6069600582122803,
 2.229372024536133,
 1.258705735206604,
 2.7163641452789307,
 1.777280330657959,
 2.38631534576416,
 1.599935531616211,
 1.9081206321716309,
 1.6184971332550049,
 1.4483606815338135,
 1.645542860031128,
 0.7315148711204529,
 1.316531777381897,
 1.6763911247253418,
 0.5046938061714172,
 0.47034740447998047,
 0.7731776237487793,
 1.2030525207519531,
 2.016554594039917,
 2.9713335037231445,
 3.7278053760528564,
 4.147859573364258,
 4.27758264541626,
 4.491320610046387,
 4.8445353507995605,
 4.7534050941467285,
 4.044671058654785,
 2.968477487564087,
 1.7000150680541992,
 0.8968724608421326,
 0.7794124484062195,
 1.0121744871139526,
 0.5032528042793274,
 2.1747968196868896,
 3.135807514190674,
 0.7833988070487976,
 1.7329763174057007,
 2.0020411014556885,
 2.197782278060913,
 3.306472063064575,
 3.6065943241119385,
 3.0668156147003174,
 3.968012809753418,
 2.7870094776153564,
 5.604249000549316,
 9.405218124389648,
 7.041199684143066,
 10.381896018