In [16]:
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import scipy as sp

from Old_Code.AllCode import SolverV2_opt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
def build_profs(r, t, size=128):
    def RandNormal(center: float, width: float, size: int, dist_tightness=0.25, Absolute=0):
        '''Docstring WIP'''
        if (not Absolute) and ((width<0) or (width>1)):
            raise ValueError(f"width must be between 0 and 1 in relative mode. Width - {width}")
        tight_function = lambda t: t/(1-t)
        b = tight_function(dist_tightness)
        scale = 1/b
        rv = sp.stats.truncnorm(-b,b,loc=0,scale=scale)
        if Absolute:
            return center + width*rv.rvs(size=size,random_state=None)
        else:
            return center*(1 + width*rv.rvs(size=size,random_state=None))

    #Spacial Source Function/Base Form
    # Create a well of negative source beyond rho=1.1, blends from 1 times normal at rho=1.1 to -A at rho=1.2
    SRFunc = lambda r,A,rc,o: A*np.e**(-(r-rc)**2/(o**2)) #+ np.where(r>1.07, -1.5*A*(r-1.07)/(0.1), 0)

    Abase = 1.75e23	
    rcbase = 1
    SRBase = SRFunc(r, Abase, rcbase, 0.125)

    T = 0.7

    #Distribution Tightness
    SRtightness = T

    #Generate Parameters
    A_SR = RandNormal(Abase, 1.25e23, size, dist_tightness=SRtightness, Absolute=1)
    rc_SR = RandNormal(rcbase, 0.1, size, dist_tightness=SRtightness, Absolute=1)
    o_SR = RandNormal(0.125, 0.075, size, dist_tightness=SRtightness, Absolute=1)

    #Time Source Function/Base Form
    STFunc = lambda t, A, f, p: A*np.sin(2*(10**f)*np.pi*(t + p)) + 1
    STBase = STFunc(t, 0.125, 0, 0)

    #Distribution Tightness
    STtightness = 0.4
    STtightness2 = 0.25

    #Generating Parameters
    A_ST = RandNormal(0.125, 0.125, size, dist_tightness=STtightness, Absolute=1)
    f_ST = RandNormal(0, 1, size, dist_tightness=STtightness2, Absolute=1)
    p_ST = RandNormal(0, np.pi, size, dist_tightness=STtightness2, Absolute=1)

    #Diffusion Function/Base Form A_D, AL_D, A0_D, Rc_D, RwL_D, RwR_D
    D_params = (0.5, 0.6, 0.105, 1, 0.125, 0.125)
    def DFunc(r, D_params):
        A_D, AL_D, A0_D, Rc_D, RwL_D, RwR_D = D_params
        A_D = 10**A_D
        Dknots = np.array([0.0, Rc_D - RwL_D, Rc_D, Rc_D + RwR_D])
        Dvals = A_D*np.array([AL_D, AL_D, A0_D, 1])
        D_interp = sp.interpolate.PchipInterpolator(Dknots, Dvals, extrapolate=True)(r)
        #D_interp = sp.interpolate.Akima1DInterpolator(Dknots, Dvals, extrapolate=True)(r)
        return D_interp, Dknots, Dvals
    DBase, DBk, DBv = DFunc(r, D_params)

    #Distribution Tightness
    Dtightness = 0.75
    Dwidths = np.array([0.5, 0.4, 0.095, 0.1, 0.075, 0.075])

    #Generating Parameters
    D_params_gen = []
    for i, param in enumerate(D_params):
        new_param = RandNormal(param, Dwidths[i], size, dist_tightness=Dtightness, Absolute=1)
        D_params_gen.append(new_param)
    D_params_gen = np.array(D_params_gen).T

    #Convection Function/Base Form
    V_params = (0.2870156338638594, 0.275, 1.625, 0.475, 0.75, 1.05, 0, 0)
    def VFunc(r, V_params):
        A_V, A1_V, A2_V, R1_V, R2_V, R3_V, Flip_V, Bounce_V = V_params
        A_V = 10**A_V
        Vknots = np.array([0.0, R1_V, R2_V, R3_V])
        Vvals = A_V*np.array([0, A1_V, A2_V, -1])
        V_interp = sp.interpolate.PchipInterpolator(Vknots, Vvals, extrapolate=True)(r)
        if Bounce_V > 0:
            V_interp = np.abs(V_interp)
        V_interp = V_interp*(-1 if Flip_V > 0 else 1)
        return V_interp, Vknots, Vvals
    VBase, VBk, VBv = VFunc(r, V_params)

    #Distribution Tightness #!6 knots?
    Vtightness = T
    Vwidths = np.array([0.5880456295278407, 0.175, 1.375, 0.125, 0.1, 0.1, 1, 1])

    #Generating Parameters
    V_params_gen = []
    for i, param in enumerate(V_params):
        new_param = RandNormal(param, Vwidths[i], size, dist_tightness=Dtightness, Absolute=1)
        V_params_gen.append(new_param)
    V_params_gen = np.array(V_params_gen).T

    N0Funca = lambda r,a: ((1+a*r)*np.exp(r) - np.exp(-r))/(np.exp(r) + np.exp(-r))
    N0Funcb = lambda r, xs, H: (xs - r)/H
    N0Func = lambda r, a, xs, H, A, B: A*N0Funca(N0Funcb(r,xs,H),a) + B

    N0Base = N0Func(r, 0.011, 1, 0.012, 7.25e19, 1.025e20)

    N0tightness = 0.25
    N0tightness2 = 0.8

    a_N0 = RandNormal(0.011, 0.005, size, dist_tightness=N0tightness, Absolute=1)
    xs_N0 = RandNormal(1, 0.1, size, dist_tightness=N0tightness2, Absolute=1)
    H_N0 = RandNormal(0.012, 0.006, size, dist_tightness=N0tightness, Absolute=1)
    A_N0 = RandNormal(7.25e19, 1.5e19, size, dist_tightness=N0tightness, Absolute=1)
    B_N0 = RandNormal(1.025e20, 0.175e20, size, dist_tightness=N0tightness, Absolute=1)

    #Generate boundary conds
    A_mag = 10**(RandNormal(20, 2, size, dist_tightness=0.6, Absolute=1))
    A_sign = RandNormal(0, 1, size, dist_tightness=0.6, Absolute=0)
    A_bounds = A_mag*np.sign(A_sign)
 
    #Make Profiles
    SR_profs = np.array([SRFunc(r, A_SR[i], rc_SR[i], o_SR[i]) for i in range(size)])
    ST_profs = np.array([STFunc(t, A_ST[i], f_ST[i], p_ST[i]) for i in range(size)])
    D_profs = np.array([DFunc(r, D_params_gen[i])[0] for i in range(size)])
    V_profs = np.array([VFunc(r, V_params_gen[i])[0] for i in range(size)])
    N0_profs = np.array([N0Func(r, a_N0[i], xs_N0[i], H_N0[i], A_N0[i], B_N0[i]) for i in range(size)])
 
    #Convert to tensors
    SR_tensor = torch.tensor(SR_profs, dtype=torch.float64, device=device)
    ST_tensor = torch.tensor(ST_profs, dtype=torch.float64, device=device)
    D_tensor = torch.tensor(D_profs, dtype=torch.float64, device=device)
    V_tensor = torch.tensor(V_profs, dtype=torch.float64, device=device)
    N0_tensor = torch.tensor(N0_profs, dtype=torch.float64, device=device)
    A_tensor = torch.tensor(A_bounds, dtype=torch.float64, device=device)
    return SR_tensor, ST_tensor, D_tensor, V_tensor, N0_tensor, A_tensor

