In [72]:
class Episode:
    pass

class TextEpisode(Episode):
    def __init__(self, observations):
        self.observations = observations

class VQAEpisode(Episode):
    pass

class AgentEpisode(Episode):
    pass

In [73]:
class Step:
    pass

class TextEpisode(Episode):
    def __init__(self, observations):
        self.observations = observations

class VQAEpisode(Episode):
    pass

class AgentEpisode(Episode):
    pass

In [74]:
class Observation:
    def __init__(self, modalities):
        self.modalities = modalities

In [75]:
import torch

In [625]:
from dataclasses import dataclass
from operator import attrgetter
import torch.nn.functional as F

@dataclass
class Data:
    tokens: torch.Tensor
    targets: torch.Tensor
    attention_mask: torch.Tensor

    def stack(self, other, pad_token=torch.tensor(0)):
        padded = min([self, other], key=lambda x: len(x.tokens))
        other = max([self, other], key=lambda x: len(x.tokens))
        pad_by = len(other.tokens) - len(padded.tokens)
        padded = type(padded)(
            tokens=F.pad(padded.tokens, (0, 0, 0, pad_by), value=pad_token),
            targets=F.pad(padded.targets, (0, 0, 0, pad_by), value=pad_token),
            attention_mask=F.pad(padded.attention_mask, (0, pad_by), value=pad_token),
        )
        return type(self)(
            tokens=torch.stack([padded.tokens, other.tokens]),
            targets=torch.stack([padded.targets, other.targets]),
            attention_mask=torch.stack([padded.attention_mask, other.attention_mask])
        )

class TextData(Data):
    pass

class ImageData(Data):
    pass

class DiscreteData(Data):
    pass

class ContinuousData(Data): 
    pass

In [626]:
a = TextData(tokens=torch.arange(4), targets=torch.arange(4), attention_mask=torch.ones((4,)))
b = TextData(tokens=torch.arange(5), targets=torch.arange(5), attention_mask=torch.ones((5,)))
c = a.stack(b)
c.tokens.shape, c.targets.shape, c

RuntimeError: Padding length should be less than or equal to two times the input dimension but got padding length 4 and input of dimension 1

In [581]:
SEQUENCE_LENGTH = 1024
BATCH_SIZE = 4
NUM_WORKERS = 4  # DataLoader

In [582]:
from transformers import GPT2Tokenizer

In [583]:
# Prefixing with _ to signify global.
__text_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
__text_tokenizer.pad_token = __text_tokenizer.eos_token
_text_tokenizer = partial(
    __text_tokenizer,
    max_length=SEQUENCE_LENGTH+1,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)

In [584]:
BOS_TOK, EOS_TOK = __text_tokenizer.bos_token, __text_tokenizer.eos_token

In [585]:
import os
from pathlib import Path
import re
import tempfile
import requests

def acquire_shakespeare_dataset():
    temp_dir = tempfile.gettempdir()
    shakespeare_filepath = Path(temp_dir)/"shakespeare.txt"
    if not os.path.exists(shakespeare_filepath):
        data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
        with open(shakespeare_filepath, 'w', encoding='utf-8') as f:
            f.write(requests.get(data_url).text)
    
    with open(shakespeare_filepath, 'r', encoding='utf-8') as f:
        data = f.read()

    # Split the dataset into each character's lines.
    # Continue taking lines until you have at least 250 words in the sample.
    # Add that sample to the dataset.
    characters_lines = re.split(r"\n\s*\n", data.strip())
    characters_lines = [BOS_TOK + line + EOS_TOK for line in characters_lines]
    MIN_WORDS_PER_BATCH = 250
    sample = [characters_lines[0]]
    num_words_in_sample = len(characters_lines[0].split())
    text_dataset = []
    i = 1
    while i < len(characters_lines):
        if num_words_in_sample > MIN_WORDS_PER_BATCH:
            text_dataset.append("".join(sample))
            num_words_in_sample -= len(sample[0].split())
            sample = sample[1:]
        sample += [characters_lines[i]]
        num_words_in_sample += len(characters_lines[i].split())
        i += 1

    return text_dataset

In [586]:
text_data = acquire_shakespeare_dataset()

