In [2]:
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

In [3]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [None]:
model

# KV Caching

In [None]:
prompt = "The quick brown fox jumped over the"

In [None]:
inputs = tokenizer(prompt, return_tensors='pt')
inputs

In [None]:
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits
logits

In [None]:
logits.shape

In [None]:
last_logits = logits[0, -1, :]
last_logits, len(last_logits)

In [None]:
next_token_id = last_logits.argmax()
next_token_id

In [None]:
tokenizer.decode(next_token_id)

In [None]:
top_k = torch.topk(last_logits, k=10)
top_k

In [None]:
tokens = [tokenizer.decode(tk) for tk in top_k.indices]
tokens

In [None]:
next_inputs = {
    "input_ids": torch.cat(
        [inputs["input_ids"], next_token_id.reshape((1, 1))],
        dim=1
    ),
    "attention_mask": torch.cat(
        [inputs["attention_mask"], torch.tensor([[1]])],
        dim=1
    ),
}

In [None]:
next_inputs["input_ids"], next_inputs["input_ids"].shape

In [None]:
next_inputs["attention_mask"], next_inputs["attention_mask"].shape

## without KV Cache

In [None]:
def generate_token(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id

In [None]:
generated_tokens = []
next_inputs = inputs
durations_s = []

for _ in range(10):
    t0 = time.time()
    next_token_id = generate_token(next_inputs)
    durations_s += [time.time() - t0]

    next_inputs = { 
        "input_ids": torch.cat(
            [inputs["input_ids"], next_token_id.reshape((1, 1))],
            dim=1
            ),
        "attention_mask": torch.cat(
            [inputs["attention_mask"], torch.tensor([[1]])],
            dim=1
            ),
        }
    next_token = tokenizer.decode(next_token_id)
    generated_tokens.append(next_token)

print(f"소요시간: {sum(durations_s)}")
print(generated_tokens)

In [None]:
plt.plot(durations_s)
plt.show()

## with KV Cache

In [None]:
def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

In [None]:
generated_tokens = []
next_inputs = inputs
durations_cached_s = []

for _ in range(10):
    t0 = time.time()
    next_token_id, past_key_values = generate_token_with_past(next_inputs)
    durations_cached_s += [time.time() - t0]

    next_inputs = { 
        "input_ids": next_token_id.reshape((1, 1)),
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.tensor([[1]])],
            dim=1),
        "past_key_values": past_key_values,
        }
    next_token = tokenizer.decode(next_token_id)
    generated_tokens.append(next_token)

print(f"소요시간: {sum(durations_cached_s)}")
print(generated_tokens)

In [None]:
plt.plot(durations_s)
plt.plot(durations_cached_s)
plt.show()

# Batching - issues with multiple inputs

## Single Input

In [None]:
prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors='pt')

def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = generate_token_with_past(next_inputs)
    
        next_inputs = { 
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.tensor([[1]])],
                dim=1),
            "past_key_values": past_key_values,
            }
        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
    return "".join(generated_tokens)
    
tokens = generate(inputs, max_tokens=10)
tokens

## Multiple inputs

In [None]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [None]:
# Pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_siede = "left"

In [None]:
# mutiple prompts of varying lengths to send to the model at once
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must"
]
# note: padding=True ensures the padding tokens
# will be inserted into the tokenized tensors
inputs = tokenizer(prompts, padding=True, return_tensors="pt")
inputs

In [None]:
inputs["input_ids"].shape

In [None]:
# position_ids tell the transformer the ordinal position of each token in the input sequence
# for single input inference, this is just [0, ..n]
# for n tokens, but for batch inference, 
# we need to 0 out the padding tokens at the start of the sequence

attention_mask = inputs["attention_mask"]
position_ids = attention_mask.long().cumsum(-1) - 1  # 모든항에서 1씩 빼고
print(position_ids)
position_ids.masked_fill_(attention_mask == 0, 1)  # 어텐션마스크가 0인 곳은 1로 변경

In [None]:
# same as before, but include the position_ids
with torch.no_grad():
    outputs = model(position_ids=position_ids, **inputs)
logits = outputs.logits
logits

In [None]:
last_logits = logits[:, -1, :]
last_logits

In [None]:
next_token_ids = last_logits.argmax(dim=1)
next_token_ids

In [None]:
next_tokens = tokenizer.batch_decode(next_token_ids)
next_tokens

In [None]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [None]:
def generate_batch(inputs, max_tokens):
    generated_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]
    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
        }
    
    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
    
        next_inputs = { 
            "input_ids": next_token_ids.reshape((-1, 1)),
            "position_ids": next_inputs["position_ids"][:,-1].unsqueeze(-1) + 1,
            "attention_mask": torch.cat([
                next_inputs["attention_mask"], 
                torch.ones((next_token_ids.shape[0], 1))
            ], dim=1),
            "past_key_values": past_key_values,
            }
        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

In [None]:
generated_tokens = generate_batch(inputs, max_tokens=10)
generated_tokens

