In [5]:
import torch
import numpy as np
import time
import math
from typing import List, Dict

# --- Import from your project files ---
try:
    from AllCode import ParamDef, ParamSpace
    from AllCode import SceneBuilder
    from AllCode import SolverV2_opt
except ImportError as e:
    print(f"Error: Could not import project files. Make sure .py files are in the same directory.")
    print(f"Details: {e}")
    exit()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Hyperparameters for the Benchmark ---

# Total number of samples to solve (set high enough for a good measurement)
TOTAL_SAMPLES = 100_000

# Batch sizes to test. The solver may run out of memory on larger batches.
BATCH_SIZES_TO_TEST = [10_000, 25_000, 50_000, 100_000]

# --- Global Setup (mimicking your main script) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
print(f"--- Solver-Only Benchmark ---")
print(f"Using device: {device}")
print(f"Total samples per run: {TOTAL_SAMPLES:,}")
print(f"Testing batch sizes: {BATCH_SIZES_TO_TEST}")
print("-" * 30)

# --- 1. Define the Parameter Space (Must match build_scene.py) ---
space = ParamSpace([
    #Spatial Source
	ParamDef(name="A_SR",       low=0.5e23,     high=3e23,          scale="linear"), # DONE
	ParamDef(name="Rc_SR",      low=0.85,       high=1.10,          scale="linear"), # DONE
	ParamDef(name="o_SR",       low=0.05,       high=0.20,          scale="linear"), # DONE
 
	#Temporal Source
	ParamDef(name="A_ST",       low=0,          high=0.25,          scale="linear"), # DONE
	ParamDef(name="f_ST",       low=0.1,        high=10,          scale="log10"), # DONE
	ParamDef(name="p_ST",       low=0,          high=2*math.pi,          scale="linear"), # DONE
 
	#Diffusion Profile
	ParamDef(name="A_D",        low=1,          high=10,          scale="log10"), # DONE
	ParamDef(name="AL_D",       low=0.20,       high=1.00,          scale="linear"), # DONE
	ParamDef(name="A0_D",       low=0.01,       high=0.20,          scale="linear"), # DONE
	ParamDef(name="Rc_D",       low=0.85,       high=1.10,          scale="linear"), # DONE
	ParamDef(name="RwL_D",      low=0.025,      high=0.200,          scale="linear"), # DONE
	ParamDef(name="RwR_D",      low=0.025,      high=0.200,          scale="linear"), # DONE
 
 	#Convection Profile
	ParamDef(name="A_V",        low=0.5,        high=7.5,          scale="linear"), # DONE
	ParamDef(name="A1_V",       low=0.10,       high=0.45,          scale="linear"), # DONE
	ParamDef(name="A2_V",       low=0.25,       high=3.00,          scale="linear"), # DONE
	ParamDef(name="R1_V",       low=0.35,       high=0.60,          scale="linear"), # DONE
	ParamDef(name="R2_V",       low=0.65,       high=0.85,          scale="linear"), # DONE
	ParamDef(name="R3_V",       low=0.95,       high=1.15,          scale="linear"), # DONE
	ParamDef(name="Flip_V",     low=-1.0,       high=1.0,          scale="linear"), # DONE
	ParamDef(name="Bounce_V",   low=-1.0,       high=1.0,          scale="linear"), # DONE
 
	#Initial Density Profile
	ParamDef(name="A_N0",       low=5.5e19,     high=9e19,          scale="linear"), # DONE
	ParamDef(name="Xs_N0",      low=0.85,       high=1.10,          scale="linear"), # DONE
	ParamDef(name="H_N0",       low=0.006,      high=0.018,          scale="linear"), # DONE
	ParamDef(name="a_N0",       low=0.006,      high=0.016,          scale="linear"), # DONE
	ParamDef(name="B_N0",       low=0.85e20,    high=1.20e20,          scale="linear"), # DONE
 
 	#Edge Boundary Condition
	ParamDef(name="A_mag",      low=1e18,       high=1e22,          scale="log10"), # DONE
    ParamDef(name="A_sign",     low=-1.0,       high=1.0,          scale="linear"), # DONE
])

# --- 2. Set up Grids, Builder, and Solver ---
rho_grid = torch.linspace(1e-3, 1.2, steps=151, device=device, dtype=DTYPE) #601
time_grid = torch.linspace(1e-3, 0.065317, steps=161, device=device, dtype=DTYPE) #641

builder = SceneBuilder(rho=rho_grid, time=time_grid)
solver = SolverV2_opt()

