[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dougc333/Colab-Notebooks/blob/main/Attn_practice.ipynb)


# **Attention impl**

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [3]:
%cd '/content/drive/MyDrive'

/content/drive/MyDrive


In [None]:
import math
import torch
import torch.nn as nn

def dense_attention(q, k, v, valid_lens):
    """
    Reference implementation: standard attention assuming
    each sequence has a contiguous [0..L-1] KV buffer.

    q: [B, H, Dh]                (last token queries)
    k, v: [B, Lmax, H, Dh]
    valid_lens: [B] integers, sequence lengths
    """
    B, H, Dh = q.shape
    _, Lmax, _, _ = k.shape
    out = torch.zeros(B, H, Dh, dtype=q.dtype)

    for b in range(B):
        L = int(valid_lens[b])
        kb = k[b, :L]          # [L, H, Dh]
        vb = v[b, :L]

        scale = 1.0 / math.sqrt(Dh)
        # scores[h, L]
        scores = torch.einsum("hd,lhd->hl", q[b] * scale, kb)
        attn = torch.softmax(scores, dim=-1)  # [H, L]
        # output[h, Dh]
        o = torch.einsum("hl,lhd->hd", attn, vb)
        out[b] = o

    return out


def paged_attention(q, key_cache, value_cache, block_table, seq_lens, block_size):
    """
    Simple 'paged' attention:

    q: [B, H, Dh]                         (last token queries)
    key_cache:   [N_blocks, block_size, H, Dh]
    value_cache: [N_blocks, block_size, H, Dh]
    block_table: [B, max_blocks]
        - block_table[b, i] = physical block index for
          logical block i of sequence b
    seq_lens: [B] sequence lengths (in tokens)
    block_size: int, tokens per block

    Returns:
        out: [B, H, Dh]
    """
    B, H, Dh = q.shape
    N_blocks, Bsz_block, Hc, Dhc = key_cache.shape

    assert Bsz_block == block_size
    assert Hc == H and Dhc == Dh
    assert block_table.shape[0] == B

    out = torch.zeros(B, H, Dh, dtype=q.dtype)

    for b in range(B):
        L = int(seq_lens[b])
        if L == 0:
            continue

        # How many logical blocks are needed for this sequence?
        num_logical_blocks = (L + block_size - 1) // block_size

        # Step 1: follow the block table to find the physical blocks
        blocks_idx = block_table[b, :num_logical_blocks]   # [num_logical_blocks]

        # Step 2: gather those blocks from the global KV cache
        # k_blocks, v_blocks: [num_logical_blocks, block_size, H, Dh]
        k_blocks = key_cache[blocks_idx]
        v_blocks = value_cache[blocks_idx]

        # Step 3: flatten blocks into a single [L_eff, H, Dh] sequence
        # (we may have padded tokens beyond L in the last block; trim them)
        k_flat = k_blocks.reshape(num_logical_blocks * block_size, H, Dh)[:L]
        v_flat = v_blocks.reshape(num_logical_blocks * block_size, H, Dh)[:L]

        # Step 4: standard scaled dot-product attention over this (logical) sequence
        scale = 1.0 / math.sqrt(Dh)
        scores = torch.einsum("hd,lhd->hl", q[b] * scale, k_flat)  # [H, L]
        attn = torch.softmax(scores, dim=-1)                       # [H, L]
        o = torch.einsum("hl,lhd->hd", attn, v_flat)               # [H, Dh]

        out[b] = o

    return out

In [None]:
# show dense attn == paged attn
def demo():
    torch.manual_seed(0)

    B, Lmax, H, Dh = 2, 8, 4, 8
    block_size = 4

    # Fake data
    q = torch.randn(B, H, Dh)
    k_dense = torch.randn(B, Lmax, H, Dh)
    v_dense = torch.randn(B, Lmax, H, Dh)
    seq_lens = torch.tensor([8, 8])  # all full-length in this demo

    # Dense baseline
    dense_out = dense_attention(q, k_dense, v_dense, seq_lens)

    # Build a paged KV cache
    # Each sequence will get Lmax/block_size blocks, stored contiguously
    blocks_per_seq = Lmax // block_size
    num_blocks_total = B * blocks_per_seq

    key_cache = torch.zeros(num_blocks_total, block_size, H, Dh)
    value_cache = torch.zeros_like(key_cache)
    block_table = torch.full((B, blocks_per_seq), -1, dtype=torch.long)

    next_block = 0
    for b in range(B):
        L = int(seq_lens[b])
        num_blocks = (L + block_size - 1) // block_size

        # logical blocks 0..num_blocks-1 map to physical blocks [next_block ..)
        blocks_idx = list(range(next_block, next_block + num_blocks))
        block_table[b, :num_blocks] = torch.tensor(blocks_idx)

        # reshape dense K/V into blocks and write into cache
        k_b = k_dense[b, :L].reshape(num_blocks, block_size, H, Dh)
        v_b = v_dense[b, :L].reshape(num_blocks, block_size, H, Dh)

        key_cache[next_block : next_block + num_blocks] = k_b
        value_cache[next_block : next_block + num_blocks] = v_b

        next_block += num_blocks

    # Paged attention output
    paged_out = paged_attention(
        q,
        key_cache,
        value_cache,
        block_table,
        seq_lens,
        block_size,
    )

    print("Max |dense - paged|:", (dense_out - paged_out).abs().max().item())

