In [1]:
# Import all needed functions from Lesson 1 and 2

import helpers
from helpers import init_batch, generate_next_token
from helpers import merge_batches, filter_batch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
model_name = "gpt2"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
model = model.to(device)

In [4]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [5]:
# multiple prompts of varying lengths to send
# to the model at once
prompts = [
    "Sean was a writer who lived in a small",
    "The quick brown fox jumped over the",
    "I am only a machine, but I can",
]

# note: padding=True ensures the padding token
# will be inserted into the tokenized tensors

inputs = tokenizer(prompts,padding=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}  # Move the entire dictionary to the GPU
inputs

{'input_ids': tensor([[26408,   373,   257,  6260,   508,  5615,   287,   257,  1402],
         [50256, 50256,   464,  2068,  7586, 21831, 11687,   625,   262],
         [   40,   716,   691,   257,  4572,    11,   475,   314,   460]],
        device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [6]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [7]:
def generate_batch(inputs, max_tokens):
    # create a list of tokens for every input in the batch
    generated_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    for _ in range(max_tokens):
        next_token_ids, past_key_values = \
            generate_batch_tokens_with_past(next_inputs)

        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)),
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,
            "attention_mask": torch.cat([
                next_inputs["attention_mask"],
                torch.ones((next_token_ids.shape[0], 1), device=next_inputs["input_ids"].device),  
            ], dim=1),
            "past_key_values": past_key_values,
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

In [8]:
generated_tokens = generate_batch(inputs, max_tokens=10)

In [9]:
for prompt, generated in zip(prompts, generated_tokens):
    print(prompt, f"\x1b[31m{generated}\x1b[0m\n")

Sean was a writer who lived in a small [31m town in the South Bronx. He was a member[0m

The quick brown fox jumped over the [31m fence and ran to the other side of the fence[0m

I am only a machine, but I can [31m do it. I am a machine, but I[0m



In [10]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

In [11]:
request_queue[:8]

[('Sean was a writer who lived in a small', 100),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10)]

In [12]:
batches = [
    request_queue[i:i + batch_size]
    for i in range(0, len(request_queue), batch_size)
]

In [13]:
len(batches)

4

In [14]:
batches[0]

[('Sean was a writer who lived in a small', 100),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10),
 ('Sean was a writer who lived in a small', 10)]

In [20]:
# generate tokens for all batches and record duration
t0 = time.time()
with tqdm(total=len(batches), desc=f"bs={batch_size}") as pbar:
    for i, batch in enumerate(batches):
        # to accommodate all the requests with our 
        # current implementation, we take the max of
        # all the tokens to generate among the requests
        batch_max_tokens = [b[1] for b in batch]
        
        max_tokens = max(batch_max_tokens)
        pbar.set_postfix({'max_tokens': max_tokens})
        
        batch_prompts = [b[0] for b in batch]
        inputs = tokenizer(
            batch_prompts, padding=True, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Move the entire dictionary to the GPU
        generate_batch(inputs, max_tokens=max_tokens)
        
        pbar.update(1)

duration_s = time.time() - t0
print("duration", duration_s)

bs=8: 100%|██████████| 4/4 [00:03<00:00,  1.03it/s, max_tokens=100]

duration 3.8741214275360107





In [21]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# this time requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

t0 = time.time()
with tqdm(total=len(request_queue), desc=f"bs={batch_size}") as pbar:
    # first, let's seed the initial cached_batch
    # with the first `batch_size` inputs
    # and run the initial prefill step
    batch = init_batch(request_queue[:batch_size])
    cached_batch = generate_next_token(batch)
    request_queue = request_queue[batch_size:]

    # continue until both the request queue is 
    # fully drained and every input
    # within the cached_batch has completed generation
    while (
        len(request_queue) > 0 or
        cached_batch["input_ids"].size(0) > 0
    ):
        batch_capacity = (
            batch_size - cached_batch["input_ids"].size(0)
        )
        if batch_capacity > 0 and len(request_queue) > 0:
            # prefill
            new_batch = init_batch(request_queue[:batch_capacity])
            new_batch = generate_next_token(new_batch)
            request_queue = request_queue[batch_capacity:]

            # merge
            cached_batch = merge_batches(cached_batch, new_batch)

        # decode
        cached_batch = generate_next_token(cached_batch)

        # remove any inputs that have finished generation
        cached_batch, removed_indices = filter_batch(cached_batch)
        pbar.update(len(removed_indices))

duration_s = time.time() - t0
print("duration", duration_s)

bs=8:   0%|          | 0/32 [00:00<?, ?it/s]

bs=8: 100%|██████████| 32/32 [00:01<00:00, 22.21it/s]

duration 1.443659782409668



