## Benchmark

In [None]:
import os

import torch

from llm_benchmarks.generation import generate_samples
from llm_benchmarks.utils import log_metrics_to_csv

os.environ["WANDB_SILENT"] = "true"

# Test both 8-bit
config = {
    "load_in_8bit": True,
    "torch_dtype": torch.float16,
    "temperature": 0.9,
    "max_length": 150,
}

model_name = "tiiuae/falcon-7b"
metrics = generate_samples(model_name, config, 3)
log_metrics_to_csv(model_name, config, metrics, "results")

## Plot

In [None]:
# Load all csv files that contain 'dolly' in the filename and put into a single dataframe
import glob

import pandas as pd
import plotly.express as px

df = pd.concat([pd.read_csv(f) for f in glob.glob("./results/*falcon*.csv")], ignore_index=True)

# Remove outliers
df = df[df["tokens_per_second"] < 100]
# df = df[df['output_tokens'] > 500]

# Plot the DataFrame
colors = {True: "blue", False: "red"}
fig = px.scatter(df, x="output_tokens", y="tokens_per_second", color="load_in_8bit", color_discrete_map=colors)
fig.update_traces(marker=dict(size=12))
fig.update_layout(xaxis_title="output_tokens", yaxis_title="tokens_per_second", title="Scatter Plot")
fig.update_layout(title="Falcon 7B: Tokens per Second vs. Output Tokens")
fig.show()

# fig.write_image("falcon_compare_8bit_inference.png", format="png", width=700, height=400, scale=5)