In [1]:
%load_ext autoreload
%autoreload

In [2]:
import importlib
import gato.policy.mini_gato as mg
from datasets import load_dataset
import requests
import torch
import torch.nn.functional as F
from torchvision.transforms import ToTensor, Resize, RandomCrop
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
import os
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

# Working demo

First, a quick demo that this works. The datasets and model parameters are hardcoded in `mini_gato.py` for now.

In [9]:
model = mg.init_model()

In [10]:
model, lm_head, optimizer, accelerator, text_dataloader, vqa_dataloader = mg.train(model)

Epoch [0/40], Loss: 10.968055725097656
Epoch [10/40], Loss: 6.313449859619141
Epoch [20/40], Loss: 7.09780216217041
Epoch [30/40], Loss: 6.6844635009765625


# § 2.1 Tokenization

## Text

> There are infinite possible ways to transform data into tokens, including directly using the raw underlying byte stream. Below we report the tokenization scheme we found to produce the best results for Gato at the current scale using contemporary hardware and model architectures.
> Text is encoded via SentencePiece (Kudo & Richardson, 2018) with 32000 subwords into the integer range [0, 32000).
> ...

For this example, we'll use GPT2. The only thing to note as you change tokenizers is that discrete/continuous values get tokenized to the 1024 numbers after the vocab size (32000 to 33024 in the case of SentencePiece). So, you'll need to make that update as you change tokenizers.

GPT2 has a vocab size of 50256, so our discrete/continuous values will tokenize to the range 50256 to 51280.

### Example text dataset/dataloader

Here's a couple of example text datasets/dataloaders.

#### Dataset

How you get the dataset doesn't much matter. All that matters is:

- It's an iterator (we expect to be using datasets too large to fit in memory).
- It has train/valid/test splits.

In [11]:
wikitext_dataset = load_dataset(path="wikitext", name="wikitext-2-v1", streaming=True)

In [12]:
wikitext_dataset

IterableDatasetDict({
    test: IterableDataset({
        features: ['text'],
        n_shards: 1
    })
    train: IterableDataset({
        features: ['text'],
        n_shards: 1
    })
    validation: IterableDataset({
        features: ['text'],
        n_shards: 1
    })
})

#### Cleaning and transforming the dataset

The Wikitext dataset contains a lot of samples that are empty.

We can remove those with a call to `.filter(lambda: x: x["text"] != '')`.

In [14]:
next(iter(wikitext_dataset["train"]))

{'text': ''}

In [16]:
next(
    iter(
        wikitext_dataset["train"]
          .filter(lambda x: x["text"] != '')
    )
)

{'text': ' = Valkyria Chronicles III = \n'}

Remember, though, each dataset is unique. 

This filter is necessary and works for wikitext, but it might not be the right filter to use for some other dataset. That's why it's important to have flexible api's, like `filter` and `map`, and a solid set of _composable_ utility functions, like `is_empty` and `not`.

It's debatable whether we should tokenize here, at the stage where we're working with the Dataset, or somewhere else. The dimensions you might need to consider are performance, complexity, and customizability. I'm choosing to tokenize at the Dataset-level for now. But keep in mind that it might not be a hard requirement. As we proceed, consider "question_type", "confidence"]),
    baconsequences of doing so.

Tokenizing can be a simple utility function that we can pass to `map`.

In [18]:
from transformers import GPT2TokenizerFast
text_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
text_tokenizer.pad_token = text_tokenizer.eos_token

In [21]:
next(
    iter(
        wikitext_dataset["train"]
          .filter(lambda x: x["text"] != '')
          .map(lambda x: text_tokenizer(x["text"], truncation=True, padding="max_length", max_length=16))
    )
)

{'text': ' = Valkyria Chronicles III = \n',
 'input_ids': [796,
  569,
  18354,
  7496,
  17740,
  6711,
  796,
  220,
  198,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]}

#### DataLoader

Once we have the dataset, the DataLoader's job is easy. It simply grabs `batch_size` number of samples from the Dataset and "collates" (instead of `[{text: "foo"}, {text: "bar"}, ...]`, `{text: ["foo", "bar", ...]}`).

In [22]:
text_dataset = (
    load_dataset(path="wikitext", name="wikitext-2-v1", streaming=True)
    .filter(mg.not_empty)
    .map(mg.tokenize, batched=True, batch_size=1000)
)
text_dataloader = DataLoader(
    text_dataset["train"], batch_size=2, collate_fn=mg.collate_fn
)
text_batch = next(iter(text_dataloader))

In [23]:
text_batch["input_ids"].shape, text_batch

(torch.Size([2, 1024]),
 {'input_ids': tensor([[  796,   569, 18354,  ..., 50256, 50256, 50256],
          [ 2311,    73, 13090,  ..., 50256, 50256, 50256]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]])})

## Images

