# Lesson 3 - Continuous Batching

In this lesson, we'll discuss the production set up of "batching" in LLM inference, "Continuous batching".

- The key idea behind continuous batching is constantly swap out requests from the batch that have completed generation for requests in the queue that are waiting to be processed.

### Import required packages and load the LLM

In [1]:
# Import all needed functions from Lesson 1 and 2

import helpers
from helpers import init_batch, generate_next_token, generate_batch_tokens_with_past
from helpers import merge_batches, filter_batch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
model.config

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.35.2",
  "use_cache": true,
  "vocab_size": 50257
}

### Add padding tokens to the model to prepare batches of prompts

In [5]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [6]:
# multiple prompts of varying lengths to send to the model at once
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must",
]

# note: padding=True ensures the padding token will be inserted into the tokenized tensors
inputs = tokenizer(prompts, padding=True, return_tensors="pt")

In [7]:
inputs

{'input_ids': tensor([[  464,  2068,  7586, 21831, 11687,   625,   262],
        [50256, 50256,   464,  6290,   287,  8602,  8953],
        [50256, 50256, 50256,  2061,  2058,   510,  1276]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1]])}

### Define needed functions for batching

In [8]:
def generate_batch(inputs, max_tokens):
    # create a list of tokens for every input in the batch
    generated_tokens = [[] for _ in range(inputs["input_ids"].shape[0])]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)

        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)),  # '-1' here means the remaining elements for this dim
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,  # increment last, discard the rest
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"],
                torch.ones(next_token_ids.shape[0], 1)],  # concatenate vector of 1's with shape [batch_size]
                  dim=1),
            "past_key_values": past_key_values
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)

        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    
    return ["".join(tokens) for tokens in generated_tokens]

In [9]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10) for i in range(queue_size)
]

In [10]:
request_queue[:8]

[('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10)]

In [11]:
batches = [
    request_queue[i: i + batch_size]
    for i in range(0, len(request_queue), batch_size)
]

In [12]:
len(batches)

4

In [13]:
batches[0]

[('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10)]

### Define the requests to be processed

### Processing batches

In [14]:
# generate tokens for all batches and record duration
t0 = time.time()

with tqdm(total=len(batches), desc=f"Batch size={batch_size}") as pbar:
    for i, batch in enumerate(batches):
        # to accommodate all the requests with our 
        # current implementation, we take the max of
        # all the tokens to generate among the requests
        batch_max_tokens = [b[1] for b in batch]
        max_tokens = max(batch_max_tokens)
        pbar.set_postfix({"max_tokens": max_tokens})

        batch_prompts = [b[0] for b in batch]
        inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt")
        generate_batch(inputs=inputs, max_tokens=max_tokens)

        pbar.update(1)

duration_s = time.time() - t0
print("duration (seconds): ", duration_s)

Batch size=8: 100%|██████████| 4/4 [00:10<00:00,  2.67s/it, max_tokens=100]

duration (seconds):  10.69694471359253





### Let's try Continuous Batching

- This time, rather than processing each batch to completion, you will use continuous batching to dynamically swap in and out inputs from the queue.

In [15]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# this time requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

t0 = time.time()
with tqdm(total=len(request_queue), desc=f"Batch size: {batch_size}") as pbar:
    # first, let's seed the initial cached_batch
    # with the first `batch_size` inputs
    # and run the initial prefill step
    batch = init_batch(request_queue[:batch_size])
    cached_batch = generate_next_token(batch=batch)
    request_queue = request_queue[batch_size:]

    # continue until both the request queue is 
    # fully drained and every input
    # within the cached_batch has completed generation
    while (len(request_queue) > 0 or cached_batch["input_ids"].size(0) > 0):
        batch_capacity = batch_size - cached_batch["input_ids"].size(0)

        if batch_capacity > 0 and len(request_queue) > 0:
            # prefill
            new_batch = init_batch(requests=request_queue[:batch_capacity])
            new_batch = generate_next_token(batch=new_batch)
            request_queue = request_queue[batch_capacity:]

            # merge
            cached_batch = merge_batches(batch1=cached_batch, batch2=new_batch)
        
        # decode
        cached_batch = generate_next_token(batch=cached_batch)

        # remove any inputs that have finished generation
        cached_batch, removed_indices = filter_batch(batch=cached_batch)

        pbar.update(len(removed_indices))

duration_s = time.time() - t0
print(f"duration (seconds): {duration_s}")

Batch size: 8: 100%|██████████| 32/32 [00:03<00:00,  9.23it/s]

duration (seconds): 3.469820976257324