if __name__ == "__main__":
    demo()

# VLLM **Impl**

In [None]:
!pip install "vllm>=0.5.0" "transformers>=4.45.0" torch

In [None]:
import os
import json
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    AutoTokenizer,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

# In recent HF, ALL_ATTENTION_FUNCTIONS is shared between backends.
# If this import fails, search in your transformers installation for ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import ALL_ATTENTION_FUNCTIONS

from vllm import LLM


# -----------------------------
# 1. Config
# -----------------------------
class ToyConfig(PretrainedConfig):
    model_type = "toy-transformer"

    def __init__(
        self,
        vocab_size: int = 50257,   # match GPT-2 tokenizer
        hidden_size: int = 128,
        num_hidden_layers: int = 2,
        num_attention_heads: int = 4,
        intermediate_size: int = 256,
        max_position_embeddings: int = 512,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings

        # vLLM/Transformers backend stuff
        # This tells vLLM that the base model supports its attention backend.
        self._attn_implementation = "vllm"
        self.is_decoder = True
        self.is_encoder_decoder = False


# -----------------------------
# 2. Attention block compatible with vLLM backend
# -----------------------------
class ToyAttention(nn.Module):
    # decoder-only causal LM, so attention is causal
    is_causal = True

    def __init__(self, config: ToyConfig):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.head_dim = config.hidden_size // config.num_attention_heads

        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        bsz, seqlen, h = hidden_states.shape
        head_dim = self.head_dim
        num_heads = self.num_heads

        # project to Q, K, V
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # (bsz, seqlen, num_heads, head_dim) -> (bsz, num_heads, seqlen, head_dim)
        def shape(x):
            return (
                x.view(bsz, seqlen, num_heads, head_dim)
                .transpose(1, 2)
                .contiguous()
            )

        q = shape(q)
        k = shape(k)
        v = shape(v)

        # This is the crucial hook: use vLLM attention backend
        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
        attn_output, attn_weights = attention_interface(
            self,
            query_states=q,
            key_states=k,
            value_states=v,
            attention_mask=attention_mask,
            is_causal=self.is_causal,
            **kwargs,
        )

        # (bsz, num_heads, seqlen, head_dim) -> (bsz, seqlen, h)
        attn_output = (
            attn_output.transpose(1, 2)
            .reshape(bsz, seqlen, h)
            .contiguous()
        )
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class ToyMLP(nn.Module):
    def __init__(self, config: ToyConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.gelu(self.fc1(x)))


class ToyBlock(nn.Module):
    def __init__(self, config: ToyConfig):
        super().__init__()
        self.self_attn = ToyAttention(config)
        self.mlp = ToyMLP(config)
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.ln2 = nn.LayerNorm(config.hidden_size)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        # Self-attention
        residual = hidden_states
        hidden_states = self.ln1(hidden_states)
        attn_output, _ = self.self_attn(
            hidden_states,
            attention_mask=attention_mask,
            **kwargs,
        )
        hidden_states = residual + attn_output

        # MLP
        residual = hidden_states
        hidden_states = self.ln2(hidden_states)
        hidden_states = residual + self.mlp(hidden_states)
        return hidden_states


# -----------------------------
# 3. Base model (decoder LM) + LM head
# -----------------------------
class ToyModel(PreTrainedModel):
    config_class = ToyConfig
    _supports_attention_backend = True  # required for vLLM backend

    def __init__(self, config: ToyConfig):
        super().__init__(config)
        self.config = config

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.embed_positions = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )

        self.layers = nn.ModuleList(
            [ToyBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.ln_f = nn.LayerNorm(config.hidden_size)

        # Important for generation: LM head
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        bsz, seqlen = input_ids.shape

        if position_ids is None:
            position_ids = torch.arange(
                seqlen, dtype=torch.long, device=input_ids.device
            )
            position_ids = position_ids.unsqueeze(0).expand(bsz, seqlen)

        inputs_embeds = self.embed_tokens(input_ids)
        pos_embeds = self.embed_positions(position_ids)
        hidden_states = inputs_embeds + pos_embeds

        # (very simple): just apply full attention over all tokens
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                **kwargs,
            )

        hidden_states = self.ln_f(hidden_states)
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # standard LM shift
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
            )

        if not self.config.use_return_dict:
            output = (logits,)
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )


