In [None]:
import os
import json
import sys
import time
import logging
import uuid
from datetime import datetime
from contextlib import redirect_stdout
from io import StringIO
from dataclasses import dataclass, field
import psutil
import torch
import torch.nn as nn
import torch.distributed as dist

from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator, notebook_launcher
from codecarbon import EmissionsTracker
from fvcore.nn import FlopCountAnalysis

# Optional: if using optimum benchmark branch
from optimum_benchmark import Benchmark, BenchmarkConfig, TorchrunConfig, InferenceConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging

from model_wrapper import ModelWrapper
from energy_tracking import start_energy_tracking, stop_energy_tracking
from metrics import get_compute_performance_metrics, detect_cpu_vendor
from experiment_utils import (
    load_model_tokenizer_backend,
    prep_distributed_env,
    extract_experiment_setup,
    extract_experiment_results,
    save_results,
    aggregate_experiments,
    get_persistent_unique_id
)


# -----------------------------------------------------------------------------
# Configuration dataclass
# -----------------------------------------------------------------------------
    
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Dict


class TaskType(str, Enum):
    TEXT_GENERATION = "text_generation"
    TRANSLATION = "translation"
    SUMMARIZATION = "summarization"

class InferenceType(str, Enum):
    PURELY_GENERATIVE = "purely_generative"
    REASONING = "reasoning"
    
class IsEncoderDecoder(str, Enum):
    NA = "na"
    ENCODER_DECODER = "encoder_decoder"
    DECODER_ONLY = "decoder_only"

@dataclass
class ExperimentConfig:
    model_name: str
    task_type: TaskType = TaskType.TEXT_GENERATION
    max_input_tokens: int = 512
    max_output_tokens: int = 128
    batch_size: int = 8
    gpu_list: list = field(default_factory=lambda: [0, 1]) 
    decoder_temperature: float = 1.0
    query_rate: float = 1.0
    fp_precision: str = "float16"
    inference_type: str = "purely_generative"
    quantisation: bool = False
    batching_options: dict = field(default_factory=dict)
    sharding_config: dict = field(default_factory=dict)
    is_encoder_decoder: IsEncoderDecoder = IsEncoderDecoder.NA

# -----------------------------------------------------------------------------
# Inference function that measures performance metrics.
# -----------------------------------------------------------------------------
def run_gen_inference_with_metrics(model, tokenizer, accelerator, prompts, 
                                   max_input_tokens, max_output_tokens, batch_size):
    """
    Runs inference and returns performance metrics.
    """
    truncated_prompts = [
        tokenizer.decode(
            tokenizer(p, truncation=True, max_length=max_input_tokens, return_tensors="pt").input_ids[0],
            skip_special_tokens=True
        )
        for p in prompts
    ]

    # Sort prompts by token length for efficient batching
    sorted_prompts = sorted(truncated_prompts, key=lambda x: len(tokenizer.tokenize(x)))
    latencies = []
    total_tokens = 0
    total_input_tokens = 0  # Track input tokens
    device = accelerator.device
    num_batches = (len(sorted_prompts) + batch_size - 1) // batch_size

    for i in range(num_batches):
        batch = sorted_prompts[i * batch_size: (i + 1) * batch_size]

        # Tokenize batch
        encoded = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens)
        input_ids = encoded.input_ids.to(device)
        total_input_tokens += input_ids.numel()  # Count input tokens

        # Generate outputs with DistributedDataParallel fix
        start_time = time.perf_counter()
        if hasattr(model, "module"):
            outputs = model.module.generate(input_ids, max_new_tokens=max_output_tokens, do_sample=False)
        else:
            outputs = model.generate(input_ids, max_new_tokens=max_output_tokens, do_sample=False)
        end_time = time.perf_counter()
        latencies.append((end_time - start_time) * 1000.0)

        # Count generated tokens per prompt
        for j in range(len(batch)):
            prompt_len = input_ids[j].shape[0]
            gen_len = outputs[j].shape[0] - prompt_len
            total_tokens += gen_len

    avg_latency_ms = sum(latencies) / len(latencies) if latencies else 0.0
    total_time_sec = sum(latencies) / 1000.0
    throughput_qps = len(sorted_prompts) / total_time_sec if total_time_sec > 0 else 0.0
    tokens_per_sec = total_tokens / total_time_sec if total_time_sec > 0 else 0.0

    return {
        "avg_latency_ms": avg_latency_ms,
        "throughput_qps": throughput_qps,
        "tokens_per_sec": tokens_per_sec,
        "total_generated_tokens": total_tokens,
        "num_runs": len(sorted_prompts),
        "total_time": total_time_sec,
        "total_input_tokens": total_input_tokens  
    }


