In [None]:
import types, sys, importlib.util

imp = types.SimpleNamespace(
    reload=importlib.reload,
    find_module=lambda name, path=None: importlib.util.find_spec(name)
)
sys.modules['imp'] = imp

import torch
import torch.nn.functional as F


def get_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


device = torch.device(get_device())
print(f"Using device: {device}")

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [None]:
def generate_n_tokens(
    input_ids: torch.Tensor, n: int, sampling_function: callable
) -> torch.Tensor:
    generated = input_ids.clone()
    for _ in range(n):
        with torch.no_grad():
            logits = model(generated).logits[:, -1, :]
        next_token = sampling_function(logits)
        generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1)
    return generated


def sample_from_logits(logits: torch.Tensor) -> torch.Tensor:
    """
    Takes logits and converts them to probabilities and samples from thier distribution
    """
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).squeeze(-1)

In [None]:
# Sample vocabulary
sample_vocab = [
    "token1",
    "token2",
    "token3",
    "token4",
    "token5",
    "token6",
    "token7",
    "token8",
    "token9",
    "token10",
]
vocabulary_size = len(sample_vocab)

# Sample logits
sample_logits = torch.tensor(
    [
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
        [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
        [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
        [1.0, 1.0, 1.0, 1.0, 10.0, 1.0, 1.0, 1.0, 1.0, 1.0],
    ]
)

In [None]:
# Function to convert token indices to vocabulary tokens
def indices_to_tokens(indices):
    return [sample_vocab[i] for i in indices]

def greedy_search(logits: torch.Tensor) -> torch.Tensor:
    """
    Select the token with the largest logit
    """
    #### Your code here ####
    return torch.argmax(logits, dim=1)
    ########################

# Test greedy search
greedy_results = greedy_search(sample_logits)
print("Greedy Search Results:", indices_to_tokens(greedy_results))

In [None]:
def top_k_sampling(logits: torch.Tensor, k: int) -> torch.Tensor:
    """
    Returns new logits with all values, except for the k largest, set to -inf
    """
    assert k >= 1, f"k was set to {k}, k must be positive"

    #### Your code here ####

    top_k_indices = torch.topk(logits, k).indices

    # Create a mask with -inf for all values except the top k
    mask = torch.ones_like(logits) * float('-inf')

    # Fill the mask with the original logits values at the top k indices
    # This effectively leaves only the top k logits unchanged, while others are set to -inf
    # scatter is used to assign values at specific indices in the mask tensor
    mask.scatter_(-1, top_k_indices, logits)

    return mask
    pass

In [None]:
# Test top-k sampling
k = 1
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))
k = 3
top_k_logits = top_k_sampling(sample_logits, k)
top_k_results = sample_from_logits(top_k_logits)
print(f"Top-{k} Sampling Results:", indices_to_tokens(top_k_results))

In [None]:
def top_p_sampling(logits: torch.Tensor, p: float):
    """
    Perform top-p (nucleus) sampling on logits.

    Args:
    logits: torch.Tensor of shape (..., vocab_size)
    p: float, cumulative probability threshold

    Returns:
    torch.Tensor of the same shape as logits, with values outside the top-p set to -inf
    """
    #### Your code here ####
    # calculate the probabilities
    probs = F.softmax(logits, dim=-1)
    # sort them
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    # calculate the cumalative probabilities
    cumsum = torch.cumsum(sorted_probs, dim=-1)

    # Remove tokens with cumulative probability above the threshold
    #over_threshold_index = torch.where(cumsum > p)
    over_threshold_indices = cumsum > p
    over_threshold_indices[:,0] = False
    mask = torch.full_like(logits, float('-inf'))  # Initialize a mask with -inf

    # Scatter back the valid logits where cumulative sum is <= p
    mask.scatter_(-1, sorted_indices, sorted_probs)

    # Apply the mask to the logits, setting low-probability tokens to -inf
    logits = logits + mask  # This will apply the mask to logits (set logits outside the top-p to -inf)

    return logits

In [None]:
# Test top-p sampling
p = 0.05
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))
p = 0.9
top_p_logits = top_p_sampling(sample_logits, p)
top_p_results = sample_from_logits(top_p_logits)
print(f"Top-p Sampling Results (p={p}):", indices_to_tokens(top_p_results))

In [None]:
def temperature_sampling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """
    Scales logits by temprature
    """
    #### Your code here ####
    # should be one line of code
    return logits / temperature
    ########################
    pass

In [None]:
# Test temperature sampling
temperature = 0.1
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)
temperature = 5
temp_logits = temperature_sampling(sample_logits, temperature)
temp_results = sample_from_logits(temp_logits)
print(
    f"Temperature Sampling Results (T={temperature}):", indices_to_tokens(temp_results)
)

In [None]:
# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

greedy_output = generate_n_tokens(input_ids, n_tokens, greedy_search)
top_k_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_k_sampling(x, k=5))
)
top_p_output = generate_n_tokens(
    input_ids, n_tokens, lambda x: sample_from_logits(top_p_sampling(x, p=0.05))
)
temp_output = generate_n_tokens(
    input_ids,
    n_tokens,
    lambda x: sample_from_logits(temperature_sampling(x, temperature=1.5)),
)

# Decode outputs
print("Greedy:", tokenizer.decode(greedy_output[0], clean_up_tokenization_spaces=True))
print("Top-k:", tokenizer.decode(top_k_output[0], clean_up_tokenization_spaces=True))
print("Top-p:", tokenizer.decode(top_p_output[0], clean_up_tokenization_spaces=True))
print(
    "Temperature:", tokenizer.decode(temp_output[0], clean_up_tokenization_spaces=True)
)

# often times you will see temprature and top p or top k combined so that we remove all unlikely next tokens and
# make some of the somewhat likely tokens more likely to be sampled
# try playing around with the temprature and p and k and see how good of an output you can get!

# Generate n tokens using different sampling strategies
n_tokens = 40

# Prepare input
text = "Once upon a time, there was a"
input_ids = tokenizer.encode(text, return_tensors="pt").to(device)

p = 0.8
k = 20
temperature = .15


def temp_top_k(x):
    return sample_from_logits(
        temperature_sampling(top_k_sampling(x, k=k), temperature=temperature)
    )


def temp_top_p(x):
    return sample_from_logits(
        temperature_sampling(top_p_sampling(x, p=p), temperature=temperature)
    )


temp_top_p_output = generate_n_tokens(input_ids, n_tokens, temp_top_p)
temp_top_k_output = generate_n_tokens(input_ids, n_tokens, temp_top_k)

# Decode outputs
print(
    "Temperature and Top-k:",
    tokenizer.decode(temp_top_k_output[0], clean_up_tokenization_spaces=True),
)
print(
    "Temperature and Top-p:",
    tokenizer.decode(temp_top_p_output[0], clean_up_tokenization_spaces=True),
)