# KV cache

The goal of caching the Key (K) and Value (V) states is to speedup the inference of autoregressive decoder like GPT.

The goal of this practical is to adapt the code of [minGPT](https://github.com/karpathy/minGPT/) form [Karpathy](https://karpathy.ai/) in order to incorporate KV-caching. We will only need the two main files [`model.py`](https://github.com/karpathy/minGPT/blob/master/mingpt/model.py) and [`trainer.py`](https://github.com/karpathy/minGPT/blob/master/mingpt/trainer.py) from this repo.

Using [Named Tensor Notation](https://hackmd.io/@mlelarge/HkVlvrc8j), we write (see the paper by [Chiang, Rush and Barak](https://arxiv.org/abs/2102.13196))
\begin{align*}
\newcommand{\namedtensorstrut}{\vphantom{fg}}
\newcommand{\nfun}[2]{\mathop{\underset{\substack{#1}}{\namedtensorstrut\mathrm{#2}}}}
\newcommand{\name}[1]{\mathsf{\namedtensorstrut #1}}
\newcommand{\ndef}[2]{\newcommand{#1}{\name{#2}}}
\ndef{\ax}{ax}
\ndef{\bx}{bx}
\newcommand{\reals}{\mathbb{R}}
\ndef{\batch}{batch}
\ndef{\layer}{layer}
\ndef{\chans}{chans}
\ndef{\key}{key}
\ndef{\seq}{seq}
\ndef{\val}{val}
\ndef{\heads}{heads}
\ndef{\hidden}{hidden}
\ndef{\height}{height}
\ndef{\width}{width}
\newcommand{\nbin}[2]{\mathbin{\underset{\substack{#1}}{\namedtensorstrut #2}}}
\newcommand{\ndot}[1]{\nbin{#1}{\odot}}
\text{Attention} \colon \mathbb{R}^{\key} \times \mathbb{R}^{\seq \times\key} \times \mathbb{R}^{\seq \times\val} &\rightarrow \mathbb{R}^{\val} \\
  \text{Attention}(Q,K,V) &= \left( \nfun{\seq}{softmax} \frac{Q \ndot{\key} K}{\sqrt{|\key|}} \right) \ndot{\seq} V.
\end{align*}

During inference, when we compute the attention for the $t$-th token of a sequence, we get:
\begin{align*}
\text{Attention} \colon \mathbb{R}^{\key} \times \mathbb{R}^{\seq(t-b:t) \times\key} \times \mathbb{R}^{\seq(t-b:t) \times\val} &\rightarrow \mathbb{R}^{\val} \\
  \text{Attention}(Q_t,K_t,V_t) &= \left( \nfun{\seq}{softmax} \frac{Q_t \ndot{\key} K_t}{\sqrt{|\key|}} \right) \ndot{\seq} V_t,
\end{align*}
where $b$ is the size of a block and $t-b$ should be interpreted as $\max(t-b,0)$.

For the computation at time $t+1$, we see that to compute $K_{t+1}$ and $V_{t+1}$ from $K_t$ and $V_t$, we need only to compute the last idice from $\seq(t-b+1:t+1)$ if we stored all other indices $\seq(t-b+1:t)$. This is exactly what we need to do!

![](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*uyuyOW1VBqmF5Gtv225XHQ.gif)

In [None]:
import math
from dataclasses import dataclass
import time
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F

## Modifying Self-attention

We start from the code from Karpathy

In [None]:
# source: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

In [None]:
@dataclass
class Config:
    n_head = 3
    n_embd = 15
    block_size = 11
    # dropout hyperparameters
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    
config = Config()
csa = CausalSelfAttention(config)

In [None]:
bs = 6
x = torch.randn(bs, config.block_size, config.n_embd)
out = csa(x)

In [None]:
out.shape

In [None]:
csa.bias.shape

Now, we need to modify the code in order to add kv-cache. We propose to do a simple modification where the forward pass take as input in addition to `x` the `kv_cache` as a list of tensors `[k, v]` and returns the output `y` and the updated `kv_cache`:

In [None]:
class CausalSelfAttention_kv(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.block_size = config.block_size

    def forward(self, x, kv_cache=None):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        
        ###
        # your code here
        ####
        
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y, kv_cache

In [None]:
config = Config()
csa = CausalSelfAttention_kv(config)
csa.eval()

In [None]:
out, kv = csa(x)

In [None]:
x.shape

Check the shape of the kv cache.

In [None]:
kv[0][:,:-1,:].shape

In [None]:
first = x[:,:10,:]
last = x[:,[10],:]

In [None]:
out_kv, kv_cache = csa(last, kv_cache=[kv[0][:,:-1,:], kv[1][:,:-1,:]])

In [None]:
torch.isclose(out[:,-1,:], out_kv[:,0,:])

In [None]:
for k in range(10):
    out_kv, kv_cache = csa(x[:,-k:,:], kv_cache=[kv[0][:,:-k,:], kv[1][:,:-k,:]])
    print(k, torch.allclose(out[:,-k,:], out_kv[:,0,:], rtol=1e-4))

## Modifying the Block

Here is the original code of Karpathy:
```python
class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x
```

and how it is used in the GPT class:
```python
class GPT(nn.Module):
    def __init__(self, config):
        ...
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        ...
        
    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

```

You need to adapt first the `Block` to include kv-cache. Provide some tests for your code.

In [None]:
from mingpt.model import NewGELU

class Block_kv(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention_kv(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward

    def forward(self, x, kv_cache=None):
        ###
        # your code here
        #

In [None]:
bkv = Block_kv(config)

In [None]:
bkv.eval()
out, kv = bkv(x)

In [None]:
first = x[:,:10,:]
last = x[:,[10],:]

In [None]:
out_first, kv_first = bkv(first)

In [None]:
out_kv, kv_cache = bkv(last, kv_cache=kv_first)

In [None]:
out_kv, kv_cache = bkv(last, kv_cache=[kv[0][:,:-1,:], kv[1][:,:-1,:]])

In [None]:
kv[0].shape

In [None]:
out_kv.shape

In [None]:
torch.isclose(out[:,-1,:], out_kv[:,0,:])

## Modifying the GPT class

Now we need to adapt the main class to include kv-cache. The only change in the `init` has been done and consists in using `Block_kv` instead of `Block.`
Then you need to override the methods `forward` (see above) and `generate` below:
```python
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # either sample from the distribution or take the most likely element
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx
```

In [None]:
from mingpt.model import GPT

class GPT_kv(GPT):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block_kv(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.n_layer = config.n_layer
        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
    
    def forward(self, idx, targets=None, kv_cache=None, compute_first=False):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        ###
        # your code here
        ###
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        if kv_cache is None:
            return logits, loss
        else:
            return logits, loss, new_kv_cache

    @torch.no_grad()
    def generate_kv(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        ###
        # your code here
        ###

In [None]:
# create a GPT instance
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = 3
model_config.block_size = 100
model = GPT_kv(model_config)
model.eval();

Here is a sample of lenght 7 to make some tests for the forward method.

In [None]:
inp = torch.tensor([[0, 0, 2, 1, 0, 1, 2]], dtype=torch.long)
inp.shape

In [None]:
logits, _ = model(inp)

In [None]:
kv_cache = [None] * model_config.n_layer
logits_kv, _, kv_cache = model(inp[:,[0]], kv_cache=kv_cache)

In [None]:
torch.isclose(logits[:,0,:], logits_kv[:,0,:])

In [None]:
logits_kv, _, kv_cache = model(inp[:,0:2], kv_cache=kv_cache)

In [None]:
torch.isclose(logits[:,1,:], logits_kv[:,0,:])

In [None]:
logits_kv, _, kv_cache = model(inp[:,0:3], kv_cache=kv_cache)

In [None]:
torch.isclose(logits[:,2,:], logits_kv[:,0,:])

In [None]:
logits_kv[:,0,:].shape

Another test related to the `forward` method before testing `generate`:

In [None]:
kv_cache = [None] * model_config.n_layer
logits_kv1, _, kv_cache1 = model(inp[:,0:2], kv_cache=kv_cache, compute_first=True) #you might want to modify this line 

In [None]:
logits_kv2, _, kv_cache2 = model(inp[:,0:3], kv_cache=kv_cache1)

In [None]:
torch.isclose(logits_kv2[:,0,:], logits_kv[:,0,:])

In [None]:
with torch.no_grad():
    cat = model.generate_kv(inp, 10, do_sample=False)                                       
cat

In [None]:
cat.shape

In [None]:
inp

In [None]:
out, _ = model(cat)

In [None]:
out.shape

## Learning to sort

We use the [demo](https://github.com/karpathy/minGPT/blob/master/demo.ipynb) to check that our code is running fine!

In [None]:
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import set_seed
set_seed(3407)

In [None]:
import pickle

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

In [None]:
# print an example instance of the dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
x, y = train_dataset[0]
for a, b in zip(x,y):
    print(int(a),int(b))

In [None]:
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT_kv(model_config)

In [None]:
# create a Trainer object
from mingpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 1000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

In [None]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()

In [None]:
# now let's perform some evaluation
model.eval();

In [None]:
loader = DataLoader(train_dataset, batch_size=10, num_workers=0, drop_last=False)
x, y = next(iter(loader))
n = train_dataset.length
x = x.to(trainer.device)
y = y.to(trainer.device)
# isolate the input pattern alone
inp = x[:, :n]
sol = y[:, -n:]
# let the model sample the rest of the sequence
cat = model.generate(inp, n, do_sample=False)

In [None]:
def eval_split(trainer, split, max_batches):
    dataset = {'train':train_dataset, 'test':test_dataset}[split]
    n = train_dataset.length # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate_kv(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, -n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
        if max_batches is not None and b+1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
    return rt.sum()

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50)
    test_score  = eval_split(trainer, 'test',  max_batches=50)

In [None]:
cat.shape

In [None]:
# let's run a random given sequence through the model as well
n = train_dataset.length # naugy direct access shrug
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate_kv(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

In [None]:
inp = torch.tensor([[0, 0, 2, 1, 0, 1, 2]], dtype=torch.long)
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-mini'
model_config.vocab_size = 9
model_config.block_size = 500 
model = GPT_kv(model_config)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
model = model.to(device)
inp = inp.to(device)
print("running on device", device)
model.eval();

In [None]:
n = 1000
for use_kv in (False, True):
    times = []
    for _ in range(10):  # measuring 10 generations
        start = time.time()
        with torch.no_grad():
            if use_kv:
                cat = model.generate_kv(inp, n, do_sample=False)
            else:
                cat = model.generate(inp, n, do_sample=False)
        times.append(time.time() - start)
    print(f"{'with' if use_kv else 'without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")