In [8]:
import os
from dotenv import load_dotenv

In [None]:
# Load environment variables
load_dotenv()
ROOT_PATH = os.environ.get("ROOT_PATH")
if not ROOT_PATH:
    raise ValueError("ROOT_PATH environment variable not set. Please set it in your .env file.")


In [None]:
#!/usr/bin/env python3
"""
MLX LM Batch Inference Script
Processes multiple prompts using mlx-community/Llama-3.2-3B-Instruct-4bit model
"""

import json
import time
from typing import List, Dict, Any
from mlx_lm import load, generate

def load_model(model_path: str = "mlx-community/Llama-3.2-3B-Instruct-4bit", adapter_path: str = None):
    """Load the MLX model and tokenizer, optionally with LoRA adapter"""
    print(f"Loading model: {model_path}")
    
    if adapter_path:
        print(f"Loading with adapter: {adapter_path}")
        model, tokenizer = load(model_path, adapter_path=adapter_path)
        print("Model and adapter loaded successfully!")
    else:
        model, tokenizer = load(model_path)
        print("Model loaded successfully!")
    
    return model, tokenizer

def process_batch(
    prompts: List[str], 
    model, 
    tokenizer,
    max_tokens: int = 512
) -> List[Dict[str, Any]]:
    """Process a batch of prompts and return results"""
    results = []
    
    for i, prompt in enumerate(prompts):
        print(f"Processing prompt {i+1}/{len(prompts)}")
        
        start_time = time.time()
        
        try:
            # Generate response
            response = generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False)
            
            end_time = time.time()
            
            result = {
                "prompt_index": i,
                "prompt": prompt,
                "response": response,
                "generation_time": end_time - start_time,
                "status": "success"
            }
            
        except Exception as e:
            result = {
                "prompt_index": i,
                "prompt": prompt,
                "response": None,
                "generation_time": 0,
                "status": "error",
                "error": str(e)
            }
            print(f"Error processing prompt {i+1}: {e}")
        
        results.append(result)
        print(f"Completed in {result['generation_time']:.2f}s")
        print("-" * 50)
    
    return results

def save_results(results: List[Dict[str, Any]], output_file: str = "inference_results.sql"):
    with open(output_file, "w", encoding="utf-8") as f:
        for query in results:
            f.write(query['response'] + "\n")
    print(f"Results saved to {output_file}")

def load_prompts_from_file(file_path: str) -> List[str]:
    """Load prompts from a text file, JSON file, or JSONL file"""
    prompts = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                data = json.loads(line)
                if 'prompt' in data:
                    prompts.append(data['prompt'])
                else:
                    print(f"Warning: Line {line_num} missing 'prompt' field, skipping")
            except json.JSONDecodeError as e:
                print(f"Warning: Invalid JSON on line {line_num}, skipping: {e}")
    return prompts

def main():
    MODEL_PATH = "mlx-community/Llama-3.2-1B-Instruct-4bit"
    MAX_TOKENS = 512
    
    prompts = load_prompts_from_file("valid.jsonl")
    
    print(f"Starting batch inference with {len(prompts)} prompts")
    print("=" * 60)
    
    model, tokenizer = load_model(MODEL_PATH)
    
    results = process_batch(
        prompts=prompts,
        model=model,
        tokenizer=tokenizer,
        max_tokens=MAX_TOKENS
    )
    
    save_results(results)
    
    successful = sum(1 for r in results if r['status'] == 'success')
    failed = len(results) - successful
    total_time = sum(r['generation_time'] for r in results)
    
    print("=" * 60)
    print("BATCH PROCESSING COMPLETE")
    print(f"Total prompts: {len(prompts)}")
    print(f"Successful: {successful}")
    print(f"Failed: {failed}")
    print(f"Total time: {total_time:.2f}s")
    print(f"Average time per prompt: {total_time/len(prompts):.2f}s")
    print("Results saved to inference_results.json")

if __name__ == "__main__":
    main()

Starting batch inference with 5 prompts
Loading model: mlx-community/Llama-3.2-1B-Instruct-4bit


Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 136031.48it/s]


Model loaded successfully!
Processing prompt 1/5
Completed in 8.42s
--------------------------------------------------
Processing prompt 2/5
Completed in 8.09s
--------------------------------------------------
Processing prompt 3/5
Completed in 8.28s
--------------------------------------------------
Processing prompt 4/5
Completed in 5.57s
--------------------------------------------------
Processing prompt 5/5
Completed in 1.25s
--------------------------------------------------
Results saved to inference_results.sql
BATCH PROCESSING COMPLETE
Total prompts: 5
Successful: 5
Failed: 0
Total time: 31.62s
Average time per prompt: 6.32s
Results saved to inference_results.json