In [None]:
for prompt, generated in zip(prompts, generated_tokens):
    print(prompt, f"\x1b[31m{generated}\x1b[0m\n")

## Throughput vs. Latency

In [None]:
# contants
max_tokens = 10

# observations
durations = []
throughputs = []
latencies = []

In [None]:
batch_sizes = [2**p for p in range(8)]
batch_sizes

In [None]:
for batch_size in batch_sizes:
    print(f"bs: {batch_size}")

    # generate tokens for batch and record duration
    t0 = time.time()
    batch_prompts = [
        prompts[i % len(prompts)] for i in range(batch_size)
    ]
    inputs = tokenizer(
        batch_prompts, padding=True, return_tensors="pt"
    )
    generated_tokens = generate_batch(inputs, max_tokens=max_tokens)
    duration_s = time.time() - t0

    ntokens = batch_size * max_tokens
    throughput = ntokens / duration_s
    avg_latency = duration_s / max_tokens
    print(f"ntokens: {ntokens}")
    print(f"duration_s: {duration_s}")
    print(f"throughput: {throughput}")
    print(f"avg_latency: {avg_latency}")
    print()

    durations.append(duration_s)
    throughputs.append(throughput)
    latencies.append(avg_latency)

In [None]:
def render_plot(x, y1, y2, x_label, y1_label, y2_label):
    fig, ax1 = plt.subplots()

    color = 'tab:red'
    ax1.set_xlabel(x_label)
    ax1.set_ylabel(y1_label, color=color)
    ax1.plot(x, y1, color = color)
    ax1.tick_params(axis='y', labelcolor=color)

    ax1.set_xscale('log', base=2)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel(y2_label, color=color)
    ax2.plot(x, y2, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    plt.show()
    

In [None]:
render_plot(
    batch_sizes,
    throughputs,
    latencies,
    "Batch Size",
    "Throughput",
    "Latency"
)

# Continuous Batching

In [None]:
import copy
import random
import torch.nn.functional as F

In [None]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
tokenizer.padding_side = "left"
tokenizer.truncation_siede = "left"

In [None]:
# mutiple prompts of varying lengths to send to the model at once
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must"
]

inputs = tokenizer(prompts, padding=True, return_tensors="pt")
inputs

In [None]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [None]:
def generate_batch(inputs, max_tokens):
    generated_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]
    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
        }
    
    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
    
        next_inputs = { 
            "input_ids": next_token_ids.reshape((-1, 1)),
            "position_ids": next_inputs["position_ids"][:,-1].unsqueeze(-1) + 1,
            "attention_mask": torch.cat([
                next_inputs["attention_mask"], 
                torch.ones((next_token_ids.shape[0], 1))
            ], dim=1),
            "past_key_values": past_key_values,
            }
        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

In [None]:
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

batches = [
    request_queue[i:i + batch_size]
    for i in range(0, len(request_queue), batch_size)
]

In [None]:
request_queue[:8]

In [None]:
len(batches)

In [None]:
batches[:1]

In [None]:
# generate tokens for all batches and record duration

t0 = time.time()
with tqdm(total=len(batches), desc=f"bs: {batch_size}") as pbar:
    for i, batch in enumerate(batches):
        # to accomplish all the requests with out current implementation, we take the max of all the tokens to generate among the requests
        batch_max_tokens = [b[1] for b in batch]
        max_tokens = max(batch_max_tokens)
        pbar.set_postfix({'max_tokens': max_tokens})

        batch_prompts = [b[0] for b in batch]
        inputs = tokenizer(
            batch_prompts, padding=True, return_tensors="pt")
        generate_batch(inputs, max_tokens=max_tokens)

        pbar.update(1)
duration_s = time.time() - t0
duration_s

## 개선해보기

In [None]:
import helpers
from helpers import init_batch, generate_next_token
from helpers import merge_batches, filter_batches

In [None]:
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]


t0 = time.time()
with tqdm(total=len(batches), desc=f"bs: {batch_size}") as pbar:
    # first, let's seed the initial cached_batch
    # with the first 'batch_size' input
    # and run the initial prefill step
    batch = init_batch(request_queue[:batch_size])
    cached_batch = generate_next_token(batch)
    request_queue = request_queue[batch_size:]

    # continue until both the request queue is fully drained and every input within the cached_batch has completed generation
    while (
        len(request_queue) > 0 or
        cached_batch['input_ids'].size(0) > 0
    ):
        batch_capacity = (
            batch_size - cached_batch["input_ids"].size(0)
        )
        if batch_capacity > 0 and len(request_queue) > 0:
            # prefill
            new_batch = init_batch(request_queue[:batch_capacity])
            new_batch = generate_next_token(new_batch)
            request_queue = request_queue[batch_capacity:]

            # merge
            cached_batch = merge_batches(cached_batch, new_batch)

        # decode
        cached_batch = generate_next_token(cached_batch)

        # remove any inputs that have finished generation
        cached_batch, removed_indices = filter_batch(cached_batch)
        pbar.update(len(removed_indices))
    
duration_s = time.time() - t0
duration_s