In [1]:
import torch
import time
import os
from typing import List, Any, Dict

# --- Dummy Classes ---

class DummyTokenizer:
    def __init__(self, model_max_length=2048):
        self.model_max_length = model_max_length
    
    def tokenize(self, text: str) -> List[str]:
        # Simple whitespace tokenization.
        return text.split()
    
    def __call__(self, texts, truncation=True, max_length=None, padding="max_length", return_tensors="pt"):
        # For each text, simulate tokenization as a list of "tokens" (dummy: each word becomes token id 1).
        batch_size = len(texts)
        token_lists = [text.split()[:max_length] for text in texts]
        lengths = [len(t) for t in token_lists]
        # Create a tensor with fixed max_length
        input_ids = torch.zeros((batch_size, max_length), dtype=torch.long)
        attention_mask = torch.zeros((batch_size, max_length), dtype=torch.long)
        for i, tokens in enumerate(token_lists):
            l = len(tokens)
            input_ids[i, :l] = torch.ones(l, dtype=torch.long)
            attention_mask[i, :l] = 1
        return {"input_ids": input_ids, "attention_mask": attention_mask}

class DummyModel:
    def __init__(self, max_length=2048):
        self.max_length = max_length
    
    def generate(self, input_ids, **kwargs):
        # For testing, simply append a fixed number of tokens (e.g., ones) to simulate generation.
        max_new_tokens = kwargs.get("max_new_tokens", 10)
        batch_size, seq_len = input_ids.shape
        generated = torch.cat([input_ids, torch.ones((batch_size, max_new_tokens), dtype=torch.long)], dim=1)
        return generated

class DummyAccelerator:
    def __init__(self, device="cpu", process_index=0):
        self.device = torch.device(device)
        self.process_index = process_index
        self.local_process_index = process_index
    
    def print(self, *args, **kwargs):
        print(*args, **kwargs)
    
    def wait_for_everyone(self):
        # In this dummy, simply pass.
        pass

class DummyExperimentConfig:
    def __init__(self):
        self.max_input_tokens = 2048   # The truncation limit for tokenization.
        self.max_output_tokens = 50    # Number of new tokens to generate.
        self.decoder_temperature = 1.0
        self.save_outputs = True
        self.batching_options = {
            "fixed_max_batch_size": 12,
            "adaptive_batching": True,
            "adaptive_max_tokens": 100   # Batching is based on estimated token counts.
        }

# --- Functions Under Test ---

def adaptive_batching(prompts: List[str],
                      tokenizer: Any,
                      adaptive_max_tokens: int,
                      max_prompt_tokens: int,
                      max_batch_size: int = None) -> List[List[str]]:
    batches = []
    current_batch = []
    current_tokens = 0

    for prompt in prompts:
        raw_token_count = len(tokenizer.tokenize(prompt))
        token_count = min(raw_token_count, max_prompt_tokens)
        
        if max_batch_size is not None and len(current_batch) >= max_batch_size:
            batches.append(current_batch)
            current_batch = []
            current_tokens = 0
        
        if current_batch and (current_tokens + token_count > adaptive_max_tokens):
            batches.append(current_batch)
            current_batch = [prompt]
            current_tokens = token_count
        else:
            current_batch.append(prompt)
            current_tokens += token_count

    if current_batch:
        batches.append(current_batch)
    
    return batches

def batch_tokenise_truncate(prompts: List[str], tokenizer: Any, max_input_tokens: int, batch_size: int = 32) -> Dict[str, torch.Tensor]:
    all_input_ids = []
    all_attention_mask = []
    
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i : i + batch_size]
        encoded = tokenizer(
            batch,
            truncation=True,
            max_length=max_input_tokens,
            padding="max_length", 
            return_tensors="pt"
        )
        # Extra safeguard: slice in case truncation isn’t applied.
        encoded["input_ids"] = encoded["input_ids"][:, :max_input_tokens]
        if "attention_mask" in encoded:
            encoded["attention_mask"] = encoded["attention_mask"][:, :max_input_tokens]
            
        all_input_ids.append(encoded["input_ids"])
        if "attention_mask" in encoded:
            all_attention_mask.append(encoded["attention_mask"])
    
    tokenised_inputs = {"input_ids": torch.cat(all_input_ids, dim=0)}
    if all_attention_mask:
        tokenised_inputs["attention_mask"] = torch.cat(all_attention_mask, dim=0)
    
    return tokenised_inputs

def calculate_inference_metrics(num_input_prompts, latencies, total_input_tokens, total_generated_tokens):
    return {
        "num_input_prompts": num_input_prompts,
        "avg_latency_ms": sum(latencies)/len(latencies) if latencies else None,
        "total_input_tokens": total_input_tokens,
        "total_generated_tokens": total_generated_tokens,
    }