Images are first transformed into sequences of non-overlapping 16 × 16 patches in raster order, as done in ViT (Dosovitskiy et al., 2020). Each pixel in the image patches is then normalized between [−1, 1] and divided by the square-root of the patch size (i.e. √16 = 4).

First, let's load a tiny version of a VQA dataset so that we can grab an example and verify we're patching images correctly.


In [None]:
micro_vqa = load_dataset("eihli/micro-ok-vqa")

In [None]:
micro_vqa

In [None]:
img = micro_vqa['train'][0]['image']

In [None]:
img, img.size

What does the image look like? What do we expect to see when we patch it?

In [None]:
plt.imshow(img)

In [None]:
to_tensor = ToTensor()
resize = Resize(256)
random_crop = RandomCrop(256)

In [None]:
img = micro_vqa['train'][0]['image']
img = to_tensor(random_crop(resize(img))).unsqueeze(0)

## Converting to patches

In [None]:
patches = mg.images_to_patches(img)

In [None]:
patches.shape

In [None]:
patches = patches.view(1, 256, 16, 16, 3)

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for row in range(4):
    for col in range(4):
        axes[row][col].imshow(patches[0][row * 16 + col])

And then, to verify we can go both ways, let's convert the patches back to the original image.

In [None]:
patches.shape

In [None]:
img = mg.patches_to_image(patches[0], (3, 256, 256))

In [None]:
plt.imshow(img)

In [None]:
text_batch = next(iter(text_dataloader))
text_sequence, text_attention_mask, text_targets = mg.embed_and_sequence_text(text_batch)
text_sequence.shape

In [None]:
x = torch.concat([text_sequence])
y = torch.concat([text_targets])
m = torch.concat([text_attention_mask])

In [None]:
o = model(inputs_embeds=x)

In [None]:
p = lm_head(o.last_hidden_state)

In [None]:
p.shape

In [None]:
predicted_tokens = F.softmax(p, dim=2).argmax(dim=2)

In [None]:
text_batch["input_ids"][1][1:200], predicted_tokens[1][:199]

In [None]:
micro_vqa = load_dataset("eihli/micro-ok-vqa")

In [None]:
img = micro_vqa['train'][0]['image']

In [None]:
img.size

In [None]:
to_tensor = ToTensor()
resize = Resize(256)
random_crop = RandomCrop(256)

In [None]:
img = micro_vqa['train'][0]['image']
img = to_tensor(random_crop(resize(img))).unsqueeze(0)

In [None]:
patches = mg.images_to_patches(img)

In [None]:
importlib.reload(mg)

In [None]:
text_dataset = (
    load_dataset(path="wikitext", name="wikitext-2-v1", streaming=True)
    .filter(mg.not_empty)
    .map(mg.tokenize, batched=True, batch_size=1000)
)
text_dataloader = DataLoader(
    text_dataset["train"], batch_size=2, collate_fn=mg.collate_fn
)
text_batch = next(iter(text_dataloader))

In [None]:
vqa_dataset = load_dataset("eihli/micro-ok-vqa", streaming=True).with_format(
    "torch"
)
vqa_dataloader = DataLoader(
    vqa_dataset["train"]
    .map(mg.vqa_img_transform)
    .map(mg.vqa_qa_transform, batched=True, batch_size=8)
    .map(mg.vqa_img_tokenize, batched=True, batch_size=8, remove_columns=["answers", "question", "answer_type", "question_type", "confidence"]),
    batch_size=2,
)
vqa_batch = next(iter(vqa_dataloader))

In [None]:
text_sequence, text_attention_mask, text_targets = mg.embed_and_sequence_text(text_batch)
vqa_sequence, vqa_attention_mask, vqa_targets = mg.embed_and_sequence_vqa(vqa_batch)
text_sequence.shape, vqa_sequence.shape, vqa_targets.shape

In [None]:
importlib.reload(mg)

In [None]:
model = mg.init_model()
lm_head = torch.nn.Linear(model.config.hidden_size, mg.text_tokenizer.vocab_size)

In [None]:
mg.remove_embedding_layer_from_model(model)

In [None]:
model = mg.train(model)

In [None]:
params = (
    list(model.parameters())
    + list(mg._lookup_embedding.parameters())
    + list(mg._image_embedding.parameters())
)
optimizer = mg.init_optimizer(params)

In [None]:
from accelerate import Accelerator
accelerator = Accelerator()
device = accelerator.device
model, mg._lookup_embedding_, mg._image_embedding, lm_head, optimizer, text_dataloader, vqa_dataloader = accelerator.prepare(model, mg._lookup_embedding, mg._image_embedding, lm_head, optimizer, text_dataloader, vqa_dataloader)

In [None]:
import gc; gc.collect()
torch.cuda.empty_cache()
gc.collect()