In [587]:
from functools import partial
import torch
from torch.utils.data import Dataset

In [588]:
class TransformDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.transform(self.dataset[idx])

In [589]:
from abc import ABC, abstractmethod
from typing import Any

class Sequencer(ABC):
    @abstractmethod
    def __call__(self, sample: Any) -> Episode:
        """Given a sample, tokenizes it and returns a dict of
        tokens, targets, attention_mask, and modality."""
        pass

class TextSequencer(Sequencer):
    def __init__(self, tokenizer):   
        self.tokenizer = tokenizer
        
    def __call__(self, sample, **kwargs) -> Episode:
        tokenized =  self.tokenizer(sample, **kwargs)
        return TextData(**{
            "tokens": tokenized["input_ids"][:, :-1].squeeze(0),
            "targets": tokenized["input_ids"][:, 1:].squeeze(0),
            "attention_mask": tokenized["attention_mask"][:, :-1].squeeze(0),
        })

    def decode(self, *args, **kwargs):
        return self.tokenizer.func.decode(*args, **kwargs)

    @property
    def vocab_size(self):
        return self.tokenizer.func.vocab_size

In [590]:
text_sequencer = TextSequencer(_text_tokenizer)
text_dataset = TransformDataset(text_data, text_sequencer)

In [591]:
text_dataset[0]

TextData(tokens=tensor([50256,  5962, 22307,  ..., 50256, 50256, 50256]), targets=tensor([ 5962, 22307,    25,  ..., 50256, 50256, 50256]), attention_mask=tensor([1, 1, 1,  ..., 0, 0, 0]))

In [592]:
import torchvision.transforms.v2 as transforms

In [593]:
# First things first, let's get the images resized, cropped, and normalized.
image_transform = transforms.Compose([
    # No particular reason to use `transforms.Compose` here since we're only doing one transform. But it's nice to know about.
    transforms.RandomResizedCrop((192, 192), (0.5, 1.0)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])    
])

In [594]:
from operator import itemgetter
tta = itemgetter("tokens", "targets", "attention_mask")

class VQASequencer(Sequencer):
    def __init__(self, text_tokenizer, image_tokenizer):   
        self.text_tokenizer = text_tokenizer
        self.image_tokenizer = image_tokenizer
        
    def __call__(self, sample, **kwargs) -> Episode:
        question_tokens, question_targets, question_mask =  tta(self.text_tokenizer(sample["question"]))
        image_tokens, image_targets, image_mask =  tta(self.image_tokenizer(sample["image"]))
        answer_tokens, answer_targets, answer_mask =  tta(self.image_tokenizer(sample["answer"]))
        return {
            "tokens": tokenized["input_ids"][:, :-1].squeeze(0),
            "targets": tokenized["input_ids"][:, 1:].squeeze(0),
            "attention_mask": tokenized["attention_mask"][:, :-1].squeeze(0),
        }

    def decode(self, *args, **kwargs):
        return self.tokenizer.func.decode(*args, **kwargs)

    @property
    def vocab_size(self):
        return self.tokenizer.func.vocab_size

In [595]:
from dataclasses import dataclass
import random

In [596]:
torch.randn(5)

tensor([-1.3616,  0.3014,  0.3458,  0.6947, -0.0158])

In [597]:
a = FourRoomsStep(
    mission=torch.randn(5),
    image=torch.randn((3, 6, 6)),
    direction=torch.randn(5),
    separator=torch.randn(1),
    action=torch.randn(5)
)
b = FourRoomsStep(
    mission=torch.randn(5),
    image=torch.randn((3, 6, 6)),
    direction=torch.randn(5),
    separator=torch.randn(1),
    action=torch.randn(5)
)
a.stack(b).mission.size()

torch.Size([2, 5])

In [598]:
__text_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)
__text_tokenizer.pad_token = __text_tokenizer.eos_token
_text_tokenizer = partial(
    __text_tokenizer,
    max_length=SEQUENCE_LENGTH+1,
    truncation=True,
    padding="max_length",
    return_tensors="pt",
)

In [599]:
from einops import rearrange
def images_to_patches(images, patch_size=16):
    return rearrange(images, 'b c (h s1) (w s2) -> b (h w) (c s1 s2)', s1=patch_size, s2=patch_size)