def run_gen_inference_with_metrics(model, tokenizer, accelerator, prompts, 
                                   max_input_tokens, max_output_tokens, batch_size,
                                   decoder_temperature=1.0,
                                   inference_type="purely_generative",
                                   query_rate=1.0):
    """
    Runs inference and returns performance metrics.
    """
    # truncate prompts
    truncated_prompts = [
        tokenizer.decode(
            tokenizer(p, truncation=True, max_length=max_input_tokens, return_tensors="pt").input_ids[0],
            skip_special_tokens=True
        )
        for p in prompts
    ]
    
    sorted_prompts = sorted(truncated_prompts, 
                            key=lambda x: len(tokenizer.tokenize(x)))
    latencies = []
    total_tokens = 0
    total_input_tokens = 0
    device = accelerator.device
    num_batches = (len(sorted_prompts) + batch_size - 1) // batch_size

    for i in range(num_batches):
        batch = sorted_prompts[i * batch_size: (i + 1) * batch_size]
        encoded = tokenizer(batch, 
                            return_tensors="pt", 
                            padding=True, 
                            truncation=True, 
                            max_length=max_input_tokens)
        input_ids = encoded.input_ids.to(device)
        total_input_tokens += input_ids.numel()

        # Determine sampling settings based on inference_type
        do_sample = True if inference_type == "reasoning" else False # remove this

        start_time = time.perf_counter()
        if hasattr(model, "module"):
            outputs = model.module.generate(
                input_ids,
                max_new_tokens=max_output_tokens,
                do_sample=do_sample,
                temperature=decoder_temperature
            )
        else:
            outputs = model.generate(
                input_ids,
                max_new_tokens=max_output_tokens,
                do_sample=do_sample,
                temperature=decoder_temperature
            )
        end_time = time.perf_counter()
        latencies.append((end_time - start_time) * 1000.0)

        for j in range(len(batch)):
            prompt_len = input_ids[j].shape[0]
            gen_len = outputs[j].shape[0] - prompt_len
            total_tokens += gen_len

        # Enforce query rate by sleeping (simulate inter-arrival delay)
        if query_rate > 0:
            time.sleep(1.0 / query_rate)

    avg_latency_ms = sum(latencies) / len(latencies) if latencies else 0.0
    total_time_sec = sum(latencies) / 1000.0
    throughput_qps = len(sorted_prompts) / total_time_sec if total_time_sec > 0 else 0.0
    tokens_per_sec = total_tokens / total_time_sec if total_time_sec > 0 else 0.0

    return {
        "avg_latency_ms": avg_latency_ms,
        "throughput_qps": throughput_qps,
        "tokens_per_sec": tokens_per_sec,
        "total_generated_tokens": total_tokens,
        "num_runs": len(sorted_prompts),
        "total_time": total_time_sec,
        "total_input_tokens": total_input_tokens  
    }