def run_gen_inference(model, experiment_config, prompts, tokenizer, accelerator):
    max_input_tokens = experiment_config.max_input_tokens 
    max_output_tokens = experiment_config.max_output_tokens
    decoder_temperature = experiment_config.decoder_temperature
    fixed_max_batch_size = experiment_config.batching_options.get("fixed_max_batch_size", 8)
    use_adaptive = experiment_config.batching_options.get("adaptive_batching", False)
        
    token_id_outputs = []
    latencies = []
    total_generated_tokens = 0
    total_input_tokens = 0  
    all_input_ids_batches = []
    device = accelerator.device

    if use_adaptive:
        adaptive_max_tokens = experiment_config.batching_options.get("adaptive_max_tokens", max_input_tokens)
        batches = adaptive_batching(prompts=prompts, 
                                    tokenizer=tokenizer, 
                                    adaptive_max_tokens=adaptive_max_tokens, 
                                    max_prompt_tokens=max_input_tokens, 
                                    max_batch_size=fixed_max_batch_size)
        accelerator.print(f"Using adaptive batching: created {len(batches)} batches.")
    else:
        batches = [prompts[i:i+fixed_max_batch_size] for i in range(0, len(prompts), fixed_max_batch_size)]
        accelerator.print(f"Using fixed batching (non-adaptive): created {len(batches)} batches.")
    
    for batch_idx, batch in enumerate(batches):
        tokenised_batch = batch_tokenise_truncate(
            prompts=batch,
            tokenizer=tokenizer,
            max_input_tokens=max_input_tokens,
            batch_size=len(batch)
        )
        batch_input_ids = tokenised_batch["input_ids"]
        total_input_tokens += batch_input_ids.numel()
        all_input_ids_batches.append(batch_input_ids)
        
        if "attention_mask" in tokenised_batch:
            batch_encoded = {
                "input_ids": batch_input_ids.to(device),
                "attention_mask": tokenised_batch["attention_mask"].to(device)
            }
        else:
            batch_encoded = {"input_ids": batch_input_ids.to(device)}
        
        gpu_id = accelerator.device.index if accelerator.device.type == 'cuda' else 0
        print(f"[Process {os.getpid()}][GPU {gpu_id}] — Completed tokenisation of batch {batch_idx + 1}/{len(batches)}")

        if decoder_temperature is not None and decoder_temperature > 0:
            generation_kwargs = {"max_new_tokens": max_output_tokens, "do_sample": True, "temperature": decoder_temperature}
        else:
            generation_kwargs = {"max_new_tokens": max_output_tokens, "do_sample": False}
        
        start_time = time.perf_counter()
        with torch.no_grad():
            token_id_batch_output = model.generate(batch_encoded["input_ids"], **generation_kwargs)
        if device.type == 'cuda':
            torch.cuda.synchronize(device)
        end_time = time.perf_counter()
        latencies.append((end_time - start_time) * 1000.0)
        print(f"[Process {os.getpid()}][GPU {gpu_id}] — Completed batch inference {batch_idx + 1}/{len(batches)}")
        
        for j in range(batch_input_ids.size(0)):
            prompt_len = batch_input_ids[j].shape[0]
            gen_len = token_id_batch_output[j].shape[0] - prompt_len
            total_generated_tokens += gen_len
        
        if experiment_config.save_outputs:
            token_id_outputs.append(token_id_batch_output)
    
    concatenated_input_ids = torch.cat(all_input_ids_batches, dim=0)
    
    inference_results = calculate_inference_metrics(
        num_input_prompts=len(prompts),
        latencies=latencies,
        total_input_tokens=total_input_tokens,
        total_generated_tokens=total_generated_tokens
    )
    
    if not experiment_config.save_outputs:
        token_id_outputs = None
        
    return token_id_outputs, concatenated_input_ids, inference_results

# --- Testing Script ---

dummy_tokenizer = DummyTokenizer(model_max_length=2048)
dummy_model = DummyModel(max_length=2048)
dummy_accelerator = DummyAccelerator(device="cpu", process_index=0)
dummy_experiment_config = DummyExperimentConfig()

# Create a list of dummy prompts.
prompts = [
    "This is a short prompt.",
    "Here is a slightly longer prompt that should still be under the adaptive limit.",
    "This is a prompt that is intentionally made very long " * 50  # repeat to create a very long prompt.
]

# Run the inference function.
outputs, concatenated_input_ids, inference_results = run_gen_inference(
    model=dummy_model,
    experiment_config=dummy_experiment_config,
    prompts=prompts,
    tokenizer=dummy_tokenizer,
    accelerator=dummy_accelerator
)

print("Inference results:", inference_results)
print("Concatenated input_ids shape:", concatenated_input_ids.shape)


Using adaptive batching: created 2 batches.
[Process 2671768][GPU 0] — Completed tokenisation of batch 1/2
[Process 2671768][GPU 0] — Completed batch inference 1/2
[Process 2671768][GPU 0] — Completed tokenisation of batch 2/2
[Process 2671768][GPU 0] — Completed batch inference 2/2
Inference results: {'num_input_prompts': 3, 'avg_latency_ms': 0.10345893679186702, 'total_input_tokens': 6144, 'total_generated_tokens': 150}
Concatenated input_ids shape: torch.Size([3, 2048])