def normalize_to_between_minus_one_plus_one(t: torch.Tensor):
    min_val, max_val = t.min(), t.max()
    if min_val == max_val:
        return torch.zeros_like(t)
    normalized = 2 * (t - min_val) / (max_val - min_val) - 1
    return normalized
# There's a small deviation in the NEKO codebase from the paper.
# The paper normalizes _per patch_. The NEKO codebase currently normalizes _per image_.
# https://github.com/eihli/NEKO/blob/master/gato/policy/embeddings.py#L38
# This notebook normalizeds per patch. That's what this utility helps.
def apply_along_dimension(func, dim, tensor):
    tensor = tensor.transpose(0, dim)
    shape = tensor.shape
    tensor = tensor.reshape(shape[0], -1)
    result = torch.stack([func(tensor[:, i]) for i in range(tensor.size(1))], dim=1)
    result = result.reshape(shape).transpose(0, dim)
    return result

In [600]:
class Tokenizer:
    def __init__(self, text_tokenizer):
        self.text_tokenizer = text_tokenizer

    def text(self, data, **kwargs):
        tokenized =  self.text_tokenizer(data, **kwargs)
        return TextData(**{
            "tokens": tokenized["input_ids"][:, :-1].squeeze(0),
            "targets": tokenized["input_ids"][:, 1:].squeeze(0),
            "attention_mask": tokenized["attention_mask"][:, :-1].squeeze(0),
        })

    def image(self, data):
        patches = images_to_patches(data, patch_size=16)
        # Hardcoding as a reminder to do something smarter
        SQUARE_ROOT_OF_PATCH_SIZE = 3.464
        xs = (
            apply_along_dimension(
                normalize_to_between_minus_one_plus_one, 2, patches
            )
            / SQUARE_ROOT_OF_PATCH_SIZE
        )
        # We don't predict images, but we need ys
        # becaues these image ys will be in our 
        # concatenated ys of text/image/action/etc...
        ys = torch.zeros(xs.shape[:2])
        ms = torch.zeros(xs.shape[:2])  # Same story as above.
        return ImageData(tokens=xs, targets=ys, attention_mask=ms)

    def discrete(self, data):
        # The Gato paper talks about offsetting their discrete tokens. But it
        # seems rather arbitrary whether you offset it and use a single Embedding
        # table or if you don't offset it and maintain two separate Embedding
        # tables. I'm not going to offset it.
        xs = data.to(torch.uint16)
        ys = torch.zeros(xs.shape[:2])
        ms = torch.zeros(xs.shape[:2])
        return DiscreteData(tokens=xs, targets=ys, attention_mask=ms)

    def continuous(self, data):
        raise Exception('TODO: Tokenizer.continuous')

In [601]:
tokenizer = Tokenizer(_text_tokenizer)

In [602]:
tokenizer.text(["Hi", "There"]).targets.shape

torch.Size([2, 1024])

In [603]:
tokenizer.image(torch.randn((2, 3, 192, 192))).tokens.shape

torch.Size([2, 144, 768])

In [604]:
tokenizer.discrete(torch.randint(5, (5,)))

DiscreteData(tokens=tensor([1, 0, 4, 4, 3], dtype=torch.uint16), targets=tensor([0., 0., 0., 0., 0.]), attention_mask=tensor([0., 0., 0., 0., 0.]))

In [605]:
# It's annoying – our text sequencer is very generic and can be re-used across almost any text dataset.
# Our VQA sequencer is less generic, but it could still probably be re-used across other question/answering dataset.
# Our agent sequencer is very specific. It seems like every agent dataset is going to have its own datastructures.
@dataclass
class FourRoomsSample:
    mission: TextData
    image: ImageData
    direction: DiscreteData
    separator: TextData
    action: DiscreteData

    def stack(self, other):
        return FourRoomsSample(
            mission=self.mission.stack(other.mission),
            image=self.image.stack(other.image),
            direction=self.direction.stack(other.direction),
            separator=self.separator.stack(other.separator),
            action=self.action.stack(other.action),
        )

