# LLaMA From Scratch

**References**
- *Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm: [Youtube Video](https://youtu.be/oM4VmoabDAI?si=JtlNl00nZeIOkWxx), [Code](https://github.com/hkproj/pytorch-llama)*
- *LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU: [Youtube Video](https://youtu.be/Mn_9W1nCFLo?si=4xJy4OzpPX5YxGqx)*
- *RoFormer: Enhanced Transformer with Rotary Position Embedding: [Paper](https://arxiv.org/abs/2104.09864)*
- *Root Mean Square Layer Normalization: [Paper](https://arxiv.org/abs/1910.07467)*
- *Rotary Embeddings: A Relative Revolution: [Blog](https://blog.eleuther.ai/rotary-embeddings/)*
- *Transformers Optimization: Part 1 - KV Cache: [Blog](https://r4j4n.github.io/blogs/posts/kv/)*
- *The Secret Sauce of LLaMA🦙 : A Deep Dive!: [Blog](https://r4j4n.github.io/blogs/posts/llama/)*

## Imports

In [2]:
from dataclasses import dataclass
from typing import Optional, List
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
import time

# Training

### LLaMA Family

**Model Architecture**

<table>
    <tr>
        <td></td>
        <td><strong>Training Data</strong></td>
        <td><strong>Params</strong></td>
        <td><strong>Context length</strong></td>
        <td><strong>GQA</strong></td>
        <td><strong>Token count</strong></td>
    </tr>
    <tr>
        <td rowspan="4">Llama 1</td>
        <td rowspan="4">See Touvron et al. (2023)</td>
        <td>7B</td>
        <td>2k</td>
        <td>❌</td>
        <td>1T+</td>
    </tr>
    <tr>
        <td>13B</td>
        <td>2k</td>
        <td>❌</td>
        <td>1T+</td>
    </tr>
    <tr>
        <td>33B</td>
        <td>2k</td>
        <td>❌</td>
        <td>1.4T+</td>
    </tr>
    <tr>
        <td>65B</td>
        <td>2k</td>
        <td>❌</td>
        <td>1.4T+</td>
    </tr>
    <tr>
        <td rowspan="4">Llama 2</td>
        <td rowspan="4">A new mix of publicly available online data.</td>
        <td>7B</td>
        <td>4k</td>
        <td>❌</td>
        <td rowspan="4">2T+</td>
    </tr>
    <tr>
        <td>13B</td>
        <td>4k</td>
        <td>❌</td>
    </tr>
    <tr>
        <td>34B</td>
        <td>4k</td>
        <td>✔️</td>
    </tr>
    <tr>
        <td>70B</td>
        <td>4k</td>
        <td>✔️</td>
    </tr>
    <tr>
        <td rowspan="2">Llama 3</td>
        <td rowspan="2">A new mix of publicly available online data.</td>
        <td>8B</td>
        <td>8k</td>
        <td>✔️</td>
        <td rowspan="2">15T+</td>
    </tr>
    <tr>
        <td>70B</td>
        <td>8k</td>
        <td>✔️</td>
    </tr>
</table>

### Model Arguments



In [2]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    # * Unlike the og transformer, we don't need to have the same q, k, v values in LLaMA
    n_heads: int = 32  # number of heads for the queries
    n_kv_heads: Optional[int] = None  # Number of heads for the keys and values
    vocab_size: int = -1  # will be set when we load the tokenizer
    # * since grouped query attention heads are reduced,
    # * the number of params in the FFN is increased to keep the total number of parameters the same
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5  # epsilon for layer norm

    # needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None

## Rotary Positional Embedding

**Absolute Positional Encodings vs Relative Positional Encodings?**
- Absolute positional encodings are fixed vectors that are added to the embedding of a token to represent its absolute position in the sentence. So, it deals with one token at a time. You can think of it as the pair (latitude, longitude) on a map: each point on earth will have a unique pair.
- Relative positional encodings, on the other hand, deals with two tokens at a time and it is involved when we calculate the attention: since the attention mechanism captures the "intensity" of how much two words are related to each other, relative positional encodings tells the attention mechanism the distance between the two words involved in it. So, given two tokens, we create a vector that represent their distance.

**The drawbacks of absolute or relative position information**
- The vanilla positional encoding is designed for a fixed maximum sequence length. If you have a more extended sequence than the maximum length used during training, handling it becomes problematic. You might need to truncate, split, or find another way to fit it within the maximum length. A model trained with a particular maximum sequence length may not generalize well to sequences of very different lengths, even if they’re within the allowed range. The positional encoding for these lengths might be outside the distribution seen during training.
- The sinusoidal nature of the positional encoding might not always be optimal for capturing very long-term dependencies in long sequences. While self-attention theoretically allows for such connections, in practice, the model might still struggle due to the fixed nature of the encoding.

**Rotary Positional Embedding (RoPE)**

Rotary Positional Embedding (RoPE) is a new type of position encoding that unifies absolute and relative approaches.

For dot-production attention the rotary encoding gives relative attention. so,

$$\bold{q}_m^\top k_n = (\bold{R}_{\Theta, m}^d \bold{W}_q x_m)^\top (\bold{R}_{\Theta, n}^d \bold{W}_k x_n) = x^{\top} \bold{W}_q R_{\Theta, n-m}^d \bold{W}_k x_n$$


![Rotary position embedding Overview](images/rotary-position-embedding-overview.png)

**Intuition**

We would like to find a positional encoding function $f(\bold{x}, l)$ for an item $\bold{x}$ and its position $l$ such that, for two items $\bold{q}$ and $\bold{k}$ at positions $m$ and $n$, the inner product between $f(\bold{q}, l)$ and $f(\bold{k}, n)$ is sensitive only to the values of $\bold{q}$, $\bold{k}$, and their relative position $m-n$. This is related in spirit to the kernel trick: we are searching for a feature map such that its kernel has certain properties. A key piece of information is the geometric definition of the dot product between Euclidean vectors: $\bold{q} \cdot \bold{k} = \| \bold{q} \| \| \bold{k} \| cos(\theta_{qk})$

In plain english, the intuition behind RoPE is that we can represent the token embeddings as complex numbers and their positions as pure rotations that we apply to them. If we shift both the query and key by the same amount, changing absolute position but not relative position, this will lead both representations to be additionally rotated in the same manner, thus the angle between them will remain unchanged and thus the dot product will also remain unchanged. By exploiting the nature of rotations, the dot product used in self-attention will have the property we are looking for, preserving relative positional information while discarding absolute position.

**The RoPE Solution further explained in simpler terms**

1. Imagine Words as Positions on a Circle: Think of each word's embedding (its numerical representation) as a point on a circle.
2. Positions as Rotations: Instead of using separate values for each word's position, RoPE uses rotations. Words closer together experience smaller rotations, while further words have larger rotations.
3. Shifting Together Keeps Things Relative: Since both the "query" (looking for information) and "key" (holding information) embeddings are rotated by the same amount when considering relative position, the angle between them stays the same. This, in turn, keeps the "dot product" (a measure of similarity) between them unchanged.

**How is this different from the sinusoidal embeddings used in "Attention is All You Need"?**

There are two ways that rotary embeddings are different from sinusoidal embeddings:
- Sinusoidal embeddings apply to each coordinate individually, while rotary embeddings mix pairs of coordinates
- Sinusoidal embeddings add a $cos (m \theta)$ or $sin (m \theta)$ term, while rotary embeddings use a multiplicative factor.

### Precompute Theta Positional Frequencies

Below are the steps involved in precomputing theta positional frequencies:

![Precompute Theta Positional Frequencies Steps](images/theta-pos-freq-steps.png)

In [3]:
def precompute_theta_pos_frequencies(
    head_dim: int, seq_len: int, device: str, theta: float = 10000.0
):
    # theta 10000.0 is the default value in the paper
    # As written in the paragraph 3.2.2 of the paper
    # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
    assert (
        head_dim % 2 == 0
    ), "Dimension must be even since rotary embedding can't be applied to odd."

    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ..., dim/2]
    theta_numerator = torch.arange(0, head_dim, 2).float()  # (head_dim / 2)
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)  # (dim / 2)
    # construct the positions (the "m" parameter)
    m = torch.arange(seq_len, device=device)  # (seq_len)
    # Multiply each theta by each position using the outer product.
    # (seq_len), outer_product*(head_dim/2) -> (seq_len,head_dim/2)
    freqs = torch.outer(m, theta).float()
    # we can compute complex numbers in the polar form c = R*exp(m*theta), where R=1 as follow:
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

### Rotary Embeddings

The Steps in calculating the Rotary Embedding:
![The Steps in calculating the Rotary Embedding](images/rotary-embedding-steps.png)

Figure 1: Implementation of Rotary Position Embedding(RoPE):
![Implementation of Rotary Position Embedding](images/implementation-of-rope.png)

**Practical Considerations**
- The rotary position embeddings are only applied to the query and the keys, but not the values.
- The rotary position embeddings are applied after the vector q and k have been multiplied by the W matrix in the attention mechanism, while in the vanilla transformer they're applied before.


In [4]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # * OP 1 & 2 >>
    # seperate the last dimension pairs of 2 values, representing the real & imaginary parts of the complex number
    # two consecutive values will become a single complex number
    # (B,seq_len,H,head_dim) -> (B,seq_len,H,head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # reshape the freqs_complex tensor to match the shape of the x_complex tensor.
    # So we need to add the batch dimension and the head dimension.
    # (seq_len,head_dim/2) -> (1,seq_len,1,head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # * OP 3 >>
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # which results in the rotation of the complex number as shown in the Figure 1 of the paper.
    # (B,seq_len,H,head_dim/2)*(1,seq_len,1,head_dim/2) -> (B,seq_len,H,head_dim/2)
    x_rotated = x_complex * freqs_complex
    # * OP 4 >> convert the complex number back to the real number
    # (B,seq_len,H,head_dim/2) -> (B,seq_len,H,head_dim/2,2)
    x_out = torch.view_as_real(x_rotated)
    # * OP 5 >> Flattening to the shape of the original tensor
    # (B,seq_len,H,head_dim/2,2) -> (B,seq_len,H,head_dim)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

## Root Mean Square Normalization

**Internal covariate shift**

Internal covariate shift refers to the *gradual change in the distribution of data* as it flows through the network's layers. As training progresses, the weights in the earlier layers are updated based on the input data. These weight changes alter the way the data is transformed between layers. Consequently, the distribution of the data (activation values) at each layer changes compared to the initial distribution. This shift in distribution across layers can disrupt the learning process in deeper networks. Neurons in later layers have to constantly adapt to the changing inputs they receive, making it harder for the network to converge on a stable set of weights. Normalization techniques, like batch normalization and layer normalization, address internal covariate shift by essentially standardizing the data at each layer.
- Batch Normalization: This technique normalizes the activations of each mini-batch of data presented to a layer. With batch normalization we *normalize by columns (features)*.
- Layer Normalization: This technique normalizes the activations of each neuron within a layer, independent of the mini-batch. With layer normalization we *normalize by rows (data items)*.

**Root Mean Square Normalization**

LayerNorm works because of its re-centering and re-scaling invariance property. Re-centering enables the model to be insensitive to shift noises on both inputs and weights, and re-scaling keeps the output representations intact when both inputs and weights are randomly scaled. 

The RMS Normalizaiton paper hypothesize that the re-scaling invariance is the reason for success of LayerNorm, rather than re-centering invariance and they propose RMSNorm which only focuses on re-scaling invariance and regularizes the summed inputs simply according to the root mean square (RMS) statistic:

$$
\bar{a}_i= \frac{a_i}{RMS(a)} g_i \\
\text{where} \ RMS(a) = \sqrt{\frac{1}{n} \sum_{i=1}^n a_{i}^2}
$$

Intuitively, RMSNorm simplifies LayerNorm by totally removing the mean statistic at the cost of sacrificing the invariance that mean normalization affords. This helps in reducing the computation cost compared to Layer Normalization and also works well in practice.

In [5]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # The gamma parameter

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # rsqrt: 1 / sqrt(x)
        # (B,seq_len,dim)*(B,seq_len,1) -> (B,seq_len,dim)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (dim)*(B,seq_len,dim) -> (B,seq_len,dim)
        return self.weight * self._norm(x.float()).type_as(x)

## Feed Forward

### SwiGLU Activation Function

**SwiGLU**

SwiGLU is a combination of the Swish activation function and the Gated Linear Unit (GLU) concept. It was introduced in the paper “GLU Variants Improve Transformer” (Sho Takase, Naoaki Okazaki, 2020). The authors propose several variations of the standard GLU that can improve performance on machine translation tasks when used in a Transformer model.

The SwiGLU variant is defined as:
$$\mathit{SwiGLU}(x, x') = x \odot \mathit{Swish}(x')$$

where $\odot$ is the element-wise multiplication operation, and x' is the transformed input (generally, a linear transformation of the input x).

**Swish**

Swish is a smooth version of ReLU with a non-zero gradient for negative values. It is a smooth, non-monotonic function that consistently matches or outperforms ReLU.

Simply put, Swish is an extension of the SILU activation function. SILU's formula $f(x) = x * \mathit{sigmoid}(x)$. The slight modification made in the Swish formulation is the addition of a trainable $\beta$ parameter, making it $f(x) = x \mathit{sigmoid}(\beta x)$.

In contrast to ReLU, which is a piecewise linear function, swish is a smooth, continuous function that allows small number of negative weights to pass through unlike ReLU which sets all negative weights to zero. This non-monotonic property is particularly beneficial in deep neural networks. A non-monotonic function is a type of function that does not consistently increase or decrease in value.

The trainable parameter $\beta$ enables the activation function to be fine-tuned more effectively to optimize information propagation and push for smoother gradients.

**Gated Linear Unit (GLU)**

GLU (Gated Linear Units) is a layer within a neural network, rather than a strict activation function. It involves a linear transformation followed by a gating process. This gating process is controlled by a sigmoid function that manages the information flow from the linear transformation.

$$h_l (\mathbf{X}) = (\mathbf{X} \ast \mathbf{W} + b) \otimes \sigma (\mathbf{X} \mathbf{V} + \mathbf{c})$$

$\sigma$ means the sigmoid function. So we have two sets of weights W and V, and two biases, b, and c. The idea is simple. I want to allow the network to decide how much information should flow through a given path, like a logical gate, hence the name. How?
- If we multiply X by 0, nothing passes.
- If we multiply X by 1, everything passes.
- If we multiply X by 0.5, half of it passes.



In [6]:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        hidden_dim = 4 * args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = args.multiple_of * (
            (hidden_dim + args.multiple_of - 1) // args.multiple_of
        )

        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        swish = F.silu(self.w1(x))  # (B, seq_len, dim) -> (B, seq_len, hidden_dim)
        x_v = self.w3(x)  # (B, seq_len, dim) -> (B, seq_len, hidden_dim)
        # (B, seq_len, hidden_dim) * (B, seq_len, hidden_dim) -> (B, seq_len, hidden_dim)
        x = swish * x_v
        x = self.w2(x)  # (B, seq_len, hidden_dim) -> (B, seq_len, dim)
        return x

## Attention

### KV Cache

**What is KV Cache?**

A common technique for improving the performance of large model inferences is by using the KV cache of the last inference. Using the KV cache of the last inference improves inference performance and reduces end-to-end latency without affecting any accuracy.

**Why KV Cache?**

While generating text (tokens) in autoregressive language models like GPT, all the previously generated tokens are fed into the network when generating a new token. Here, the hidden representation of the previously generated tokens needs to be recalculated each time a new token is generated. This causes a lot of computational waste.

As the input tokens for each inference process become longer, it increases inference FLOPs (floating point operations). KV cache solves this problem by storing hidden representations of previously computed key-value pairs while generating a new token.

Consider a transformer architecture with 12 attention heads and KV Cache. The following figure represents the transformer state while generating 9th token of the input sequence.

![Multi-headed Attention with KV Cache](images/Multi-headed-Attention-with-KV-Cache.png)


Self-Attention during Next Token Prediction Task at Inference T=1:

![Self-Attention during Next Token Prediction Task at T1](images/Self-Attention-during-NTP-Task-T1.png)

Self-Attention during Next Token Prediction Task at Inference T=4:

![Self-Attention during Next Token Prediction Task at T4](images/Self-Attention-during-NTP-Task-T4.png)

Where KV Cache is useful:

![Where KV Cache is useful](images/where-kv-cache-come-in.png)

Self-Attention with KV-Cache at Inference T=1:

![Self-Attention with KV-Cache at T1](images/Self-Attention-with-KV-Cache-T1.png)

Self-Attention with KV-Cache at Inference T=4:

![Self-Attention with KV-Cache at T4](images/Self-Attention-with-KV-Cache-T4.png)

In [7]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, None, :]  # (B, seq_len, n_kv_heads, 1, head_dim)
        .expand(
            batch_size, seq_len, n_kv_heads, n_rep, head_dim
        )  # (B, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(
            batch_size, seq_len, n_kv_heads * n_rep, head_dim
        )  # (B, seq_len, n_kv_heads * n_rep, head_dim)
    )

### Attention Variations

#### Vanilla batched Multi-Head Attention

- Multihead Attention as presented in the original paper "Attention is all you need".
- By setting 𝑚 = 𝑛 (sequence length of query = seq. length of keys and values)
- The number of arithmetic operations performed is $O(bnd^2)$
- The total memory involved in the operations, given by the sum of all the tensors involved in the calculations (including the derived ones!) is $O(bnd + bhn^2 + d^2)$ 
- The ratio between the total memory and the number of arithmetic operations is $O (\frac{1}{k} + \frac{1}{bn})$
- In this case, the ratio is much smaller than 1, which means that the number of memory access we are performing is much less than the number of arithmetic operations, so the memory access is not the bottleneck here.

In [1]:
def MultiHeadAttentionBatched():
    d, m, n, b, h, k, v = 512, 10, 10, 32, 8, (512 // 8), (512 // 8)

    X = torch.rand(b, n, d)  # Query
    M = torch.rand(b, m, d)  # Key and Value
    mask = torch.rand(b, h, n, m)
    P_q = torch.rand(h, d, k)  # W_q
    P_k = torch.rand(h, d, k)  # W_k
    P_v = torch.rand(h, d, v)  # W_v
    P_o = torch.rand(h, d, v)  # W_o

    Q = torch.einsum("bnd,hdk->bhnk", X, P_q)
    K = torch.einsum("bmd,hdk->bhmk", M, P_k)
    V = torch.einsum("bmd,hdv->bhmv", M, P_v)

    logits = torch.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = torch.softmax(logits + mask, dim=-1)

    O = torch.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = torch.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

#### Batched Multi-Head Attention with KV cache

- Uses the KV cache to reduce the number of operations performed.
- By setting 𝑚 = 𝑛 (sequence length of query = seq. length of keys and values)
- The number of arithmetic operations performed is $O(bnd^2)$
- The total memory involved in the operations, given by the sum of all the tensors involved in the calculations (including the derived ones!) is $O(bn^2d + nd^2)$
- The ratio between the total memory and the number of arithmetic operations is $O (\frac{n}{d} + \frac{1}{b})$
- When 𝑛 ≈ 𝑑 (the sequence length is close to the size of the embedding vector) or 𝑏 ≈ 1 (the batch size is 1), the ratio becomes 1 and the memory access now becomes the bottleneck of the algorithm. For the batch size is not a problem, since it is generally much higher than 1, while for the 𝑛/𝑑 term, we need to reduce the sequence length. But there’s a better way...


In [None]:
def MultiHeadSelfAttentionIncremental():
    d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # Suppose we have already cached "m" tokens
    prev_K = torch.rand(b, h, m, k)
    prev_V = torch.rand(b, h, m, v)

    X = torch.rand(b, d)  # Query
    M = torch.rand(b, d)  # Key and Value
    P_q = torch.rand(h, d, k)  # W_q
    P_k = torch.rand(h, d, k)  # W_k
    P_v = torch.rand(h, d, v)  # W_v
    P_o = torch.rand(h, d, v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)
    K = torch.concat([prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2)
    V = torch.concat([prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2)

    logits = torch.einsum("bhk,bhmk->bhnm", q, K)
    weights = torch.softmax(logits, dim=-1)

    O = torch.einsum("bhm,bhmv->bhv", weights, V)
    Y = torch.einsum("bhv,hdv->bd", O, P_o)
    return Y, K, V

#### Multi-Query Attention with KV cache

- We remove the ℎ dimension from the 𝐾 and the 𝑉, while keeping it for the 𝑄. This means that all the different query heads will share the same keys and values.
- The number of arithmetic operations performed is $O(bnd^2)$
- The total memory involved in the operations, given by the sum of all the tensors involved in the calculations (including the derived ones!) is $O(bnd + bn^2k + nd^2)$
- The ratio between the total memory and the number of arithmetic operations is $O (\frac{1}{d} + \frac{n}{dh} + \frac{1}{b})$
- Comparing with the previous approach, we have reduced the expensive term 𝑛/𝑑 by a factor of h.
- The performance gains are important, while the model's quality degrades only a little bit.


In [None]:
def MultiquerySelfAttentionIncremental():
    d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # Suppose we have already cached "m" tokens
    prev_K = torch.rand(b, m, k)
    prev_V = torch.rand(b, m, v)

    X = torch.rand(b, d)  # Query
    M = torch.rand(b, d)  # Key and Value
    P_q = torch.rand(h, d, k)  # W_q
    P_k = torch.rand(d, k)  # W_k
    P_v = torch.rand(d, v)  # W_v
    P_o = torch.rand(d, v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)
    K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1)
    V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1)

    logits = torch.einsum("bhk,bmk->bhm", q, K)
    weights = torch.softmax(logits, dim=-1)

    O = torch.einsum("bhm,bmv->bhv", weights, V)
    Y = torch.einsum("bhv,hdv->bd", O, P_o)
    return Y, K, V

#### Grouped Multi-Query Attention

Grouped Multi-Query Attention is a compromise between Multi-Head Attention and Multi-Query Attention:

![Grouped Multi-Query Attention compared](images/Grouped-Multi-Query-Attention-compared.png)

GQA can be thought of as a way to optimize the attention mechanism in transformer-based models. Instead of computing attention for each query independently, GQA groups queries together and computes their attention jointly. This reduces the number of attention computations, leading to faster inference times.

However, while MQA drastically speeds up decoder inference, it can lead to quality degradation. To address this, GQA was introduced as a generalization of MQA, using an intermediate number of key-value heads, which is more than one but less than the number of query heads.

In GQA, query heads are divided into groups, each of which shares a single key head and value head. This approach allows GQA to interpolate between multi-head and multi-query attention, achieving a balance between quality and speed. For instance, GQA with a single group (and therefore a single key and value head) is equivalent to MQA, while GQA with groups equal to the number of heads is equivalent to MHA.

**What are some common methods for implementing Grouped Query Attention?**

Common methods for implementing Grouped Query Attention (GQA) include:
- Grouping queries based on similarity — One popular method for implementing GQA is to group queries based on their similarity. This involves computing a similarity metric between queries and then assigning them to groups accordingly.
- Dividing query heads into groups — In GQA, query heads are divided into groups, each of which shares a single key head and value head. This approach allows GQA to interpolate between multi-head and multi-query attention, achieving a balance between quality and speed.
- Using an intermediate number of key-value heads — GQA strikes a balance between multi-query attention (MQA) and multi-head attention (MHA) by using an intermediate number of key-value heads, which is more than one but less than the number of query heads.
- Repeating key-value pairs for computational efficiency — In GQA, key-value pairs are repeated to optimize performance while maintaining quality. This is achieved by repeating key-value pairs n_rep times, where n_rep corresponds to the number of query heads that share the same key-value pair.

These methods can be combined and adapted to suit the specific requirements of a given task or model architecture.

**What are some challenges associated with Grouped Query Attention?**

There are several challenges associated with GQA:
- Quality Degradation and Training Instability — GQA is an evolution of Multi-Query Attention (MQA), which uses multiple query heads but a single key and value head. While MQA speeds up decoder inference, it can lead to quality degradation and training instability. GQA attempts to mitigate this by using an intermediate number of key-value heads (more than one but fewer than the query heads), but the balance between speed and quality is a challenge.
- Memory Bandwidth Overhead — Autoregressive decoder inference is a severe bottleneck for Transformer models due to the memory bandwidth overhead from loading decoder weights and all attention keys and values at every decoding step. GQA attempts to address this by dividing query heads into groups, each of which shares a single key head and value head. However, managing this memory bandwidth overhead is a significant challenge.
- Complexity of Implementation — Implementing GQA within the context of an autoregressive decoder using a Transformer model can be complex. It involves repeating key-value pairs for computational efficiency, managing cached key-value pairs, and performing scaled dot-product attention computation.
- Group Division — The input nodes are divided into several groups and attention is calculated only within that local block. If the total number of nodes cannot be divided by the group length, zero-padded nodes are added to match the length. This division and management of groups add to the complexity of the GQA implementation.
- Hyperparameter Tuning — Achieving optimal performance with GQA requires careful tuning of hyperparameters. For instance, the number of groups into which the query heads are divided can significantly impact the model's performance and efficiency.




In [8]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        # indicates the number of heads for the keys and values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # indicates the number of heads for the queries
        self.n_heads_q = args.n_heads
        # indicates how many times the keys and values should be repeated
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # indicates the dimension of each head, i.e the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)
        )

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape  # (B, 1, dim)
        xq = self.wq(x)  # (B, 1, dim) -> (B, 1, H_Q * head_dim)
        xk = self.wk(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)
        xv = self.wv(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)

        # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        # Size is the same for xk & xv: (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # Size doesn't change for xq & zk: (B, 1, H_Q, head_dim) -> (B, 1, H_Q, head_dim)
        xq = apply_rotary_embeddings(xq, freqs_complex, x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, x.device)

        # replace the entry in the cache for this token
        self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
        self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

        # retrieve all the cached keys and values so far
        # Size is the same for keys & values: (B, seq_len_kv, H_KV, head_dim)
        keys = self.cache_k[:batch_size, : start_pos + seq_len]
        values = self.cache_v[:batch_size, : start_pos + seq_len]

        # since every group of Q shares the same K & V heads,
        # just repeat the K & V heads for every Q in the same group.
        # Doesn't look like grouped query attention is being done here since only 70B LLaMA has this feature.
        # So this is just multi-head attention.
        # Size is the same for keys & values:
        # (B, seq_len_kv, H_KV, head_dim) -> (B, seq_len_kv, H_Q, head_dim)
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(1, 2)  # (B, 1, H_Q, head_dim) -> (B, H_Q, 1, head_dim)
        # Size is the same for keys & values:
        # (B, seq_len_kv, H_Q, head_dim) -> (B, H_Q, seq_len_kv, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (B, H_Q, 1, head_dim) @ (B, H_Q, head_dim, seq_len_kv) -> (B, H_Q, 1, seq_len_kv)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        # (B, H_Q, 1, seq_len_kv) -> (B, H_Q, 1, seq_len_kv)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # (B, H_Q, 1, seq_len) @ (B, H_Q, seq_len_kv, head_dim) -> (B, H_Q, 1, head_dim)
        output = torch.matmul(scores, values)
        # (B, H_Q, 1, head_dim) -> (B, 1, H_Q, head_dim) -> (B, 1, dim)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.wo(output)  # (B, 1, dim)

## Encoder Block

In [9]:
class EncoderBlock(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.n_head = args.n_heads
        self.dim = args.dim
        self.head_dim = self.dim // self.n_head

        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)

        # Normalization before the attention block
        self.attention_norm = RMSNorm(self.dim, eps=args.norm_eps)
        # Normalization before the feed-forward block
        self.ffn_norm = RMSNorm(self.dim, eps=args.norm_eps)

    def forward(
        self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor
    ) -> torch.Tensor:
        # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex)
        # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

## LLaMA Model: Training

![Transformer vs LLaMA](images/Transformer-vs-LLaMA.png)

- Llama is an encoder only model.
- RMS Norm is before the attention, unlike in Vanilla Transformer
- The block inside Nx is repeated N times.


In [10]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        assert args.vocab_size != -1, "vocab_size must be set"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers  # represents Nx in the figure above: 32 layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(
            self.args.dim // self.args.n_heads,
            self.args.max_seq_len * 2,
            device=self.args.device,
        )

    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, seq_len = tokens.shape  # (B, seq_len)
        assert seq_len == 1, "Only one token at a time can be processed."

        h = self.tok_embeddings(tokens)  # (B, seq_len) -> (B, seq_len, dim)
        # retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
        freqs_complex = self.freqs_complex[start_pos : start_pos + seq_len]

        # consequently apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output

# Inference

## Inferencing strategies

### Logits

- The output of the last linear layer in the Transformer model is called logits. The logits represent the unscaled “probabilities”, but they’re not really probabilities because they do not sum up to 1.
- The softmax scales all the logits in such a way that they sum up to 1.
- The output of the softmax is thus a probability distribution over all the words in the vocabulary, that is, each words in the vocabulary will have a probability associated to it.
- But how do we choose the next token, given this distribution? There are many strategies...

### Greedy

- At every step, we choose the token with the maximum probability, which is appended to the input to generate the next token and so on...
- If the initial tokens happens to be the wrong ones, it’s very likely that the next ones will be wrong as well.
- It is easy to implement but performs poorly in practice.

### Beam Search with K=2

- Beam search has a parameter K which when set to 2 will choose the top 2 probabilities.

Beam Search with K=2 at T=1

![Beam Search with K=2 at T=1](images/Beam-Search-with-K2-at-T1.png)

Beam Search with K=2 at T=2

![Beam Search with K=2 at T=2](images/Beam-Search-with-K2-at-T2.png)

Beam Search with K=2 at T=3

![Beam Search with K=2 at T=3](images/Beam-Search-with-K2-at-T3.png)

### Beam Search

- At every step, we keep alive the top K paths, all the others are killed.
- Increases inferencing time, since at every step must explore K possible options.
- Generally, performs better than the greedy strategy.

### Temperature

- The idea is to scale the logits before applying the softmax
- A **low temperature** makes the model **more confident** (the gap between low and high probabilities increases).
- A **high temperature** makes the model **less confident** (the gap between low and high probabilities reduces).
    ```python
    logits = torch.Tensor([-2.5, -3, -0.6])
    torch.softmax(logits, dim=0)  # No temperature
    >>> tensor([0.1206, 0.0731, 0.8063])

    torch.softmax(logits / 0.4, dim=0)  # Low temperature = 0.4 -> High confident
    >>> tensor([0.0086, 0.0025, 0.9890])

    torch.softmax(logits / 5, dim=0)  # High temperature = 5 -> Less confident
    >>> tensor([0.2970, 0.2687, 0.4343])
    ```

### Random Sampling

- We sample from the random distribution that is output from the softmax.
    ```python
    logits = torch.Tensor([-2.5, -3, -0.6])
    distribution = torch.softmax(logits, dim=0)
    distribution
    >>> tensor([0.1206, 0.0731, 0.8063])
    ```
- The first token will be chosen with a probability of 12.06%, the second with a probability of 7.31% and the last one with a probability of 80.63%
- The higher the probability, the more likely the probability of it being chosen.
- **Problem**: with very little probability it may happen that we choose tokens that are total nonsense.


### Top K

- In the random sampling strategy, it may happen that we choose words that have very little probability, which usually indicates that the token is unrelated to the previous ones.
- With Top K, we keep only the top k highest probabilities, so that tokens with very low probabilities will never be chosen.
    ```python
    # sort the logits in decreasing order
    # we can sort the logits directly because the softmax is a monotonous function
    logits, _ = torch.sort(torch.Tensor([-2.5, -3, -2.8, -0.5, -0.6]), descending=True)
    k = 2
    top_k_logits = logits[:k]
    distribution = torch.softmax(top_k_logits, dim=0)
    distribution
    >>> tensor([0.5250, 0.4750])
    ```
- **Problem**: given the following distributions, low-probability tokens still make their way into the top k tokens (k = 2)
- Distribution 1: **0.5, 0.4**, 0.05, 0.025, 0.025
- Distribution 2: **0.9, 0.05**, 0.025, 0.020, 0.005

### Top P

- With Top P, we keep only the tokens with highest probability, such that their cumulative probability is greater than or equal to the parameter p. This way, we get more tokens for distributions that are more “flat” and less tokens for distributions with a very prominent mode.
    ```python
    # sort the logits in descending order
    p = 0.5
    logits, _ = torch.sort(torch.Tensor([-2.5, -3, -2.8, -0.5, -0.6]), descending=True)
    probs = torch.softmax(logits, dim=0)
    print(f"Probabilities: {probs}")

    probs_cumulative = torch.cumsum(probs, dim=0)
    print(f"Cumulative Probabilities: {probs_cumulative}")

    mask = probs_cumulative - probs > p  # Mask for non-top-p positions
    print(f"Cumulative Probabilities (shifted): {probs_cumulative - probs}")
    print(f"mask: {mask}")

    probs[mask] = 0.0  # zero out all non-top-k tokens
    probs.div_(probs.sum(dim=-1, keepdim=True))  # redistribute probabilities among surviving tokens
    print(f"Top P: {probs}")

    >>> Probabilities: tensor([0.4499, 0.4071, 0.0609, 0.0451, 0.0369])
    >>> Cumulative Probabilities: tensor([0.4499, 0.8571, 0.9180, 0.9631, 1.0000])
    >>> Cumulative Probabilities (shifted): tensor([0.0000, 0.4499, 0.8571, 0.9180, 0.9631])
    >>> mask: tensor([False, False,  True,  True,  True])
    >>> Top P: tensor([0.5250, 0.4750, 0.0000, 0.0000, 0.0000])
    ```

## LLaMA Model: Inference

In [11]:
class LLaMA:
    def __init__(
        self,
        model: Transformer,
        tokenizer: SentencePieceProcessor,
        model_args: ModelArgs,
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.args = model_args

    @staticmethod
    def build(
        checkpoints_dir: str,
        tokenizer_path: str,
        load_model: bool,
        max_seq_len: int,
        max_batch_size: int,
        device: str,
    ):
        prev_time = time.time()
        if load_model:
            checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
            assert len(checkpoints) > 0, f"No checkpoints found in {checkpoints_dir}"
            ckpt_path = checkpoints[0]
            print(f"Loading model from checkpoint: {ckpt_path}")
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            print(f"Loaded checkpoint in {time.time() - prev_time:.2f} seconds")
            prev_time = time.time()

        with open(Path(checkpoints_dir) / "params.json", "r") as f:
            params = json.loads(f.read())

        model_args: ModelArgs = ModelArgs(
            max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            device=device,
            **params,
        )

        tokenizer = SentencePieceProcessor()
        tokenizer.load(tokenizer_path)
        model_args.vocab_size = tokenizer.vocab_size()

        if device == "cuda":
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        else:
            torch.set_default_tensor_type(torch.BFloat16Tensor)

        model = Transformer(model_args).to(device)

        if load_model:
            # The only unmatched key in the checkpoint is rotary positional encoding freqs
            # because we will create it during inference, so we can remove it
            del checkpoint["rope.freqs"]
            model.load_state_dict(checkpoint, strict=True)
            print(f"Loaded state dict in {time.time() - prev_time:.2f} seconds")

        return LLaMA(model, tokenizer, model_args)

    def text_completion(
        self,
        prompts: List[str],
        device: str,
        temperature: float = 0.6,
        top_p: float = 0.9,
        max_gen_len: Optional[int] = None,
    ):
        if max_gen_len is None:
            max_gen_len = self.args.max_seq_len - 1

        # convert each prompt into tokens
        prompt_tokens = [
            self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False)
            for prompt in prompts
        ]

        # Make sure the batch size is not too large
        batch_size = len(prompt_tokens)
        assert (
            batch_size <= self.args.max_batch_size
        ), f"Batch size {batch_size} must be less than or equal to the Max batch size: {self.args.max_batch_size}"

        max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
        # Make sure the prompt length is not larger than the max sequence length
        assert (
            max_prompt_len <= self.args.max_seq_len
        ), f"Prompt length {max_prompt_len} must be less than or equal to the Max sequence length: {self.args.max_seq_len}"

        total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)

        # create the list that will contain the generated tokens, along with the initial prompt tokens
        pad_id = self.tokenizer.pad_id()
        tokens = torch.full(
            (batch_size, total_len), pad_id, dtype=torch.long, device=device
        )
        for k, t in enumerate(prompt_tokens):
            # populate the initial tokens with the prompt tokens
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)

        eos_reached = torch.tensor([False] * batch_size, device=device)
        # True if the token is a prompt token, False otherwise
        prompt_tokens_mask = tokens != pad_id

        cur_iterator = tqdm(range(1, total_len), desc="Generating tokens...")
        for cur_pos in cur_iterator:
            with torch.no_grad():
                logits = self.model.forward(tokens[:, cur_pos - 1 : cur_pos], cur_pos)
            if temperature > 0:
                # The temperature is applied before the softmax
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = self._sample_top_p(probs, top_p)
            else:
                # greedily select the token with the max probability
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # only replace token if it is a padding token
            next_token = torch.where(
                prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token

            # EOS is reached only if we found an EOS token for a padding position
            eos_reached != (
                ~prompt_tokens_mask[:, cur_pos] & (next_token == self.tokenizer.eos_id)
            )
            if all(eos_reached):
                break

        out_tokens = []
        out_text = []
        for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
            # cut to the EOS token, if present
            if self.tokenizer.eos_id in current_prompt_tokens:
                eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
                current_prompt_tokens = current_prompt_tokens[:eos_idx]
            out_tokens.append(current_prompt_tokens)
            out_text.append(self.tokenizer.decode(current_prompt_tokens))

        return (out_tokens, out_text)

    def _sample_top_p(self, probs: torch.Tensor, p: float):
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)  # (B, vocab_size)
        # (substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
        mask = probs_sum - probs_sort > p  # (B, vocab_size)
        # zero out all the probabilities of tokens that are not selected by the top_p
        probs_sort[mask] = 0.0
        # redistribute the probabilities so that they sum up to 1.
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        # sample a token (its index) from the top p distribution
        next_token = torch.multinomial(probs_sort, num_samples=1)
        # get the token position in the vocabulary corresponding to the sampled index
        next_token = torch.gather(probs_idx, -1, next_token)
        return next_token

