In [None]:
import time
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import polars as pl 
import json

from decoding_strategies_over_custom_gpt import GenerativeModel
from gpt import CharTokenizer, get_persons

In [None]:
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
tokenizer = CharTokenizer()

with open("tokenizer_vocab.json", "r", encoding="utf-8") as f:
    raw_vocab = json.load(f)
    tokenizer._vocab = {int(k): v for k, v in raw_vocab.items()}

In [None]:
VOCAB_SIZE = len(tokenizer.vocab)
BATCH_SIZE = 1024
MAX_SEQ_LEN = 200
N_LAYERS = 6
EMBEDDING_SIZE = 128
NUM_HEADS = 8
NUM_KV_GROUPS = 2
NUM_EXPERTS = 16
NUM_EXPERTS_PER_TOKEN = 2
HEAD_EMBEDDING_SIZE = EMBEDDING_SIZE // NUM_HEADS
FCCN_HIDDEN_SIZE = EMBEDDING_SIZE * 4
n_epoch = 20

In [None]:
model_config = dict(
    vocab_size=VOCAB_SIZE,
    n_layers=N_LAYERS,
    embedding_size=EMBEDDING_SIZE,
    num_heads=NUM_HEADS,
    num_kv_groups=NUM_KV_GROUPS,
    num_experts=NUM_EXPERTS,
    num_experts_per_token=NUM_EXPERTS_PER_TOKEN,
    head_embedding_size=HEAD_EMBEDDING_SIZE,
    fcnn_hidden_size=FCCN_HIDDEN_SIZE,
    dropout=0.15,
)

generator = GenerativeModel(**model_config)

generator.load_state_dict(torch.load("my_gpt_weights.pt", map_location=device))

generator.to(device)
generator.eval()

In [None]:
def benchmark_parameter(param_name, param_values, prompt, fixed_params):
    """
    Generates text changing one param
    """
    results = []
    
    print(f"\nTesting param: {param_name}")
    
    for val in tqdm(param_values):
        current_params = fixed_params.copy()
        current_params[param_name] = val
        
        start_time = time.time()
        
        output_text = generator.generate(
            text_input=prompt, 
            tokenizer=tokenizer, 
            device=device, 
            **current_params
        )
        
        end_time = time.time()
        
        results.append({
            "Parameter": param_name,
            "Value": val,
            "Time (sec)": round(end_time - start_time, 4),
            "Output": output_text,
            "Length (chars)": len(output_text)
        })
        
    return pd.DataFrame(results)

def plot_results(df, title):
    """
    Plot text length and time generation taken graphs
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # TIME TAKEN
    sns.barplot(data=df, x='Value', y='Time (sec)', ax=axes[0], palette='viridis')
    axes[0].set_title(f'Time dependence on {df["Parameter"].iloc[0]}')
    axes[0].set_xlabel(df["Parameter"].iloc[0])
    
    # TEXT LEN
    sns.lineplot(data=df, x='Value', y='Length (chars)', ax=axes[1], marker='o', color='red')
    axes[1].set_title(f'Output length dependence on {df["Parameter"].iloc[0]}')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()
    
    # TABLE
    display(df[['Value', 'Time (sec)', 'Output']])

In [None]:
results = []
common_prompt = "Geralt "

## Beam Search
 do_sample=False -> Deterministic Beam Search

In [None]:
base_params_beam = {
    "max_new_tokens": 40,
    "repetition_penalty": 1.0,
    "temperature": 1.0,
    "top_k": 50,
    "top_p": 0.9,
    "do_sample": False  
}

beam_values = [1, 2, 4, 8]

df_beams = benchmark_parameter("num_beams", beam_values, common_prompt, base_params_beam)

In [None]:
plot_results(df_beams, "Influence of num_beams (Beam Search)")

## Temperature (sampling)
do_sample=True, num_beams=1 -> Sampling

In [None]:
base_params_temp = {
    "max_new_tokens": 50,
    "num_beams": 1,
    "top_k": 0,
    "top_p": 0.0,
    "repetition_penalty": 1.0,
    "do_sample": True 
}

temp_values = [0.1, 0.5, 0.8, 1.0, 1.5, 3.0]

df_temp = benchmark_parameter("temperature", temp_values, common_prompt, base_params_temp)

In [None]:
plot_results(df_temp, "Influence of temperature (Sampling)")

## Top-K

In [None]:
base_params_k = {
    "max_new_tokens": 40,
    "num_beams": 1,
    "temperature": 1.0,
    "top_p": 0.0,
    "repetition_penalty": 1.0,
    "do_sample": True
}

k_values = [1, 5, 20, 50, 0]

df_k = benchmark_parameter("top_k", k_values, common_prompt, base_params_k)

In [None]:
plot_results(df_k, "Influence of top_k")

## Repetition penalty

In [None]:
loop_prompt = "Geralt of Rivia is a witcher and Geralt of Rivia"
base_params_rep = {
    "max_new_tokens": 50,
    "num_beams": 1,
    "temperature": 1.0,
    "do_sample": False, # Greedy search
    "top_k": 50,
    "top_p": 0.9,
}

penalty_values = [1.0, 1.05, 1.1, 1.2, 2.0]

df_rep = benchmark_parameter("repetition_penalty", penalty_values, loop_prompt, base_params_rep)

In [None]:
plot_results(df_rep, "Influence of repetition_penalty")

## Comparison

In [None]:
comparison_data = [
    {"Strategy": "Greedy", "Time": df_beams[df_beams['Value'] == 1]['Time (sec)'].values[0]},
    {"Strategy": "Beam Search (k=4)", "Time": df_beams[df_beams['Value'] == 4]['Time (sec)'].values[0]},
    {"Strategy": "Beam Search (k=8)", "Time": df_beams[df_beams['Value'] == 8]['Time (sec)'].values[0]},
    {"Strategy": "Sampling", "Time": df_temp[df_temp['Value'] == 1.0]['Time (sec)'].values[0]},
]

df_comp = pd.DataFrame(comparison_data)

In [None]:
plt.figure(figsize=(10, 6))
sns.barplot(data=df_comp, x='Strategy', y='Time', palette='magma')
plt.title('Comparison of generation speed')
plt.ylabel('Time (sec)')
plt.xticks(rotation=45)
plt.show()