class FourRoomsSequencer(Sequencer):
    def __init__(self, text_tokenizer, image_tokenizer, discrete_observation_tokenizer, discrete_action_tokenizer, image_transform):
        self.text_tokenizer = text_tokenizer
        self.image_tokenizer = image_tokenizer
        # The only difference between an observation and an action tokenizer is that the observation tokenizer returns a zero mask
        # because the Gato paper doesn't try to predict observations.
        self.discrete_observation_tokenizer = discrete_observation_tokenizer
        self.discrete_action_tokenizer = discrete_action_tokenizer
        self.image_transform = image_transform

    def __call__(self, episode_data, **kwargs) -> Episode:
        """What might a sample look like?

        https://minari.farama.org/datasets/D4RL/minigrid/fourrooms-v0/

        EpisodeData(
            observations = {
                direction: ndarray(20,),
                image: ndarray(20, 7, 7, 3),
                mission: ['reach the goal', 'reach the goal', ...]
            },
            actions = ndarray(19,)
        )
        """
        # We probably can't fit the entire episode in our context window, so
        # pick a random spot to start from
        i = random.randint(len(episode_data.actions))

        # Starting at that index, we'll continue adding observations to our context window until
        # we run out of space.
        mission = self.text_tokenizer(episode_data.observations["mission"][i])
        image = self.image_tokenizer(self.image_transform(episode_data.observations["image"][i]))
        direction = self.discrete_observation_tokenizer(episode_data.observations["direction"])
        separator = TextData(tokens=torch.tensor([91]), targets=torch.tensor([0]), mask=torch.tensor([0]))
        action = self.discrete_action_tokenizer(episode_data.actions[i])
        step = FourRoomsSample(mission=mission, image=image, direction=direction, separator=separator, action=action)
        steps = [step]

        observation_length = len(mission.tokens) + len(image.tokens) + len(direction.tokens) + len(separator.tokens) + len(action.tokens)

        current_sequence_length = observation_length
        while i < len(episode_data.actions) and current_sequence_length + observation_length < SEQUENCE_LENGTH:
            i += 1
            mission = tta(self.text_tokenizer(episode_data.observations["mission"][i]))
            image = tta(self.image_tokenizer(self.image_transform(episode_data.observations["image"][i])))
            direction = tta(self.discrete_observation_tokenizer(episode_data.observations["direction"]))
            separator = TextData(tokens=torch.tensor([91]), targets=torch.tensor([0]), mask=torch.tensor([0]))
            action = tta(self.discrete_action_tokenizer(episode_data.actions[i]))
            step = step.stack(FourRoomsSample(mission=mission, image=image, direction=direction, separator=separator, action=action))
            current_sequence_length += observation_length

        return sample

    def decode(self, *args, **kwargs):
        return self.tokenizer.func.decode(*args, **kwargs)

    @property
    def vocab_size(self):
        return self.tokenizer.func.vocab_size

In [606]:
import minari

minigrid_dataset = minari.load_dataset('D4RL/minigrid/fourrooms-v0', download=True)
env  = minigrid_dataset.recover_environment()

In [607]:
print("Observation space:", minigrid_dataset.observation_space)
print("Action space:", minigrid_dataset.action_space)
print("Total episodes:", minigrid_dataset.total_episodes)
print("Total steps:", minigrid_dataset.total_steps)

Observation space: Dict('direction': Discrete(4), 'image': Box(0, 255, (7, 7, 3), uint8), 'mission': Text(1, 14, charset=                                                              ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''(),,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdeeeffghijklmnnoopqrrssttuvwxyzz{}))
Action space: Discrete(7)
Total episodes: 590
Total steps: 10010


In [608]:
minigrid_dataset.set_seed(seed=123)

for i in range(5):
    # sample 5 episodes from the dataset
    episodes = minigrid_dataset.sample_episodes(n_episodes=5)
    # get id's from the sampled episodes
    ids = list(map(lambda ep: ep.id, episodes))
    print(f"EPISODE ID'S SAMPLE {i}: {ids}")

EPISODE ID'S SAMPLE 0: [31, 348, 9, 536, 400]
EPISODE ID'S SAMPLE 1: [103, 265, 544, 204, 477]
EPISODE ID'S SAMPLE 2: [302, 158, 14, 505, 522]
EPISODE ID'S SAMPLE 3: [240, 125, 371, 87, 435]
EPISODE ID'S SAMPLE 4: [468, 125, 305, 489, 469]