In [None]:
text_batch = next(iter(text_dataloader))
vqa_batch = next(iter(vqa_dataloader))
text_sequence, text_attention_mask, text_targets = mg.embed_and_sequence_text(text_batch)
vqa_sequence, vqa_attention_mask, vqa_targets = mg.embed_and_sequence_vqa(vqa_batch)
x = torch.concat([text_sequence, vqa_sequence])
y = torch.concat([text_targets, vqa_targets])
m = torch.concat([text_attention_mask, vqa_attention_mask])
x.device, y.device, m.device

In [None]:
text_batch = next(iter(text_dataloader))
vqa_batch = next(iter(vqa_dataloader))
text_sequence, text_attention_mask, text_targets = mg.embed_and_sequence_text(text_batch)
vqa_sequence, vqa_attention_mask, vqa_targets = mg.embed_and_sequence_vqa(vqa_batch)
x = torch.concat([text_sequence, vqa_sequence])
y = torch.concat([text_targets, vqa_targets])
m = torch.concat([text_attention_mask, vqa_attention_mask])
optimizer.zero_grad()
o = model(inputs_embeds=x)
p = lm_head(o.last_hidden_state)
loss = mg.cross_entropy(p, y, m)

In [None]:
loss

In [None]:
loss.backward()

In [None]:
optimizer.step()

In [None]:
mask = vqa_attention_mask.squeeze(-1).view(-1)
predicted = predicted.view(B * T, C)
target = vqa_targets.view(-1)
losses = F.cross_entropy(predicted, target, reduction="none")
losses_masked = losses * mask
loss = losses_masked.sum() / mask.sum()

In [None]:
loss

In [None]:
losses.shape, losses_masked.shape, mask.shape

In [None]:
predicted.shape

In [None]:
loss = mg.cross_entropy(predicted, vqa_targets, vqa_attention_mask)

In [None]:
predicted

In [None]:
loss.item()

In [None]:
-math.log(1/mg.text_tokenizer.vocab_size)

In [None]:
import math

In [None]:
loss

In [None]:
import torch.nn.functional as F

In [None]:
vqa_sequence.shape

In [None]:
losses = F.cross_entropy(vqa_sequence.view(-1, vqa_sequence.size(2)), torch.randn(), reduction="none")

In [None]:
losses.shape, vqa_attention_mask.view(-1).shape

In [None]:
text_emb = mg.lookup_embedding(text_batch['input_ids'])

In [None]:
text_emb.shape

In [None]:
vqa_batch.keys()

In [None]:
vqa_batch["question_input_ids"].shape

In [None]:
vqa_batch = mg.sequence_vqa(vqa_batch)

In [None]:
torch.concat([torch.zeros(8, 3), torch.ones(8, 7)], dim=1)

In [None]:
image_emb = mg.image_embedding(vqa_batch['image'])

In [None]:
question_emb = mg.lookup_embedding(vqa_batch["question_input_ids"])
answer_emb = mg.lookup_embedding(vqa_batch["answer_input_ids"])

In [None]:
vqa_emb = torch.concat([image_emb, question_emb, answer_emb], dim=1)

In [None]:
vqa_emb.shape, text_emb.shape

In [None]:
text_emb.shape, vqa_emb.shape

In [None]:
emb.shape, text_emb.shape

In [None]:
patches.shape

In [None]:
patches = patches.view(1, 256, 16, 16, 3)

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for row in range(4):
    for col in range(4):
        axes[row][col].imshow(patches[0][row * 16 + col])

In [None]:
patches = patches.view(1, 256, -1)

In [None]:
img = mg.patches_to_image(patches, (3, 256, 256))

In [None]:
plt.imshow(img[0].permute(1, 2, 0))

In [None]:
import torch
from transformers import GPT2Model

# Initialize model
model = GPT2Model.from_pretrained('gpt2')

# Number of model parameters
num_params = sum(p.numel() for p in model.parameters())

# Data type size (float32 = 4 bytes)
dtype_size = 4

# Calculate memory for model parameters
param_memory = num_params * dtype_size

# Batch size and sequence length
batch_size = 8
seq_length = 1024

# Hidden size from GPT-2 config
hidden_size = model.config.hidden_size
num_layers = model.config.n_layer

# Calculate memory for activations
activation_memory = batch_size * seq_length * hidden_size * num_layers * dtype_size

# Calculate memory for gradients
gradient_memory = num_params * dtype_size

# Optimizer states (Adam)
optimizer_memory = num_params * dtype_size * 2

# Total memory estimate
total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory

# Convert to MB
total_memory_mb = total_memory / (1024 ** 2)

print(f"Estimated Total Memory: {total_memory_mb:.2f} MB")

In [None]:
f"{param_memory / 1e6:.2f}"

In [None]:
import torchtext
import portalocker
import datasets

In [None]:
train, valid, test = torchtext.datasets.PennTreebank('./', 'text')
it = iter(train)
ex = next(it)

In [None]:
owt = datasets.load_dataset('Skylion007/openwebtext', trust_remote_code=False)

In [None]:
ds = datasets.Dataset.from_generator(train)