## Inference Run

In [12]:
# NOTE: If you want to rerun this with limited GPU: I am using 16175MiB, RTX 3080 Ti Laptop GPU
# You will need to restart the kernal cause jupyter notebook doesn't release the GPU memory
# I tried: `torch.cuda.empty_cache()` to release all the GPU memory that can be freed
# and `%reset -f` to clear jupyter notebook variables, but neither worked

torch.manual_seed(42)

model_version = "llama-2-7b"
allow_cuda = True
device = "cuda" if torch.cuda.is_available() and allow_cuda else "cpu"

prompts = [
    "Simply put, the theory of relativity states that ",
    "If Google was an Italian company founded in Milan, it would",
    # Few shot promt
    """Translate English to French:
    sea otter => loutre de mer
    peppermint => menthe poivrée
    plush girafe => girafe peluche
    cheese =>""",
    # Zero shot prompt
    """Tell me if the following person is actually a Jedi knight disguised as human:
    Name: Mukesh Mithrakumar
    Decision: 
    """,
]

model = LLaMA.build(
    checkpoints_dir=f"models/{model_version}",
    tokenizer_path=f"models/{model_version}/tokenizer.model",
    load_model=True,
    max_seq_len=1024,
    max_batch_size=len(prompts),
    device=device,
)

out_tokens, out_texts = model.text_completion(prompts, device, max_gen_len=64)
assert len(out_texts) == len(prompts)
for i in range(len(out_texts)):
    print(f"{out_texts[i]}\n{'-'*50}")

Loading model from checkpoint: models\llama-2-7b\consolidated.00.pth
Loaded checkpoint in 5.99 seconds


  _C._set_default_tensor_type(t)


Loaded state dict in 9.10 seconds


Generating tokens...: 100%|██████████| 112/112 [00:08<00:00, 12.54it/s]

Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 2) the laws of physics are the same for all inertial reference frames.Ћ The theory of relativity has two versions: the special theory of relativity and the general theory of relativity. The special theory of relativity applies to all physical phenomena in inertial reference frames. The general theory of relativity applies to the phenomena in non-inertial reference frames. The theory
--------------------------------------------------
If Google was an Italian company founded in Milan, it would be the largest company in the world, with a market capitalization of $ 713 billion. The company was founded in 1998 by Larry Page and Sergey Brin, two PhD students at Stanford University.
The company's name is a reference to the number of letters in the word "googol", which is a mathematical term for the number 1 followed by 100 zeros.
The company has its headquarters in Mountain