In [609]:
minigrid_dataset[0]

EpisodeData(id=0, total_steps=19, observations={direction: ndarray of shape (20,) and dtype int64, image: ndarray of shape (20, 7, 7, 3) and dtype uint8, mission: ['reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal']}, actions=ndarray of shape (19,) and dtype int64, rewards=ndarray of 19 floats, terminations=ndarray of 19 bools, truncations=ndarray of 19 bools, infos=dict with the following keys: [])

In [610]:
# First things first, let's get the images resized, cropped, and normalized.
image_transform = transforms.Compose([
    # No particular reason to use `transforms.Compose` here since we're only doing one transform. But it's nice to know about.
    transforms.RandomResizedCrop((192, 192), (0.5, 1.0)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])    
])

In [611]:
from minigrid.core import constants as mgc

In [612]:
import numpy as np

In [613]:
# Create lookup table
lut = np.zeros((256, 3), dtype=np.uint8)
for idx, color_name in mgc.IDX_TO_COLOR.items():
    lut[idx] = mgc.COLORS[color_name]

def minigrid_to_rgb(episode):
    """Convert discrete "image" observations into actual images.
    I'm expecting this will improve our image modality while not losing
    much. The downside is we can fit less in our context window. Note:
    We might need to overlay the color/type image (index 1) with the 
    state image (index 2), if we really don't want to lose any info."""
    # Apply lookup to second channel
    image = lut[episode.observations['image'][:, :, :, 1]]
    # Convert to PyTorch tensor and permute
    image = torch.from_numpy(image).permute(0, 3, 1, 2)
    return image

image_transform = transforms.Compose([
    # No particular reason to use `transforms.Compose` here since we're only doing one transform. But it's nice to know about.
    transforms.RandomResizedCrop((192, 192), (0.5, 1.0)),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])    
])

def minigrid_tokenizer(tokenizer, episode):
    image = image_transform(minigrid_to_rgb(episode))
    image = tokenizer.image(image)
    mission = tokenizer.text(episode.observations['mission'])
    direction = tokenizer.discrete(torch.from_numpy(episode.observations['direction']))
    separator = TextData(tokens=torch.tensor([91]), targets=torch.tensor([0]), attention_mask=torch.tensor([0]))
    action = tokenizer.discrete(torch.from_numpy(episode.actions))
    return FourRoomsSample(mission=mission, image=image, direction=direction, separator=separator, action=action)

In [614]:
obs = minigrid_dataset[0]
obs

EpisodeData(id=0, total_steps=19, observations={direction: ndarray of shape (20,) and dtype int64, image: ndarray of shape (20, 7, 7, 3) and dtype uint8, mission: ['reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal']}, actions=ndarray of shape (19,) and dtype int64, rewards=ndarray of 19 floats, terminations=ndarray of 19 bools, truncations=ndarray of 19 bools, infos=dict with the following keys: [])

In [615]:
minigrid_tokenize = partial(minigrid_tokenizer, tokenizer)

In [616]:
minigrid_dataset_xf = TransformDataset(minigrid_dataset, minigrid_tokenize)
minigrid_dataset_xf[0];

In [617]:
from torch.utils.data import DataLoader

In [631]:
fields(minigrid_dataset_xf[0].mission)

(Field(name='tokens',type=<class 'torch.Tensor'>,default=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='targets',type=<class 'torch.Tensor'>,default=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='attention_mask',type=<class 'torch.Tensor'>,default=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,default_factory=<dataclasses._MISSING_TYPE object at 0x7e069826d3d0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD))

In [618]:
def minigrid_collate_fn(batch):
    result = batch[0]
    for episode in batch[1:]:
        result = result.stack(episode)
    return result

In [621]:
minigrid_dataloader = DataLoader(minigrid_dataset_xf, collate_fn=minigrid_collate_fn, batch_size=BATCH_SIZE)

In [622]:
minigrid_batch = next(iter(minigrid_dataloader))

RuntimeError: stack expects each tensor to be equal size, but got [2, 20, 1051] at entry 0 and [29, 1024] at entry 1