# --- 3. Pre-build Input Data (NOT TIMED) ---
print("Pre-building input data template...")
try:
    # We will build one "template" batch based on the largest size
    MAX_BATCH_SIZE = max(BATCH_SIZES_TO_TEST)
    
    # Create stable, "mean" parameters (unit-space 0.5)
    U_mean = torch.full((1, space.dim), 0.5, device=device, dtype=DTYPE)
    X_phys_mean = space.unit_to_phys(U_mean)
    plist_mean = space.dict_from_vector(X_phys_mean[0])
    
    # Create a list of identical parameters
    plist_template = [plist_mean] * MAX_BATCH_SIZE
    
    # Build the *input* to the solver
    # This `case_template` dict contains all tensors the solver needs
    case_template = builder.build_batch(plist_template)
    print(f"Template built for max batch size {MAX_BATCH_SIZE}.")

except torch.cuda.OutOfMemoryError:
    print(f"ERROR: CUDA Out of Memory just trying to *build* the template batch of size {MAX_BATCH_SIZE}.")
    print("Try reducing the largest batch size in BATCH_SIZES_TO_TEST.")
    exit()
except Exception as e:
    print(f"Error during pre-build: {e}")
    exit()

def run_benchmark(batch_size: int) -> float:
    """
    Runs a benchmark for a given batch size and returns samples per second.
    """
    num_batches = math.ceil(TOTAL_SAMPLES / batch_size)
    actual_samples = num_batches * batch_size
    
    print(f"  Testing {actual_samples:,} samples in {num_batches} batches of {batch_size}...")
    
    try:
        # --- 1. Get the input slice for this batch size ---
        # This is just a view, very fast, not timed.
        batch_input = {k: v[:batch_size] for k, v in case_template.items()}
        
        # --- 2. Warmup Run ---
        # Run the solver once to compile JIT, warm up CUDA, etc.
        _ = solver.solve(**batch_input, conv_bc_outer='dirichlet', assert_conservation=False, dtype=DTYPE)
        
        # Wait for CUDA to finish the warmup
        if device.type == 'cuda':
            torch.cuda.synchronize()
            
        # --- 3. Timed Run ---
        start_time = time.perf_counter()
        
        for _ in range(num_batches):
            # *** THIS IS THE ONLY THING BEING TIMED ***
            _ = solver.solve(**batch_input, conv_bc_outer='dirichlet', assert_conservation=False, dtype=DTYPE)

        # IMPORTANT: Wait for all CUDA kernels to finish
        if device.type == 'cuda':
            torch.cuda.synchronize()
            
        end_time = time.perf_counter()
        
        # --- 4. Calculate Results ---
        total_time = end_time - start_time
        sps = actual_samples / total_time
        return sps, total_time, actual_samples

    except torch.cuda.OutOfMemoryError:
        print(f"  ERROR: CUDA Out of Memory with batch size {batch_size}. Stopping.")
        return -1.0, -1.0, -1.0
    except Exception as e:
        print(f"  ERROR: An exception occurred: {e}")
        return -1.0, -1.0, -1.0

# --- Main Execution ---
if __name__ == "__main__":
    results = []
    
    print("\n--- Starting Benchmark ---")
    print(f"{'Batch Size':<12} | {'Samples/Sec':<15} | {'Total Time (s)':<15} | {'Total Samples':<15}")
    print("-" * 58)

    for bs in BATCH_SIZES_TO_TEST:
        sps, total_time, actual_samples = run_benchmark(bs)
        
        if sps == -1.0: # OOM Error
            print(f"{bs:<12} | {'--- OOM ---':<15} | {'--- OOM ---':<15} | {'--- OOM ---':<15}")
            break
        
        results.append((bs, sps))
        print(f"{bs:<12} | {sps:<15.2f} | {total_time:<15.4f} | {actual_samples:<15,}")

    # --- Final Report ---
    if results:
        best_bs, best_sps = max(results, key=lambda item: item[1])
        print("\n--- Benchmark Complete ---")
        print(f"Optimal Batch Size: {best_bs}")
        print(f"   Max Throughput: {best_sps:,.2f} Samples/Sec")
        
        # Compare to your 300k/min baseline (5,000 sps)
        baseline_sps = 300_000 / 60
        improvement = (best_sps - baseline_sps) / baseline_sps
        print(f"  vs. 5,000 SPS: {improvement:+.1%}")
    else:
        print("\n--- Benchmark Failed ---")
        print("No valid results were recorded.")

--- Solver-Only Benchmark ---
Using device: cuda
Total samples per run: 100,000
Testing batch sizes: [10000, 25000, 50000, 100000]
------------------------------
Using device: cuda
Pre-building input data template...
Template built for max batch size 100000.

--- Starting Benchmark ---
Batch Size   | Samples/Sec     | Total Time (s)  | Total Samples  
----------------------------------------------------------
  Testing 100,000 samples in 10 batches of 10000...


KeyboardInterrupt: 