# MiniGato

From the paper [A Generalist Agent](https://arxiv.org/abs/2205.06175).

The paper doesn't introduce a new architecture. Instead, the paper is all about tokenizing, embedding, and sequencing data from multiple modalities (text, image, proprioception) in such a way that it can be learned by a transformer.

Reproducing the paper is more of a software design exercise than an ML research exercise. How would you structure the data manipulation code – the tokenization, embedding, and sequencing of different modalities – in a way that's correct, easy to understand and extend, and performant?

## Imports

Just grouping and hiding these to avoid noise.

In [1]:
from itertools import cycle
import pdb
import random
from typing import List
from dataclasses import dataclass, fields
import datasets
from einops import rearrange
from functools import partial
from nano_gpt import GPT, GPTConfig
import numpy as np
import minari
import minigrid.core
from timm.models.resnetv2 import ResNetV2
import tiktoken
import torch
import torchvision.transforms.v2 as transforms
from torchvision.transforms.functional import pil_to_tensor
from torch.utils.data import Dataset, DataLoader

ModuleNotFoundError: No module named 'nano_gpt'

## Tokenization [§ 2.1](https://arxiv.org/pdf/2205.06175)

> There are infinite possible ways to transform data into tokens, including directly using the raw underlying byte stream.

We're going to start off with a certain set of tokenization strategies for a certain set of modalities. We might want to expand that in the future. When that happens, I don't want to have to edit code in the PyTorch Module that implements our model code. I'd rather be able to add a new class of modality that implements a "tokenization" signature.

In [None]:
class ModalToken(torch.Tensor):
    pass
class ModalTarget(torch.Tensor):
    pass
class ModalEmbedding(torch.Tensor):
    pass
class TextToken(ModalToken):
    pass
class ImageToken(ModalToken):
    pass
class DiscreteToken(ModalToken):
    pass
class ContinuousToken(ModalToken):
    pass
class TextTarget(ModalTarget):
    pass
class ImageTarget(ModalTarget):
    pass
class DiscreteTarget(ModalTarget):
    pass
class ContinuousTarget(ModalTarget):
    pass
class TextEmbedding(ModalEmbedding):
    pass
class ImageEmbedding(ModalEmbedding):
    pass
class DiscreteEmbedding(ModalEmbedding):
    pass
class ContinuousEmbedding(ModalEmbedding):
    pass

def text_token(data):
    return TextToken(data).to(torch.long)
def image_token(data):
    return ImageToken(data)
def discrete_token(data):
    return DiscreteToken(data).to(torch.long)
def continuous_token(data):
    return ContinuousToken(data)
def text_target(data):
    return TextTarget(data).to(torch.long)
def image_target(data):
    return ImageTarget(data)
def discrete_target(data):
    return DiscreteTarget(data).to(torch.long)
def continuous_target(data):
    return ContinuousTarget(data)
def text_embedding(data):
    return TextEmbedding(data)
def image_embedding(data):
    return ImageEmbedding(data)
def discrete_embedding(data):
    return DiscreteEmbedding(data)
def continuous_embedding(data):
    return ContinuousEmbedding(data)

In [None]:
class Modality:
    tokens: ModalToken
    targets: ModalTarget
    attention_mask: torch.Tensor
    embeddings: ModalEmbedding = ModalEmbedding([])

    def __repr__(self):
        return f"{type(self).__name__}(\n\t{self.origin!r}\n\t{self.tokens!r}\n\t{self.targets!r}\n\t{self.attention_mask!r}\n\t{self.embeddings!r}\n)"
        
    @property
    def size(self):
        return self.tokens.size(0)

    def __len__(self):
        return self.size

    def to(self, device):
        """Return new instance of class with all fields moved to device."""
        return type(self)(
            tokens=type(self.tokens)(self.tokens.to(device)),
            targets=type(self.targets)(self.targets.to(device)),
            attention_mask=type(self.attention_mask)(self.attention_mask.to(device)),
            embeddings=type(self.embeddings)(self.embeddings.to(device)),
        )

    def clone(self):
        return type(self)(
            tokens=type(self.tokens)(self.tokens.clone()),
            targets=type(self.targets)(self.targets.clone()),
            attention_mask=type(self.attention_mask)(self.attention_mask.clone()),
            embeddings=type(self.embeddings)(self.embeddings.clone()),
        )

class TextModality(Modality):
    origin: str
    tokens: TextToken
    targets: TextTarget
    attention_mask: torch.Tensor
    embeddings: TextEmbedding = TextEmbedding([])

    def __init__(self, origin, tokens, targets, attention_mask, embeddings=None):
        embeddings = TextEmbedding(embeddings) if embeddings is not None else TextEmbedding(torch.tensor([]))
        self.origin, self.tokens, self.targets, self.attention_mask, self.embeddings = origin, TextToken(tokens), TextTarget(targets), attention_mask, embeddings

class ImageModality(Modality):
    origin: torch.Tensor
    tokens: ImageToken
    targets: ImageTarget
    attention_mask: torch.Tensor
    embeddings: ImageEmbedding = ImageEmbedding([])

    def __init__(self, origin, tokens, targets, attention_mask, embeddings=None):
        embeddings = ImageEmbedding(embeddings) if embeddings is not None else ImageEmbedding(torch.tensor([]))
        self.origin, self.tokens, self.targets, self.attention_mask, self.embeddings = origin, ImageToken(tokens), ImageTarget(targets), attention_mask, embeddings    

class DiscreteModality(Modality):
    origin: torch.Tensor
    tokens: DiscreteToken
    targets: DiscreteTarget
    attention_mask: torch.Tensor
    embeddings: DiscreteEmbedding = DiscreteEmbedding([])

    def __init__(self, origin, tokens, targets, attention_mask, embeddings=None):
        embeddings = DiscreteEmbedding(embeddings) if embeddings is not None else DiscreteEmbedding(torch.tensor([]))
        self.origin, self.tokens, self.targets, self.attention_mask, self.embeddings = origin, DiscreteToken(tokens), DiscreteTarget(targets), attention_mask, embeddings    

class ContinuousModality(Modality):
    origin: torch.Tensor
    tokens: ContinuousToken
    targets: ContinuousTarget
    attention_mask: torch.Tensor
    embeddings: ContinuousEmbedding = ContinuousEmbedding([])

    def __init__(self, origin, tokens, targets, attention_mask, embeddings=None):
        embeddings = ContinuousEmbedding(embeddings) if embeddings is not None else ContinuousEmbedding(torch.tensor([]))
        self.origin, self.tokens, self.targets, self.attention_mask, self.embeddings = origin, ContinuousToken(tokens), ContinuousTarget(targets), attention_mask, embeddings    

In [None]:
@dataclass
class Episode:
    @property
    def size(self):
        return sum(getattr(self, field.name).size for field in fields(self))

    def __len__(self):
        return self.size

    def __getitem__(self, key):
        if isinstance(key, int):
            i = 0
            for modality in [getattr(self, field.name) for field in fields(self)]:
                if i + modality.size <= key:
                    i += modality.size
                else:
                    break
            return type(modality)(
                origin=modality.origin[key - i],
                tokens=modality.tokens[[key - i]],
                targets=modality.targets[[key - i]],
                attention_mask=modality.attention_mask[[key - i]],
                embeddings=modality.embeddings[[key - i]] if len(modality.embeddings) > key - i else modality.embeddings,
            )
        elif isinstance(key, slice):
            start, stop, step = key.start or 0, key.stop or len(self), key.step or 1
            i, results = start, []
            while i < stop:
                results.append(self[i])
                i += step
            return results
        else:
            raise TypeError(f"Invalid argument type `{type(key)}`.")

    def clone(self):
        return type(self)(
            **{
                field.name: getattr(self, field.name).clone()
                for field in fields(self)
            }
        )

class AgentEpisode(Episode):
    pass

@dataclass
class FourRoomsEpisode(AgentEpisode):
    mission: TextModality
    image: ImageModality    
    direction: DiscreteModality
    action: DiscreteModality

@dataclass
class GenericTextEpisode(Episode):
    text: TextModality

@dataclass
class GenericVQAEpisode(Episode):
    question: TextModality
    image: ImageModality
    answer: TextModality

#### Example

In [2]:
mission = TextModality("012", torch.arange(3), torch.arange(3), torch.arange(3))
direction = TextModality(torch.arange(3), torch.arange(3)+5, torch.arange(3)+5, torch.arange(3)+5)
image = ImageModality(torch.arange(3), torch.randn((3, 8, 8)), torch.randn((3,)), torch.randn((3,)))
action = DiscreteModality(torch.arange(3), torch.arange(3)+10, torch.arange(3)+10, torch.arange(3)+10)
ep = FourRoomsEpisode(mission, direction, image, action)

NameError: name 'TextModality' is not defined

In [3]:
[ep[i] for i in range(ep.size)]

NameError: name 'ep' is not defined

Sweet. Now we can index into an "Episode" as if the episode were sequenced.

## Series

In [190]:
@dataclass
class Series:
    episodes: List[Episode]
    sequence: torch.Tensor = torch.tensor([])

    def __len__(self):
        return self.size

    def __getitem__(self, key):
        if isinstance(key, int):
            i = 0
            for episode in self.episodes:
                if i + episode.size <= key:
                    i += episode.size
                else:
                    break
            return episode[key - i]
        elif isinstance(key, slice):
            start, stop, step = key.start or 0, key.stop or len(self), key.step or 1
            i, results = start, []
            while i < stop:
                results.append(self[i])
                i += step
            return results
        else:
            raise TypeError(f"Invalid argument type `{type(key)}`.")

    @property
    def size(self):
        return sum(episode.size for episode in self.episodes)

    @property
    def n_episodes(self):
        return len(self.episodes)

Sweet. Now we can index into a series as if it were sequenced.

## Tokenization

In [301]:
def images_to_patches(images, patch_size=16):
    return rearrange(images, 'c (h s1) (w s2) -> (h w) (c s1 s2)', s1=patch_size, s2=patch_size)

def normalize_to_between_minus_one_plus_one(t: torch.Tensor):
    min_val, max_val = t.min(), t.max()
    if min_val == max_val:
        return torch.zeros_like(t)
    normalized = 2 * (t - min_val) / (max_val - min_val) - 1
    return normalized

def apply_along_dimension(func, dim, tensor):
    tensor = tensor.transpose(0, dim)
    shape = tensor.shape
    tensor = tensor.reshape(shape[0], -1)
    result = torch.stack([func(tensor[:, i]) for i in range(tensor.size(1))], dim=1)
    result = result.reshape(shape).transpose(0, dim)
    return result

def mu_law_encode(x, M=256, mu=100):
    M = torch.tensor(M, dtype=x.dtype)
    mu = torch.tensor(mu, dtype=x.dtype)
    x_mu = torch.sign(x) * torch.log(torch.abs(x) * mu + 1.0)
    x_mu = x_mu / torch.log(M * mu + 1.0)
    return x_mu

def mu_law_decode(x_mu, M=256, mu=100):
    M = torch.tensor(M, dtype=x_mu.dtype)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
    x = torch.sign(x_mu) * (torch.exp(torch.abs(x_mu) * torch.log(M * mu + 1.0)) - 1.0) / mu
    return x

class Tokenizer:
    def __init__(self, text_tokenizer):
        self.text_tokenizer = text_tokenizer
        self.n_text = text_tokenizer.n_vocab
        self.n_discrete = 1024
        self.eod_token = 1023

    def encode_text(self, text, attend=True):
        assert len(text) > 1
        tokens = self.text_tokenizer.encode(text)
        attention_mask = torch.ones(len(tokens)+1) if attend else torch.zeros(len(tokens)+1)        
        return TextModality(
            origin=[self.text_tokenizer.decode([self.text_tokenizer.eot_token])] + [self.text_tokenizer.decode([x]) for x in tokens],
            tokens=text_token([self.text_tokenizer.eot_token] + tokens),
            targets=text_target(tokens + [self.text_tokenizer.eot_token]),
            attention_mask=attention_mask,
            embeddings=text_embedding(torch.tensor([]))
        )

    def decode_text(self, tokens):
        return self.text_tokenizer.decode(tokens.tolist())

    def encode_discrete(self, data, attend=False):
        attention_mask = torch.ones(len(data)+1) if attend else torch.zeros(len(data)+1)
        return DiscreteModality(
            origin=data,
            tokens=discrete_token([self.eod_token] + data),
            targets=discrete_target(data + [self.eod_token]),
            attention_mask=attention_mask,
            embeddings = discrete_embedding(torch.tensor([]))
        )

    def decode_discrete(self, tokens):
        return tokens.tolist()

    def encode_image(self, image, attend=False):
        patches = images_to_patches(image, patch_size=16)
        # TODO: Hardcoding as a reminder to do something smarter
        SQUARE_ROOT_OF_PATCH_SIZE = 3.464
        xs = (
            apply_along_dimension(
                normalize_to_between_minus_one_plus_one, 1, patches
            )
            / SQUARE_ROOT_OF_PATCH_SIZE
        )
        # We don't predict images, but we need ys
        # becaues these image ys will be in our
        # concatenated ys of text/image/action/etc...
        ys = torch.zeros(len(patches))
        ms = torch.zeros(len(patches))
        return ImageModality(
            origin=patches,
            tokens=ImageToken(xs),
            targets=ImageTarget(ys),
            attention_mask=ms,
            embeddings=ImageEmbedding(torch.tensor([]))
        )

In [302]:
text_encoding = tiktoken.get_encoding("r50k_base")
tokenizer = Tokenizer(text_encoding)

## Datasets

In [303]:
four_rooms_dataset = minari.load_dataset('D4RL/minigrid/fourrooms-v0', download=True)

In [304]:
class TransformDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.transform(self.dataset[idx])

In [305]:
# Some FourRooms/Minigrid-specific stuff to turn
# a 7x7x3 non-pixel observation into an pixel/image observation.
lut = np.zeros((256, 3), dtype=np.uint8)
for idx, color_name in minigrid.core.constants.IDX_TO_COLOR.items():
    lut[idx] = minigrid.core.constants.COLORS[color_name]

def four_rooms_to_rgb(image):
    """Convert discrete "image" observations into actual images.
    I'm expecting this will improve our image modality while not losing
    much. The downside is we can fit less in our context window. Note:
    We might need to overlay the color/type image (index 1) with the
    state image (index 2), if we really don't want to lose any info."""
    # Apply lookup to second channel
    image = lut[image[:, :, 1]]
    # Convert to PyTorch tensor and permute
    image = torch.from_numpy(image).permute(2, 0, 1)
    return image

In [306]:
image_transform = transforms.Compose([
    transforms.ToDtype(torch.float32, scale=True),    
    transforms.RandomResizedCrop((192, 192), (0.6, 1.0)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [307]:
SEQUENCE_LENGTH = 1024

In [308]:
def four_rooms_transform(episode):
    missions = [tokenizer.encode_text(mission, attend=False) for mission in episode.observations["mission"][:-1]]
    images = [tokenizer.encode_image(image_transform(four_rooms_to_rgb(image))) for image in episode.observations["image"][:-1]]
    directions = [tokenizer.encode_discrete([d]) for d in episode.observations["direction"][:-1].tolist()]
    actions = [tokenizer.encode_discrete([d], attend=True) for d in episode.actions.tolist()]
    series = Series([
        FourRoomsEpisode(mission=mission, image=image, direction=direction, action=action)
        for mission, image, direction, action in zip(missions, images, directions, actions)
    ])
    return series

def four_rooms_collate_fn(batch):
    pass

In [309]:
four_rooms_dataset_xf = TransformDataset(four_rooms_dataset, four_rooms_transform)

In [310]:
four_rooms_dataset[0]

EpisodeData(id=0, total_steps=19, observations={direction: ndarray of shape (20,) and dtype int64, image: ndarray of shape (20, 7, 7, 3) and dtype uint8, mission: ['reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal', 'reach the goal']}, actions=ndarray of shape (19,) and dtype int64, rewards=ndarray of 19 floats, terminations=ndarray of 19 bools, truncations=ndarray of 19 bools, infos=dict with the following keys: [])

In [311]:
series = four_rooms_dataset_xf[0]
len(series), series.n_episodes, series.episodes[0].direction.tokens.shape, series.episodes[0].action.tokens.shape

(2888, 19, torch.Size([2]), torch.Size([2]))

In [312]:
Series(series.episodes[1:4]).size

456

In [313]:
[m for m in series][:5]

[TextModality(
 	'<|endoftext|>'
 	TextToken([50256])
 	TextTarget([16250])
 	tensor([0.])
 	TextEmbedding([])
 ),
 TextModality(
 	'reach'
 	TextToken([16250])
 	TextTarget([262])
 	tensor([0.])
 	TextEmbedding([])
 ),
 TextModality(
 	' the'
 	TextToken([262])
 	TextTarget([3061])
 	tensor([0.])
 	TextEmbedding([])
 ),
 TextModality(
 	' goal'
 	TextToken([3061])
 	TextTarget([50256])
 	tensor([0.])
 	TextEmbedding([])
 ),
 ImageModality(
 	tensor([ 2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,
          2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.2489,  2.24

In [314]:
four_rooms_dataloader = DataLoader(four_rooms_dataset_xf, batch_size=4)

In [315]:
next(iter(four_rooms_dataloader))

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class '__main__.Series'>

In [None]:
vqa_dataset = datasets.load_dataset("eihli/micro-ok-vqa")

In [None]:
def vqa_transform(sample):
    image = tokenizer.encode_image(image_transform(pil_to_tensor(sample["image"])))
    answer = tokenizer.encode_text(random.choice(sample["answers"])["answer"])
    question = tokenizer.encode_text(sample["question"], attend=False)
    episode = GenericVQAEpisode(question=question, image=image, answer=answer)
    series = Series([episode])
    return series

In [316]:
vqa_dataset_xf = TransformDataset(vqa_dataset["train"], vqa_transform)

In [317]:
series = vqa_dataset_xf[0]
len(series), series.n_episodes, series.episodes[0].question.tokens.shape, series.episodes[0].answer.attention_mask, series.episodes[0].image.tokens.shape

(159, 1, torch.Size([11]), tensor([1., 1., 1., 1.]), torch.Size([144, 768]))

## Embedding

In [318]:
EMBEDDING_DIMS = 768

In [319]:
embed_image = ResNetV2(layers=[3, 4, 6, 3], num_classes=EMBEDDING_DIMS)
embed_text = torch.nn.Embedding(tokenizer.n_text, EMBEDDING_DIMS)
embed_discrete = torch.nn.Embedding(tokenizer.n_discrete, EMBEDDING_DIMS)

In [320]:
image_tokens = series.episodes[0].image.tokens
patch_tokens = image_tokens.view(-1, 3, 16, 16)
image_embedding(patch_tokens).shape

torch.Size([144, 3, 16, 16])

In [321]:
question_tokens = series.episodes[0].question.tokens
text_embedding(question_tokens).shape

torch.Size([11])

In [322]:
series = four_rooms_dataset_xf[0]
len(series), series.n_episodes, series.episodes[0].direction.tokens.shape, series.episodes[0].action.tokens.shape

(2888, 19, torch.Size([2]), torch.Size([2]))

In [323]:
class Embedder:
    def __init__(self, embed_text, embed_image, embed_discrete):
        self._embed_text = embed_text
        self._embed_image = embed_image
        self._embed_discrete = embed_discrete

    def embed(self, modality):
        if isinstance(modality, TextModality):
            return self.embed_text(modality)
        elif isinstance(modality, DiscreteModality):
            return self.embed_discrete(modality)
        elif isinstance(modality, ImageModality):
            return self.embed_image(modality)
        elif isinstance(modality, Series):
            return self.embed_series(modality)
        else:
            raise NotImplementedError()

    def embed_text(self, text: TextModality):     
        return TextModality(
            origin=text.origin,
            tokens=text.tokens,
            targets=text.targets,
            attention_mask=text.attention_mask,
            embeddings=self._embed_text(text.tokens)
        )

    def embed_discrete(self, discrete: DiscreteModality):     
        return DiscreteModality(
            origin=discrete.origin,
            tokens=discrete.tokens,
            targets=discrete.targets,
            attention_mask=discrete.attention_mask,
            embeddings=self._embed_discrete(discrete.tokens)
        )

    def embed_continuous(self, continuous: DiscreteModality):     
        return ContinuousModality(
            origin=continuous.origin,
            tokens=continuous.tokens,
            targets=continuous.targets,
            attention_mask=continuous.attention_mask,
            embeddings=self._embed_discrete(continuous.tokens)
        )

    def embed_image(self, image: ImageModality):     
        return ImageModality(
            origin=image.origin,
            tokens=image.tokens,
            targets=image.targets,
            attention_mask=image.attention_mask,
            embeddings=self._embed_image(image.tokens.view(-1, 3, 16, 16))
        )

    def embed_series(self, series):
        fs = fields(

In [324]:
embedder = Embedder(embed_text, embed_image, embed_discrete)

In [325]:
[embedder.embed(s) for s in series[:5]]

[TextModality(
 	'<|endoftext|>'
 	TextToken([50256])
 	TextTarget([16250])
 	tensor([0.])
 	TextEmbedding([[ 1.3377,  0.3706,  0.2434,  0.3984, -1.2438, -0.3400, -1.7282,
                  0.5273,  0.7739, -0.3889,  0.3851, -0.5963,  0.4818, -2.2235,
                 -0.9197, -0.7364,  0.6347, -1.3343, -1.6939,  0.1525, -1.5823,
                 -1.1443, -1.0015, -0.6222, -1.7076, -0.6444, -0.7920, -0.4780,
                  0.3504,  0.6423,  2.0598, -0.1107,  1.2266, -1.2802, -1.6681,
                  1.3518, -1.1298,  1.7424, -1.3395, -0.2191,  1.4888,  0.8040,
                  0.5412, -0.1471, -0.9120,  0.0892,  0.2964, -0.9585, -0.5082,
                  0.4095,  1.4428, -0.8667, -0.5916,  0.1458,  0.2127, -0.6352,
                  1.4051, -2.3234, -0.8480, -1.2410, -1.6334,  0.6584, -0.6934,
                 -1.8706, -0.7050, -0.3073,  0.2233, -0.6093,  0.0982, -1.1815,
                  0.4024, -0.6770, -1.3918,  1.0473, -1.3251,  0.4577,  0.4819,
                  1.4207, -0

In [352]:
transformer_config = GPTConfig(n_head=8, n_embd=512)
gpt = GPT(transformer_config)

number of parameters: 63.59M


In [355]:
gpt.config.n_embd

512

In [369]:
@dataclass
class MiniGatoConfig:
    embedding_dim: int
    sequence_length: int
    tokenizer: Tokenizer
    transformer_config: GPTConfig
    transformer: GPT

def init_default_config() -> MiniGatoConfig:
    transformer_config = GPTConfig(n_head=8, n_embd=512)
    text_encoding = tiktoken.get_encoding("r50k_base")
    tokenizer = Tokenizer(text_encoding)
    return MiniGatoConfig(
        embedding_dim=768,
        sequence_length=1024,
        tokenizer=tokenizer,
        transformer_config=transformer_config,
        transformer=GPT(transformer_config),
    )

default_config = init_default_config()

class MiniGato(torch.nn.Module):
    def __init__(self, config: MiniGatoConfig=default_config):
        super().__init__()
        self.config = config
        self.sequence_length = self.config.sequence_length
        self.embed_image = ResNetV2(layers=[3, 4, 6, 3], num_classes=EMBEDDING_DIMS)
        self.embed_text = torch.torch.nn.Embedding(tokenizer.n_text, EMBEDDING_DIMS)
        self.embed_discrete = torch.torch.nn.Embedding(tokenizer.n_discrete, EMBEDDING_DIMS)
        self.embedder = Embedder(embed_text=self.embed_text, embed_image=self.embed_image, embed_discrete=self.embed_discrete)
        self.transformer = self.config.transformer
        self.lm_head = torch.nn.Linear(self.transformer.config.n_embd, self.config.tokenizer.n_text + self.config.tokenizer.n_discrete)    

    def forward(self, batch, targets=None):
        batch = [
            sample.embed(self.embedder).sequence(self.sequence_length) for sample in batch
        ]
        xs, ys, ms = map(torch.stack, zip(*batch))
        xs, ys, ms = [x.to(device) for x in [xs, ys, ms]]
        out, _ = self.transformer(emb=xs)
        if targets is not None:
            logits = self.lm_head(out.last_hidden_state)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            predicted = self.lm_head(out.last_hidden_state)
        return predicted, ys, ms

number of parameters: 63.59M


In [370]:
def infinite_dataloader(fn):
    it = iter(fn())
    while True:
        try:
            yield next(it)
        except StopIteration:
            it = iter(fn())

In [371]:
BATCH_SIZE=4

In [372]:
dataloaders = [
    infinite_dataloader(partial(DataLoader, four_rooms_dataset_xf, batch_size=BATCH_SIZE, collate_fn=lambda x: x)),
    infinite_dataloader(partial(DataLoader, vqa_dataset_xf, batch_size=BATCH_SIZE, collate_fn=lambda x: x)),
]
dl_it = cycle(dataloaders)

In [373]:
len(vqa_dataset["train"])

80

In [374]:
config = init_default_config()
model = MiniGato(config)

number of parameters: 63.59M


In [385]:
dl = next(dl_it)
batch = next(dl)

In [386]:
embed_image = ResNetV2(layers=[3, 4, 6, 3], num_classes=EMBEDDING_DIMS)
embed_text = torch.torch.nn.Embedding(tokenizer.n_text, EMBEDDING_DIMS)
embed_discrete = torch.torch.nn.Embedding(tokenizer.n_discrete, EMBEDDING_DIMS)
embedder = Embedder(embed_text=embed_text, embed_image=embed_image, embed_discrete=embed_discrete)

In [393]:
series = batch[0]
questions = [episode.question.tokens for episode in series.episodes]
images = [episode.image.tokens for episode in series.episodes]
answers = [episode.answer.tokens for episode in series.episodes]

In [398]:
len(series.episodes)

1

In [396]:
series.episodes[0].question.tokens

TextToken([50256,  2061,   318,   262, 42658,  2349,   286,   262, 32749,  1444,
              30])

In [409]:
embed_text(torch.concat(questions))
embed_image(torch.concat(images).view(-1, 3, 16, 16))
embed_text(torch.concat(answers))

TextToken([[ 2.0528,  1.5425,  0.3466,  ...,  0.5679,  0.0906,  0.7212],
           [ 0.6527,  2.0921,  1.0566,  ...,  0.5998, -0.4714, -0.1266],
           [-0.6124, -0.4198,  0.2808,  ...,  0.1288, -0.1279, -0.4592],
           [ 0.9156, -0.3364,  0.9264,  ...,  2.0148,  1.0986,  2.0607]],
          grad_fn=<AliasBackward0>)

In [388]:
foo = [embedder.embed(token) for token in batch[0]]

In [379]:
batch[0]

Series(episodes=[FourRoomsEpisode(mission=TextModality(
	['<|endoftext|>', 'reach', ' the', ' goal']
	TextToken([50256, 16250,   262,  3061])
	TextTarget([16250,   262,  3061, 50256])
	tensor([0., 0., 0., 0.])
	TextEmbedding([])
), image=ImageModality(
	tensor([[ 2.2489,  2.2489,  2.2489,  ..., -1.8044, -1.8044, -1.8044],
        [ 2.2489,  2.2489,  2.2489,  ..., -1.8044, -1.8044, -1.8044],
        [ 2.2489,  2.2489,  2.2489,  ..., -1.8044, -1.8044, -1.8044],
        ...,
        [-0.4054, -0.4054, -0.4054,  ..., -0.0615, -0.0615, -0.0615],
        [-0.4054, -0.4054, -0.4054,  ..., -0.0615, -0.0615, -0.0615],
        [-0.4054, -0.4054, -0.4054,  ..., -0.0615, -0.0615, -0.0615]])
	ImageToken([[ 0.2887,  0.2887,  0.2887,  ..., -0.2575, -0.2575, -0.2575],
            [ 0.2887,  0.2887,  0.2887,  ..., -0.2575, -0.2575, -0.2575],
            [ 0.2887,  0.2887,  0.2887,  ..., -0.2575, -0.2575, -0.2575],
            ...,
            [-0.2887, -0.2887, -0.2887,  ...,  0.2887,  0.2887,  0.2887]

In [None]:
optimizer.zero_grad()
predicted, targets, attention_mask = model(batch)

In [None]:
model.train()
for i in tqdm(range(self.num_iterations)):
    dl = next(dl_it)
    batch = next(dl)
    optimizer.zero_grad()
    predicted, targets, attention_mask = model(batch)
    self.losses.append(loss.item())
    loss.backward()
    if self.scheduler:
        self.scheduler.step()
    self.optimizer.step()