In [1]:
import sys
sys.path.insert(0, '../../')

In [2]:
import time
import numpy as np
import candle

## (1) Initialize Target and Draft Models

In [3]:
target_model = candle.models.gpt.GPT.from_pretrained('gpt2-large')
draft_model = candle.models.gpt.GPT.from_pretrained('gpt2')
tokenizer = candle.models.gpt.GPT2BPETokenizer()

Loading file from cache: /home/johnma2006/.cache/candle/gpt2_encoder.json
Loading file from cache: /home/johnma2006/.cache/candle/gpt2_vocab.bpe


## (2) Compare Speculative vs Autoregressive Sampling Speed

In [5]:
def compare_speculative_sampling_vs_autoregressive_decoding(prompt: str,
                                                            n_tokens_to_gen: int,
                                                            temperature: float,
                                                            K: int = 4):
    print(f'Prompt: {prompt}')
    indices = candle.Tensor([tokenizer.encode(prompt)])
    
    # Speculative sampling
    
    target_model.clear_kv_cache()
    draft_model.clear_kv_cache()
    generator = candle.nlp.speculative_sample(
        target_model,
        draft_model,
        K=K,
        indices=indices,
        n_tokens_to_gen=n_tokens_to_gen,
        top_k=40,
        top_p=0.95,
        temperature=temperature,
    )
    
    start_time = time.time()
    response = tokenizer.decode(np.concatenate(list(generator)))
    end_time = time.time()
    
    print('\nSPECULATIVE SAMPLING')
    print('====================')
    print(f'Generation time: {end_time - start_time:.1f} sec')
    print(f'Response: {response}')
    
    # Autoregressive decoding
    
    target_model.clear_kv_cache()
    generator = candle.nlp.generation.batch_generation(
        target_model, 
        indices=indices,
        n_tokens_to_gen=n_tokens_to_gen,
        top_k=40,
        top_p=0.95,
        temperature=temperature,
    )
    
    start_time = time.time()
    response = tokenizer.decode(np.concatenate(list(generator)))
    end_time = time.time()
    
    print('\nAUTOREGRESSIVE DECODING')
    print('====================')
    print(f'Generation time: {end_time - start_time:.1f} sec')
    print(f'Response: {response}')

> **Note:** we do not see benefits from speculative sampling in the below experiments. This is likely because the key observation from the paper, `the latency of parallel scoring of short continuations, generated by a faster but less powerful draft model, is comparable to that of sampling a single token from the larger target model` does not hold true in my single-threaded, throughput-bound, CPU-based setup.

In [6]:
compare_speculative_sampling_vs_autoregressive_decoding(
    'The meaning of life is ',
    n_tokens_to_gen=50,
    temperature=0.0
)

Prompt: The meaning of life is 

SPECULATIVE SAMPLING
Generation time: 14.8 sec
Response:  the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life

AUTOREGRESSIVE DECODING
Generation time: 13.1 sec
Response:  the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life is the meaning of life


In [7]:
compare_speculative_sampling_vs_autoregressive_decoding(
    'Today on the Daily Show, ',
    n_tokens_to_gen=50,
    temperature=0.0
)

Prompt: Today on the Daily Show, 

SPECULATIVE SAMPLING
Generation time: 20.5 sec
Response:  Jon Stewart and Stephen Colbert took a look at the "war on Christmas" and how it's been used to justify the government's war on Christmas.
The Daily Show: "War on Christmas"
The Daily Show: "War on Christmas"

AUTOREGRESSIVE DECODING
Generation time: 13.0 sec
Response:  Jon Stewart and Stephen Colbert took a look at the "war on Christmas" and how it's been used to justify the government's war on Christmas.
The Daily Show: "War on Christmas"
The Daily Show: "War on Christmas


In [8]:
compare_speculative_sampling_vs_autoregressive_decoding(
    'GPT2 is the ',
    n_tokens_to_gen=50,
    temperature=0.0
)

Prompt: GPT2 is the 

SPECULATIVE SAMPLING
Generation time: 14.4 sec
Response:  GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of

AUTOREGRESSIVE DECODING
Generation time: 13.2 sec
Response:  GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of the GPT2 of


In [9]:
compare_speculative_sampling_vs_autoregressive_decoding(
    'Dolphins are ',
    n_tokens_to_gen=50,
    temperature=1.0
)

Prompt: Dolphins are 

SPECULATIVE SAMPLING
Generation time: 25.9 sec
Response: ileocarboxylic acids that contain approximately 50% to 85% propyl fatty acids and are derived from the seed oils of the citrus peels of the fruit.

The most common phytochemicals found in coconut are laur

AUTOREGRESSIVE DECODING
Generation time: 13.0 sec
Response: icky.

The best part? They're not the worst thing to have on a car, unless you happen to live in Los Angeles.

"I've always been a fan of blue and white and I like the idea of a blue