# -----------------------------------------------------------------------------
# Experiment runner with aggregation integration.
# -----------------------------------------------------------------------------
class ExperimentRunner:
    def __init__(self, experiment_config: ExperimentConfig, prompts, inference_fn, backend="pytorch", use_optimum=False, **inference_kwargs):
        self.config = experiment_config
        self.prompts = prompts
        self.inference_fn = inference_fn
        self.backend = backend
        self.use_optimum = use_optimum
        self.inference_kwargs = inference_kwargs

    def run(self):
        model_name = self.config.model_name
        # Use the enum's value if applicable
        task_type = self.config.task_type.value if isinstance(self.config.task_type, Enum) else self.config.task_type

        if self.use_optimum:
            # --- Optimum benchmark branch ---
            setup_logging(level="INFO")
            launcher_config = TorchrunConfig(nproc_per_node=1)
            scenario_config = InferenceConfig(latency=True, memory=True, input_shapes={"sequence_length": 128})
            backend_config = PyTorchConfig(model=model_name, device="cuda", device_ids="0", no_weights=True)
            benchmark_config = BenchmarkConfig(
                name=f"{self.backend}_{model_name}",
                scenario=scenario_config,
                launcher=launcher_config,
                backend=backend_config,
            )
            benchmark_report = Benchmark.launch(benchmark_config)
            benchmark_results = benchmark_report.to_dict()
            print(json.dumps({
                "model": model_name,
                "optimum_benchmark_results": benchmark_results
            }, indent=4))
            return benchmark_results
        else:
            # --- Standard experiment branch ---
            model, tokenizer = load_model_tokenizer_backend(
                model_name, 
                backend=self.backend, 
                fp_precision=self.config.fp_precision
            )

            model, tokenizer, accelerator = prep_distributed_env(model, tokenizer, gpu_list=self.config.gpu_list)
            tracker = start_energy_tracking()
            
            # Run inference using parameters from the config
            inference_metrics = self.inference_fn(
                model, tokenizer, accelerator, self.prompts,
                self.config.max_input_tokens,
                self.config.max_output_tokens,
                self.config.batch_size,
                decoder_temperature=self.config.decoder_temperature,
                inference_type=self.config.inference_type,
                query_rate=self.config.query_rate,
                **self.inference_kwargs
            )
            codecarbon_data = stop_energy_tracking(tracker)
            experiment_results = extract_experiment_results(
                inference_metrics, codecarbon_data,
                model=model, tokenizer=tokenizer, device=accelerator.device
            )
            
            # Extract common experimental setup and variables
            experiment_setup = extract_experiment_setup(model_name, codecarbon_data, accelerator, task_type)
            experiment_variables = {
                "max_input_tokens": self.config.max_input_tokens,
                "max_output_tokens": self.config.max_output_tokens,
                "number_runs": inference_metrics["num_runs"],
                "total_token_inputted": inference_metrics["total_input_tokens"],
                "total_tokens_outputted": inference_metrics["total_generated_tokens"],

                "batch_size": self.config.batch_size,
                "gpu_list": self.config.gpu_list,
                "decoder_temperature": self.config.decoder_temperature,
                "query_rate": self.config.query_rate,
                "fp_precision": self.config.fp_precision,
                "inference_type": self.config.inference_type,
                "quantisation": self.config.quantisation,
                "batching_options": self.config.batching_options,
                "sharding_config": self.config.sharding_config,
                "is_encoder_decoder": (
                    self.config.is_encoder_decoder.value
                    if hasattr(self.config.is_encoder_decoder, "value")
                    else self.config.is_encoder_decoder
                )
            }
            
            local_result = {
                "experiment_setup": experiment_setup,
                "experiment_variables": experiment_variables,
                "experiment_results": experiment_results
            }
            
            # Gather results if using distributed processing
            if dist.is_available() and dist.is_initialized():
                world_size = dist.get_world_size()
                all_results = [None] * world_size
                dist.all_gather_object(all_results, local_result)
            else:
                all_results = [local_result]
            
            # Only aggregate on the main process (local rank 0)
            if accelerator.local_process_index == 0:
                aggregated_result = aggregate_experiments(all_results)
                unique_id = get_persistent_unique_id()
                experiment_title = f"EXPERIMENT #{unique_id}"
                benchmark_results = {experiment_title: aggregated_result}
                output_json_path = save_results(task_type, benchmark_results)
                print(f"Aggregated benchmark results saved to {output_json_path}")
                return benchmark_results
            else:
                return None

# -----------------------------------------------------------------------------
# Main execution
# -----------------------------------------------------------------------------
from datasets import load_dataset

# Create an ExperimentConfig instance with the desired parameters
experiment_config = ExperimentConfig(
    model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    is_encoder_decoder="decoder_only",
    task_type="text_generation",
    max_input_tokens=512,
    max_output_tokens=50,
    batch_size=8
)

# Prompts
ds = load_dataset("lighteval/pile_helm", "arxiv")["test"]
ds = ds.select(range(5))
prompts = [sample["text"] for sample in ds]

backend = "pytorch"
use_optimum = False

# Restrict to GPUs 0 and 1
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

notebook_launcher(
    lambda: ExperimentRunner(
        experiment_config=experiment_config,
        prompts=prompts,
        inference_fn=run_gen_inference_with_metrics,
        backend="pytorch",
        use_optimum=False
    ).run(),
    num_processes=2
)

Launching training on 2 GPUs.
Using device: cuda:0 (Local Rank: 0)
Using 2 GPUs: [0, 1]Model is on cuda:1

Model is on cuda:0




Aggregated benchmark results saved to benchmark_results/text_generation_results.json