In [21]:
import time, torch

def benchmark_batch_size(B, nr, nt, device):
    # --- construct inputs of the right shape ---
    rho  = torch.linspace(0, 1, nr, device=device).unsqueeze(0).repeat(B, 1)
    time_grid = torch.linspace(0, 1, nt, device=device).unsqueeze(0).repeat(B, 1)

    SR, ST, D, V, N0, A = build_profs(rho[0].cpu().numpy(), time_grid[0].cpu().numpy(), size=B)

    solver = SolverV2_opt()

    # --- warmup (JIT + CUDA) ---
    for _ in range(3):
        _ = solver.solve(rho, time_grid, SR, ST, D, V, N0, A, assert_conservation=False)

    torch.cuda.synchronize()
    n_iters = 5

    t0 = time.perf_counter()
    for _ in range(n_iters):
        _ = solver.solve(rho, time_grid, SR, ST, D, V, N0, A, assert_conservation=False)
    torch.cuda.synchronize()
    t1 = time.perf_counter()

    total_time = t1 - t0
    time_per_call = total_time / n_iters
    time_per_sample = time_per_call / B
    return time_per_call, time_per_sample


In [22]:
nr = 151
nt = 160
batch_sizes = [10000, 25000]
total_eval = 100000
times = []

for B in batch_sizes:
    print(f"Benchmarking batch size {B}...")
    tB = 0
    for _ in range(total_eval // B):
        print(f'processing batch {_+1} of {total_eval // B}')
        time_per_call, time_per_sample = benchmark_batch_size(B, nr, nt, device)
        tB += time_per_call
    times.append((B, tB / (total_eval // B), tB / total_eval))
    print(f'end=of batch size {B}')
    print(f'Took {tB:.6f} seconds for {total_eval} samples.')
        

Benchmarking batch size 10000...
processing batch 1 of 10
Using device: cuda
processing batch 2 of 10
Using device: cuda
processing batch 3 of 10
Using device: cuda
processing batch 4 of 10
Using device: cuda
processing batch 5 of 10
Using device: cuda
processing batch 6 of 10
Using device: cuda


KeyboardInterrupt: 

In [None]:
print("Batch Size | Time per Call (s) | Time per Sample (s)")
for B, t_call, t_sample in times:
	print(f"{B:10d} | {t_call:.6f}       | {t_sample:.9f}")

In [None]:
B = 1
results = []
while True:
    try:
        t_call, t_sample = benchmark_batch_size(B, nr, nt, device="cuda")
        results.append((B, t_call, t_sample))
        print(f"B={B} ok: {t_sample*1e3:.3f} ms/sample")
        B *= 2
    except RuntimeError as e:
        print(f"OOM at B={B}")
        break


In [None]:
def memory_per_sample(nr, nt, device="cuda"):
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # allocate B=1 case
    rho  = torch.linspace(0, 1, nr, device=device).unsqueeze(0)
    time_grid = torch.linspace(0, 1, nt, device=device).unsqueeze(0)
    SR = torch.rand(1, nr, device=device)
    ST = torch.rand(1, nt, device=device)
    D  = torch.rand(1, nr, device=device)
    V  = torch.rand(1, nr, device=device)
    N0 = torch.rand(1, nr, device=device)
    A  = torch.rand(1, device=device)

    solver = SolverV2_opt()
    _ = solver.solve(rho, time_grid, SR, ST, D, V, N0, A)
    torch.cuda.synchronize()

    mem = torch.cuda.max_memory_allocated()
    return mem  # bytes for B=1


In [None]:
bytes_per_sample = memory_per_sample(nr, nt)
total_bytes = torch.cuda.get_device_properties(0).total_memory
usable = 0.8 * total_bytes  # leave headroom
B_max_est = int(usable // bytes_per_sample)
print("Estimated max B:", B_max_est)
