# GPT-2 with MAX on GPU

Run GPT-2 inference using Modular's MAX framework.

**Setup:** Runtime → Change runtime type → T4 GPU

In [None]:
# Install dependencies
!pip install -q numpy torch transformers rich
!pip install -q max mojo --index-url https://dl.modular.com/public/nightly/python/simple/ --pre

In [None]:
# Check GPU availability
!nvidia-smi --query-gpu=name,memory.total --format=csv

In [None]:
import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from max.driver import CPU, GPU
from max.dtype import DType
from max.experimental import functional as F
from max.experimental.tensor import Tensor, TensorType
from max.graph import DeviceRef
from max.nn.module_v3 import Embedding, Linear, Module, Sequential

In [None]:
# GPT-2 Config
class GPT2Config:
    vocab_size = 50257
    n_positions = 1024
    n_embd = 768
    n_layer = 12
    n_head = 12
    layer_norm_epsilon = 1e-5

print(f"Config: {GPT2Config.n_layer} layers, {GPT2Config.n_head} heads, {GPT2Config.n_embd} dim")

In [None]:
# Layer Normalization
class LayerNorm(Module):
    def __init__(self, dim, eps=1e-5):
        self.eps = eps
        self.weight = Tensor.ones([dim])
        self.bias = Tensor.zeros([dim])

    def __call__(self, x):
        return F.layer_norm(x, gamma=self.weight, beta=self.bias, epsilon=self.eps)


# Causal Mask
@F.functional
def causal_mask(seq_len, dtype, device):
    from max.graph import Dim
    n = Dim(seq_len)
    mask = Tensor.constant(float("-inf"), dtype=dtype, device=device)
    mask = F.broadcast_to(mask, shape=(seq_len, n))
    return F.band_part(mask, num_lower=None, num_upper=0, exclude=True)

In [None]:
# Multi-head Attention
class GPT2Attention(Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        self.c_attn = Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = Linear(self.n_embd, self.n_embd, bias=True)

    def __call__(self, x):
        B, T, C = x.shape
        qkv = self.c_attn(x)
        q, k, v = F.split(qkv, [self.n_embd, self.n_embd, self.n_embd], axis=2)
        
        q = q.reshape([B, T, self.n_head, self.head_dim]).transpose(-3, -2)
        k = k.reshape([B, T, self.n_head, self.head_dim]).transpose(-3, -2)
        v = v.reshape([B, T, self.n_head, self.head_dim]).transpose(-3, -2)
        
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-1, -2)) * scale
        mask = causal_mask(T, 0, dtype=attn.dtype, device=attn.device)
        attn = attn + mask
        attn = F.softmax(attn)
        out = attn @ v
        
        out = out.transpose(-3, -2).reshape([B, T, C])
        return self.c_proj(out)

In [None]:
# MLP
class GPT2MLP(Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=True)
        self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=True)

    def __call__(self, x):
        x = self.c_fc(x)
        x = F.gelu(x, approximate="tanh")
        return self.c_proj(x)


# Transformer Block
class GPT2Block(Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = GPT2Attention(config)
        self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(config)

    def __call__(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
# Full GPT-2 Model
class GPT2(Module):
    def __init__(self, config):
        super().__init__()
        self.wte = Embedding(config.vocab_size, dim=config.n_embd)
        self.wpe = Embedding(config.n_positions, dim=config.n_embd)
        self.h = Sequential(*(GPT2Block(config) for _ in range(config.n_layer)))
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False)

    def __call__(self, input_ids):
        B, T = input_ids.shape
        tok_emb = self.wte(input_ids)
        pos = Tensor.arange(T, dtype=input_ids.dtype, device=input_ids.device)
        pos_emb = self.wpe(pos)
        x = tok_emb + pos_emb
        x = self.h(x)
        x = self.ln_f(x)
        return self.lm_head(x)

In [None]:
# Text generation function
def generate(model, tokenizer, device, prompt, max_tokens=30, temperature=0.8):
    tokens = tokenizer.encode(prompt)
    input_ids = Tensor.constant([tokens], dtype=DType.int64, device=device)
    
    print(f"Prompt: {prompt}")
    print("-" * 40)
    
    for i in range(max_tokens):
        logits = model(input_ids)
        next_logits = logits[0, -1, :]
        
        if temperature > 0:
            next_logits = next_logits / Tensor.constant(temperature, dtype=next_logits.dtype, device=device)
            probs = F.softmax(next_logits)
            probs_np = np.from_dlpack(probs.to(CPU()))
            next_id = np.random.choice(len(probs_np), p=probs_np)
        else:
            next_id = int(np.from_dlpack(F.argmax(next_logits).to(CPU())))
        
        next_tensor = Tensor.constant([[next_id]], dtype=DType.int64, device=device)
        input_ids = F.concat([input_ids, next_tensor], axis=1)
        
        if next_id == tokenizer.eos_token_id:
            break
        
        # Print progress every 10 tokens
        if (i + 1) % 10 == 0:
            current = tokenizer.decode(np.from_dlpack(input_ids.to(CPU())).flatten().tolist())
            print(f"[{i+1}] {current}")
    
    result_ids = np.from_dlpack(input_ids.to(CPU())).flatten().tolist()
    return tokenizer.decode(result_ids)

In [None]:
# Load HuggingFace model and tokenizer
print("Loading GPT-2 from HuggingFace...")
hf_model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print("Done!")

In [None]:
# Create MAX model and load weights
# Use GPU() for T4, or CPU() for CPU-only
try:
    device = GPU()
    print(f"Using GPU: {device}")
except:
    device = CPU()
    print(f"GPU not available, using CPU: {device}")

config = GPT2Config()
model = GPT2(config)

print("Loading weights...")
model.load_state_dict(hf_model.state_dict())
model.to(device)

# Transpose Conv1D weights to Linear format
for name, child in model.descendents:
    if isinstance(child, Linear):
        if any(n in name for n in ["c_attn", "c_proj", "c_fc"]):
            child.weight = child.weight.T

print("Weights loaded!")

In [None]:
# Compile model for faster inference
print("Compiling model (this may take a minute)...")
token_type = TensorType(DType.int64, ("batch", "seq"), device=DeviceRef.from_device(device))
compiled = model.compile(token_type)
print("Compilation done!")

In [None]:
# Generate text!
result = generate(compiled, tokenizer, device, "The meaning of life is", max_tokens=50, temperature=0.8)
print("\n" + "=" * 40)
print(f"Result: {result}")

In [None]:
# Try different prompts!
prompts = [
    "Once upon a time",
    "The future of AI is",
    "In a galaxy far away",
]

for prompt in prompts:
    print("\n" + "=" * 50)
    result = generate(compiled, tokenizer, device, prompt, max_tokens=30, temperature=0.7)
    print(f"\nFinal: {result}")

In [None]:
# Benchmark: measure tokens per second
import time

prompt = "Hello world"
num_tokens = 50

start = time.time()
result = generate(compiled, tokenizer, device, prompt, max_tokens=num_tokens, temperature=0.8)
elapsed = time.time() - start

print(f"\nGenerated {num_tokens} tokens in {elapsed:.2f}s")
print(f"Speed: {num_tokens / elapsed:.1f} tokens/sec")