# -----------------------------
# 4. Save as a HF-style model directory
# -----------------------------
def save_toy_model(model_dir: str = "toy_vllm_model"):
    os.makedirs(model_dir, exist_ok=True)

    # use GPT-2 tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.save_pretrained(model_dir)

    config = ToyConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=128,
        num_hidden_layers=2,
        num_attention_heads=4,
        intermediate_size=256,
        max_position_embeddings=512,
    )

    # Add HF-style auto_map so Transformers/vLLM know how to import this class
    auto_map = {
        "AutoConfig": "__main__.ToyConfig",
        # For vLLM + Transformers backend, the base model should be the one in auto_map.
        # We treat ToyModel as the causal LM base.
        "AutoModelForCausalLM": "__main__.ToyModel",
    }

    cfg_dict = config.to_dict()
    cfg_dict["auto_map"] = auto_map
    cfg_dict["architectures"] = ["ToyModel"]

    with open(os.path.join(model_dir, "config.json"), "w") as f:
        json.dump(cfg_dict, f, indent=2)

    model = ToyModel(config)
    model.save_pretrained(model_dir)

    print(f"Saved toy model + tokenizer to: {model_dir}")


# -----------------------------
# 5. Load with vLLM and generate
# -----------------------------
def run_with_vllm(model_dir: str = "toy_vllm_model"):
    # Force Transformers backend; we’re using a Transformers-style custom model
    llm = LLM(model=model_dir, model_impl="transformers", trust_remote_code=True)

    prompts = ["Hello world", "Once upon a time"]
    outputs = llm.generate(prompts, max_tokens=20)

    for i, out in enumerate(outputs):
        print(f"\nPrompt: {prompts[i]!r}")
        print("Completion:", out.outputs[0].text)


if __name__ == "__main__":
    save_toy_model()
    run_with_vllm()

In [None]:
# training toymodel
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from toy_vllm_model import ToyModel, ToyConfig
from datasets import load_dataset

# Load config + model
config = ToyConfig()
model = ToyModel(config)

# Load a text dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

def tokenize(example):
    return tokenizer(example["text"], truncation=True, max_length=128)
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for next-token prediction
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,        # causal LM objective (autoregressive)
)

args = TrainingArguments(
    output_dir="toy_trained",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=3e-4,
    warmup_steps=100,
    num_train_epochs=1,
    logging_steps=10,
    evaluation_strategy="steps",
    save_strategy="epoch",
    fp16=False,          # you can enable fp16 if you want
)

trainer = Trainer(
    model=model,
    args=args,
    data_collator=collator,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"]
)

trainer.train()

In [None]:
# full example w training

from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from toy_vllm_model import ToyModel, ToyConfig, save_toy_model
from datasets import load_dataset

# Step 1: save base toy model
save_toy_model("toy_vllm_model")

# Step 2: load base model for training
tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = ToyConfig(vocab_size=tokenizer.vocab_size)
model = ToyModel(config)

# Step 3: load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenized = dataset.map(
    lambda e: tokenizer(e["text"], truncation=True, max_length=128),
    batched=True,
    remove_columns=["text"],
)

# Step 4: collator
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Step 5: training args
args = TrainingArguments(
    output_dir="toy_vllm_trained",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    learning_rate=3e-4,
    warmup_steps=50,
    num_train_epochs=1,
    save_strategy="epoch",
    logging_steps=25,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=collator,
)

trainer.train()

# Step 6: save trained model in HF format
trainer.save_model("toy_vllm_trained")

print("Training complete! Now you can serve it in vLLM:")
print("vllm serve toy_vllm_trained --model-impl transformers --trust-remote-code")

In [None]:
vllm serve toy_vllm_trained --model-impl transformers --trust-remote-code