In [1]:
import sys
sys.path.insert(0, '../../')

In [2]:
import numpy as np
import time
import matplotlib.pyplot as plt
from typing import List, Tuple, Union
import pandas as pd
from IPython.display import display, Markdown

import candle
import experiments.textgenutils as gutils

## (1) Initialize Model with Pre-trained Weights

In [3]:
# One of ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
#    gpt2:         124M params
#    gpt2-medium:  354M params
#    gpt2-large:   774M params
#    gpt2-xl:    1,557M params

model = candle.models.gpt.GPT.from_pretrained('gpt2-large')

## (2) Profile Speedup and Memory Consumption from KV Caching

In [4]:
def profile_kv_cache_generation_time_and_memory_consumption(prompt: str,
                                                            n_tokens_to_generate: int):
    tokenizer = candle.models.gpt.GPT2BPETokenizer()
    indices = tokenizer.encode(prompt)

    print('Prompt:', prompt)
    print('Tokens to generate:', n_tokens_to_generate)
    
    model_param_fp32s = int(sum([np.prod(p.shape) for p in model.parameters().values()]))
    print('Model Param Memory Consumption:', f'{model_param_fp32s / 1e6 * 4:.1f} MB')
    print()

    for use_kv_cache in [False, True]:
        if use_kv_cache:
            print('USING KV CACHE')
            print('==============')
        else:
            print('NOT USING KV CACHE')
            print('==================')

        model.clear_kv_cache()

        generator = gutils.generate_text(model, tokenizer, prompt,
                                         n_tokens_to_generate=n_tokens_to_generate,
                                         beam_size=1,
                                         top_k=100,
                                         top_p=0.95,
                                         sample=False,
                                         use_kv_cache=use_kv_cache)

        start_time = time.time()
        response = ''.join(list(generator))
        end_time = time.time()

        if use_kv_cache:
            kv_cache_fp32s = np.prod(model.decoder_blocks[0].attn.kv_cache[0].shape)
            peak_kv_cache_fp32s = (len(indices) + n_tokens_to_generate) * kv_cache_fp32s // len(indices)
        else:
            peak_kv_cache_fp32s = 0

        print('Generation time:'.ljust(28), f'{end_time - start_time:.1f} sec')
        print('KV Cache Memory Consumption:'.ljust(28), f'{peak_kv_cache_fp32s / 1e6 * 4:.1f} MB')
        print()

    print('Response:', response.lstrip())

>     Note: KV cache memory figures are only on single batch and a relatively small context length,
    and in practice will be much larger.


In [5]:
profile_kv_cache_generation_time_and_memory_consumption(
    'Note: KV cache memory figures are only on single batch and a relatively small context length, '
    'and in practice will be much larger.',
    n_tokens_to_generate = 100
)

Prompt: Note: KV cache memory figures are only on single batch and a relatively small context length, and in practice will be much larger.
Tokens to generate: 100
Model Param Memory Consumption: 3096.1 MB

NOT USING KV CACHE
Generation time:             168.0 sec
KV Cache Memory Consumption: 0.0 MB

USING KV CACHE
Generation time:             28.5 sec
KV Cache Memory Consumption: 0.7 MB

Response: The KV cache is a memory-mapped file that is used to store the data that is to be processed by the kernel. The KV cache is a special type of memory that is used to store data that is to be processed by the kernel. The KV cache is a special type of memory that is used to store data that is to be processed by the kernel. The KV cache is a special type of memory that is used to store data that is to be processed


In [6]:
profile_kv_cache_generation_time_and_memory_consumption(
    'The young hacker stumbles upon a secret that could change the fate of both worlds.',
    n_tokens_to_generate = 100
)

Prompt: The young hacker stumbles upon a secret that could change the fate of both worlds.
Tokens to generate: 100
Model Param Memory Consumption: 3096.1 MB

NOT USING KV CACHE
Generation time:             142.4 sec
KV Cache Memory Consumption: 0.0 MB

USING KV CACHE
Generation time:             28.1 sec
KV Cache Memory Consumption: 0.6 MB

Response: The Dark Crystal: Age of Resistance

In this prequel to the fantasy classic, three young Gelflings inspire a rebellion against the cruel emperor when they discover a horrifying secret.

The Curious Creations of Christine McConnell

Wickedly talented baker and artist Christine McConnell fills her home with haunting confections, creepy crafts -- and wildly inappropriate creatures.

The Kindergarten Teacher

A devoted teacher takes interest in a young student's creative potential after hearing his


In [7]:
profile_kv_cache_generation_time_and_memory_consumption(
    'A lone astronaut floats weightlessly in the vast expanse of space, their small capsule '
    'dwarfed by the swirling nebulae and distant galaxies. The astronaut\'s gaze drifts towards '
    'the Earth, a tiny blue marble suspended in the cosmic void, their home planet now a distant memory.',
    n_tokens_to_generate = 50
)

Prompt: A lone astronaut floats weightlessly in the vast expanse of space, their small capsule dwarfed by the swirling nebulae and distant galaxies. The astronaut's gaze drifts towards the Earth, a tiny blue marble suspended in the cosmic void, their home planet now a distant memory.
Tokens to generate: 50
Model Param Memory Consumption: 3096.1 MB

NOT USING KV CACHE
Generation time:             87.4 sec
KV Cache Memory Consumption: 0.0 MB

USING KV CACHE
Generation time:             15.3 sec
KV Cache Memory Consumption: 0.5 MB

Response: The astronaut's eyes are closed, their body weightless, their mind free to wander. They are alone in the vastness of space.

The astronaut's eyes are closed, their body weightless, their mind free to wander.


In [8]:
profile_kv_cache_generation_time_and_memory_consumption(
    'What is an Apple iPhone?',
    n_tokens_to_generate = 50
)

Prompt: What is an Apple iPhone?
Tokens to generate: 50
Model Param Memory Consumption: 3096.1 MB

NOT USING KV CACHE
Generation time:             32.0 sec
KV Cache Memory Consumption: 0.0 MB

USING KV CACHE
Generation time:             13.6 sec
KV Cache Memory Consumption: 0.3 MB

Response: The iPhone is a mobile phone that is used to make calls, send and receive text messages, and access the Internet. It is also used to make and receive phone calls, send and receive text messages, and access the Internet.


