# Tune RFX parameters to match dataset

In [None]:
from utils import *

In [2]:
# load dataset
N_DS = 50_000 # number of samples in the dataset
# N_DS = 10_000 # number of samples in the dataset
ds = np.load(f'data/sxr_real_ds_gs{GSIZE}_n{N_DS}.npz')
emiss = ds['emiss']

# scale everything so it's in the range [0, 1]
ks = np.max(emiss.reshape(-1, GSIZE*GSIZE), axis=-1).reshape(-1, 1)
emiss = emiss/ks[:, None]
vdi = ds['vdi']/ks
vdc = ds['vdc']/ks
vde = ds['vde']/ks
hor = ds['hor']/ks
sxrs = [vdi, vdc, vde, hor]

In [3]:
# helper functions
def plot_examples(emiss, sxr, csxr, err, avgs, rays, inc_angles, color):
    # plot 
    np.random.seed(0)
    rows, cols = 10, 10
    print(f'Plotting {rows*cols} random samples')
    plt.figure(figsize=(1.5*cols, 2*rows))
    for i, idx in enumerate(np.random.choice(len(sxr), rows*cols)):
        em, sx, csx = emiss[idx], sxr[idx], csxr[idx]
        if i == 0:
            plt.subplot(rows, cols, 1)
            plt.scatter(RR, ZZ, c=em, s=900/GSIZE, marker='s')
            for r in rays: plt.plot(r[:,0], r[:,1], f'{color}-')
            plt.colorbar()
            # plt.title('Emissivity')
            plt.axis('equal')
            plt.grid(False)
        elif i==1:
            #plot average
            plt.subplot(rows, cols, 2)
            plt.plot(inc_angles, avgs.reshape(-1), f'{color}-', label='Average')
            # plt.title('Average SXR')
            # plt.legend()
            plt.xticks([]); plt.yticks([])
            plt.ylim(0, 0.7)

        else:
            plt.subplot(rows, cols, i+1)
            plt.plot(inc_angles, sx, f'{color}s:', label='Real', markersize=3)
            plt.plot(inc_angles, csx, f'{color}-', label='Sim')
            plt.plot(inc_angles, err[idx], f'w-', label='Error')
            plt.axhline(y=0, color='w', linestyle='--', linewidth=0.5)
            # plt.axis('off')
            plt.xticks([]); plt.yticks([])
            plt.ylim(-0.3, 0.3)
            # plt.title(f'{name} SXR')
            # plt.legend()
    plt.show()
    plt.close()

In [None]:
# SXR
I = 3 # 0 - 3

nrays, start_angle, span_angle, pinhole, keep = RFX_SXR_NRAYS[I], RFX_SXR_STARTS[I], RFX_SXR_SPANS[I], RFX_SXR_PINHOLES[I], RFX_SXR_TO_KEEP[I]
sxr = sxrs[I]
name, color = RFX_SXR_NAMES[I], RFX_SXR_COLORS[I]

rays, fan, inc_angles = create_rfx_fan(nrays, start_angle, span_angle, pinhole, keep, ret_all=True)

# filter data 
mins = np.min(sxr, axis=1)
print(f'mins shape: {mins.shape}')
# idxs_to_keep = np.where(np.sum(sxr, axis=1) > 0.5)[0]
idxs_to_keep = np.where(mins > 0.1)[0]
# idxs_to_keep = np.arange(len(sxr))

print(f'Keeping {len(idxs_to_keep)}/{len(sxr)} samples, {len(idxs_to_keep)/len(sxr)*100:.1f}%')
femiss = emiss[idxs_to_keep]
fsxr = sxr[idxs_to_keep]

csxr = np.zeros_like(fsxr)
for i in range(len(fsxr)):
    csxr[i] = eval_rfx_sxr(fan, emiss[i])

err = abs(fsxr - csxr) # error

avgs = np.mean(fsxr, axis=0).reshape(1, -1) # average

