In [None]:
import gc
from time import time

import torch
from openai import Model
from transformers import AutoModelForCausalLM, AutoTokenizer

models = ["tiiuae/falcon-7b"]

model = None
gc.collect()
torch.cuda.empty_cache()

config = {
    "load_in_8bit": True,
    "torch_dtype": torch.bfloat16,
    "temperature": 0.9,
    "max_length": 1024,
}


def generate_samples(model_name: str, config: dict, num_samples: int = 1):
    # Load Model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=config["load_in_8bit"],
        torch_dtype=config["torch_dtype"],
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()
    model = torch.compile(model)

    # Tokenize inputs
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    text = "Question: Tell me a history of WW2 in 3 or 4 paragraphs.\nAnswer: "
    input_tokens = tokenizer(text, return_tensors="pt").input_ids.to("cuda")

    metrics = {
        "output_tokens": [],
        "gpu_mem_usage": [],
        "generate_time": [],
        "tokens_per_second": [],
    }

    for _ in range(num_samples):
        # Generate
        time0 = time()
        with torch.no_grad():
            output = model.generate(
                input_tokens,
                do_sample=True,
                temperature=0.9,
                max_length=1024,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        time1 = time()

        # Collect metrics
        output_tokens = len(output.cpu().numpy().tolist()[0])
        gpu_mem_usage = torch.cuda.memory_allocated() / 1024**3
        generate_time = time1 - time0
        tokens_per_second = output_tokens / generate_time

        # Log metrics
        metrics["output_tokens"].append(output_tokens)
        metrics["gpu_mem_usage"].append(gpu_mem_usage)
        metrics["generate_time"].append(generate_time)
        metrics["tokens_per_second"].append(tokens_per_second)

        # Print metrics
        print(f"Output tokens: {output_tokens}")
        print(f"GPU mem usage: {gpu_mem_usage:.2f} GB")
        print(f"Generate time: {generate_time:.2f} s")
        print(f"Tokens per second: {tokens_per_second:.2f}")

    # Save metrics
    log_metrics_to_csv(model_name, config, metrics)