# Batching, multi-gpu, and multi-node for large data and large models

We've seen how to inference LLMs with a high degree of control over the model inputs and outputs. The goal of this last notebook is to discussion measures to scale up the inference process to large data and large models.

There are three primary tools we will use:
1. Batching
2. Multi-GPU inference
3. Multi-node inference

We'll discuss each of these in turn.

## Batching

Batching is the process of processing multiple inputs at once. This is a common technique in deep learning, as it allows the model to process multiple inputs in parallel. The `transformers` library has built-in support for batching, and we can use it to speed up inference with minimal code changes.

First, we'll load a large number of pieces of text that we want to process using an LLM. Then, we'll process them in batches and compare the time it takes to process them in batches versus one at a time.

In [None]:
# Get a list of texts from the 20 newsgroups dataset
# Each text is a post from a newsgroup
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='test', data_home='/project/rcde/cehrett/running_llms_workshop/data')['data'][:64]
print(f'Number of documents: {len(docs)}')
for i, doc in enumerate(docs[:3]):
    print(f'\n\nDOCUMENT {i+1}:\n{doc}\n')

Suppose we want some piece of information about each of these newsgroup posts, and what we want cannot be easily extracted in an automated way using traditional NLP techniques. An LLM might be a good choice for such a task.

For example, we might want a one-sentence summary of each post. We can craft a prompt that asks the model to generate such a summary.

In [None]:
import time

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507", padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B-Instruct-2507", device_map="auto")

device = model.device

system_prompt = "The user will supply a post from an online newsgroup. Summarize the post in a single, very short sentence."

# Define a function that will generate summaries for a batch of posts
def generate_summaries(texts, batch_size):
    results = []
    total_batches = (len(texts) + batch_size - 1) // batch_size
    with tqdm(total=total_batches, desc="Processing batches", leave=True, bar_format="{l_bar}{bar} | {n_fmt}/{total_fmt}") as pbar:
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            batch_messages = [[{"role": "system", "content": system_prompt}, 
                               {"role": "user", "content": text}] for text in batch]
                        
            # Tokenize the messages using chat template
            model_inputs = tokenizer.apply_chat_template(
                batch_messages,
                add_generation_prompt=True,
                return_tensors="pt",
                padding=True,
                return_dict=True,
                truncation=True,
                max_length=512
            ).to(device)

            # Run model to get logits and generated output
            with torch.no_grad():
                model.eval()
                torch.set_grad_enabled(False)
                torch.backends.cuda.matmul.allow_tf32 = True

                outputs = model.generate(
                    **model_inputs,
                    max_new_tokens=100,
                    return_dict_in_generate=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            # Decode output
            prompt_length = model_inputs["input_ids"].shape[1]
            generated_sequences = outputs.sequences[:, prompt_length:]
            decoded_outputs = tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)
            results.extend(decoded_outputs)

            pbar.update(1)
    return results

start = time.time()

# Generate summaries for the documents
summaries = generate_summaries(docs, batch_size=16)

end = time.time()
print(f"Total time taken: {end - start:.2f} seconds")
torch.cuda.empty_cache()


## Why vLLM Is Faster

vLLM achieves significantly higher throughput than `transformers` for three main reasons:

1. **PagedAttention (optimized KV-cache management)**  
   vLLM stores and reuses key–value attention cache in a paging system instead of repeatedly reallocating large tensors. This drastically reduces memory movement and avoids fragmentation.

2. **Continuous batching**  
   New requests are dynamically added to running batches without waiting for all sequences in the batch to finish. This eliminates “straggler” slowdown and keeps the GPU saturated.

3. **CUDA graph + compilation optimizations**  
   vLLM captures decoding operations into CUDA graphs, reducing Python overhead and kernel-launch latency. It also fuses kernels and uses optimized scheduling.

Together, these features keep the GPU >90% utilized during generation, giving vLLM its characteristic 2–5× speedup over standard `transformers` inference.


In [None]:
from vllm import LLM, SamplingParams

llm = LLM("Qwen/Qwen3-4B-Instruct-2507", gpu_memory_utilization=0.70)

sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=100,
)

In [None]:
import time

def vllm_generate_summaries(texts, batch_size):
    system_prompt = "Summarize the post in one very short sentence."

    # Build prompts (vLLM expects a single string per item)
    prompts = [
        f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        f"<|im_start|>user\n{text}<|im_end|>\n"
        f"<|im_start|>assistant\n"
        for text in texts
    ]

    results = []
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        outputs = llm.generate(batch, sampling_params)
        results.extend([o.outputs[0].text for o in outputs])
    
    return results

start = time.time()
summaries = vllm_generate_summaries(docs, batch_size=64)
end = time.time()
print(f"Total time taken: {end - start:.2f} seconds")

In [None]:
summaries

In [None]:
from utils import create_answer_box

create_answer_box(
    question=(
        "What considerations affect the choice of batch size? "
        "What downsides are there to setting an enormous batch size? "
        "(Unsure? Try out different batch sizes and see!)"
    ),
    question_id="nb3-01"
)

In [None]:
# Clear the model from memory
import torch
del model
del llm
torch.cuda.empty_cache()

## Multi-GPU inference

Thankfully, `transformers` makes multi-gpu inference easy.

Note that there are multiple kinds of ways you might want to use multiple GPUs. Note that there are different kinds of paralellism one might want to use. For example, if you just want to speed up your LLM inference, and your model can fit on a single GPU, you can use *data parallelism*.

If your model is too large to fit on a single GPU, you can use *model parallelism*, in which the different GPUs each hold a different part of the model. Luckily, `transformers` makes it easy to use model parallelism, via setting `device_map`. 

In [None]:
import time
start = time.time()

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

model_name = "Qwen/Qwen3-4B-Instruct-2507" # "meta-llama/Llama-3.3-70B-Instruct" 

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') 
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", dtype=torch.bfloat16)

device = model.device

system_prompt = "The user will supply a post from an online newsgroup. Summarize the post in a single, very short sentence."

# Generate summaries for the documents
summaries = generate_summaries(docs, batch_size=8)

end = time.time()
print(f"Total time taken: {end - start:.2f} seconds")

summaries

In [None]:
# Clear the model from memory
import torch
del model
torch.cuda.empty_cache()

## Multi-node inference

What if you have multiple nodes available, and want to use them all to speed up your inference? There are a variety of sorts of *parallelism* that are possible with multi-node inference.

For example, you can use *data parallelism*, in which you split the data across the nodes, and each node processes a different part of the data. You can also use *model parallelism*, in which you split the model across the nodes, and each node processes a different part of the model. The former is for speeding up inference, and the latter is for when you have a model that's too large to fit on a single node.

We will implement data parallelism. The code is in the scripts `multinode_infer.slurm` and `multinode_infer.py`. These files work together to enable distributed inference:

- `multinode_infer.slurm`: The SLURM job submission script that requests and configures computing resources (in this example, 2 nodes, each with 2 A100 GPU, 32GB memory, etc.)

- `multinode_infer.py`: The main Python script that performs the actual inference. 

The workflow is to use `sbatch` to submit `multinode_infer.slurm`, which calls `multinode_infer.py` in a coordinated way across all nodes.

In [None]:
from utils import create_answer_box

create_answer_box(
    question=(
        "Thank you for participating in today's workshop! Please record here any thoughts you have about how this workshop or others like it could be of greater use to you in the future."
    ),
    question_id="nb3-02"
)