# fsxr, csxr = fsxr-avgs, csxr-avgs # remove average
fsxr, csxr = fsxr-0.5, csxr-0.5 # remove 0.5

# plot 
plot_examples(femiss, fsxr, csxr, err, avgs, rays, inc_angles, color)
    


In [None]:
# evaluate on a grid of parameters, 2d: pinhole position
# nx, ny, nangles = 13, 5, 15
nx, ny, nangles = 1, 1, 40
dx, dy, dangles = 0.02, 0.05, π/3
# x, y
phxs = np.linspace(-dx, dx, nx, endpoint=True) if nx > 1 else [0]
phys = np.linspace(-dy, dy, ny, endpoint=True) if ny > 1 else [0]
# angles diffs
angles = np.linspace(-dangles, dangles, nangles, endpoint=True) if nangles > 1 else [0]

xy, yx = np.meshgrid(phxs, phys)
xa, ax = np.meshgrid(phxs, angles)
ya, ay = np.meshgrid(phys, angles)

err = np.zeros((nx, ny, nangles))
best_ph = None
best_angle = None
best_err = np.inf
plt.figure(figsize=(10, 10))
for i, phx in enumerate(tqdm(phxs, desc='pinholes')):
    for j, phy in enumerate(phys):
        for k, angle in enumerate(angles):
            ph = pinhole + np.array([phx, phy])
            tmp_start = start_angle - angle/2
            tmp_span = span_angle + angle
            rays, fan, _ = create_rfx_fan(nrays, tmp_start, tmp_span, ph, keep, ret_all=True)
            for si in range(len(fsxr)):
                csxr[si] = eval_rfx_sxr(fan, femiss[si])
            err[i, j, k] = np.sum(abs(fsxr - csxr))
            if err[i, j, k] < best_err:
                best_err = err[i, j, k]
                best_ph = ph
                best_angle = angle
            #plot rays
            random_color = np.random.rand(3)
            # for r in rays: plt.plot(r[:,0], r[:,1], f'-', color=random_color, alpha=0.5)
            plt.plot(rays[0][:,0], rays[0][:,1], f'-', color=random_color, alpha=0.5)
            plt.plot(rays[-1][:,0], rays[-1][:,1], f'-', color=random_color, alpha=0.5)
            plt.plot(ph[0], ph[1], 'o', color=random_color)
# print(f'Error: {100000*err}')
#plot FW
plt.plot(FW[:,0], FW[:,1], 'w-')
plt.axis('equal')
plt.show()

print(f'Best error: {best_err} at pinhole: {best_ph}, angle: {best_angle}')


plt.figure(figsize=(5, 5))
plt.scatter(xy+pinhole[0], yx+pinhole[1], c=err.mean(axis=2), s=1000/nx, marker='s')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('y')
plt.title('Error')
plt.show()
plt.close()

#plot x / angle
plt.figure(figsize=(5, 5))
plt.scatter(xa+pinhole[0], np.rad2deg(ax), c=err.mean(axis=1), s=1000/nx, marker='s')
plt.colorbar()
plt.xlabel('x')
plt.ylabel('angle')
plt.title('Error')
plt.show()
plt.close()

# plot the best
rays, fan, _ = create_rfx_fan(nrays, start_angle - best_angle/2, span_angle + best_angle, best_ph, keep, ret_all=True)
for si in range(len(fsxr)):
    csxr[si] = eval_rfx_sxr(fan, femiss[si])
err = abs(fsxr - csxr) # error
plot_examples(femiss, fsxr, csxr, err, avgs, rays, inc_angles, color)
print(f'Best error: {err} at pinhole: {best_ph}, angle: {best_angle}')

# plot the original
rays, fan, _ = create_rfx_fan(nrays, start_angle, span_angle, pinhole, keep, ret_all=True)
for si in range(len(fsxr)):
    csxr[si] = eval_rfx_sxr(fan, femiss[si])
err = abs(fsxr - csxr) # error
plot_examples(femiss, fsxr, csxr, err, avgs, rays, inc_angles, color)
