In [None]:
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

from accelerate import notebook_launcher
from config.experiment_config import ExperimentConfig
from experiment.experiment_runner import ExperimentRunner
from datasets import load_dataset

# Create an ExperimentConfig (could also load from YAML)
experiment_config = ExperimentConfig(
    model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    is_encoder_decoder="decoder_only",
    task_type="text_generation",
    inference_type="purely_generative",
    max_input_tokens=512,
    max_output_tokens=50,
    gpu_list=[0,1,2,3],
    num_processes=4,
    batching_options={"adaptive_batching": True, "max_batch_size": 8},
    sharding_config={"fsdp_config": {}},
    query_rate=1,
    decoder_temperature=1,
    fp_precision="float16",
    quantisation=True,
    backend="pytorch"
)

# Load prompts (example using a dataset)
ds = load_dataset("lighteval/pile_helm", "arxiv")["test"]
ds = ds.select(range(5))
prompts = [sample["text"] for sample in ds]

notebook_launcher(
    lambda: ExperimentRunner(experiment_config, prompts, inference_fn=text_gen_runinf, use_optimum=False).run(),
    num_processes=experiment_config.num_processes
)
