In [None]:
from utils import *
# BEST_MODEL_DIR = BEST_MODEL_DIR
# BEST_MODEL_DIR = 'data/2637759/'
# BEST_MODEL_DIR = 'data/2637760/'
# BEST_MODEL_DIR = 'data/2637771/'
# BEST_MODEL_DIR = 'data/2637772/'
BEST_MODEL_DIR = 'data/3048577/'

net_path = model_path(FX, BEST_MODEL_DIR)

ds = LiuqeDataset(EVAL_DS_PATH)

x_mean_std = to_tensor(np.load(f'{BEST_MODEL_DIR}/x_mean_std.npz')['x_mean_std']) # load normalization parameters

m = FullNet(InputNet(x_mean_std), PtsEncoder(), FHead(3), FHead(1), LCFSHead())
m.load_state_dict(torch.load(net_path, map_location=torch.device("cpu"))) # load pretrained model

# convert model to onnx
rt_net = LiuqeRTNet(m.input_net, m.pts_enc, m.rt_head)
convert_to_onnx_static(rt_net, save_dir=[BEST_MODEL_DIR, SAVE_DIR])
# convert_to_onnx_dyn(rt_net, save_dir=[BEST_MODEL_DIR, SAVE_DIR])

In [None]:
# test the model on the demo dataset
d = np.load(f'{DS_DIR}/demo.npz')
n_examples = d[PTS].shape[0]
n_pts = d[PTS].shape[1]
n_ctrl_pts = 25

rand_i = 0

phys = d[PHYS][rand_i]
r = d[PTS][rand_i, :n_ctrl_pts, 0]
z = d[PTS][rand_i, :n_ctrl_pts, 1]

pyhs, r, z = to_tensor(phys), to_tensor(r), to_tensor(z)
rt = rt_net(pyhs, r, z) # forward pass through the RT net

Fx_pred, Br_pred, Bz_pred = rt[:,0].detach().cpu().numpy(), rt[:, 1].detach().cpu().numpy(), rt[:, 2].detach().cpu().numpy()
Fx_true, Br_true, Bz_true = d[FX][rand_i], d[BR][rand_i], d[BZ][rand_i]

k = 8
np.set_printoptions(precision=4, suppress=True, linewidth=1000)
print(f'Fx_pred -> {Fx_pred[:k]}\nFx_true -> {Fx_true[:k]}')
print(f'Br_pred -> {Br_pred[:k]}\nBr_true -> {Br_true[:k]}')
print(f'Bz_pred -> {Bz_pred[:k]}\nBz_true -> {Bz_true[:k]}')

In [None]:
# # plot
# if os.path.exists('test/imgs'): os.system('rm test/imgs/*') # remove old images
# # plot_lcfs_net_out(ds, LCFSNet(m.input_net, m.lcfs_head), save_dir='test', nplt=1) # plot LCFS outputs
# plot_network_outputs(ds, m, save_dir='test', nplt=20) # plot network outputs

## Test calculating Br Bz from grad of Fx

In [None]:
# TEST PINN version, i.e. with gradients on the input
def forward_with_Br_Bz_grad(phys, r, z):
    pyhs, r, z = to_tensor(phys), to_tensor(r), to_tensor(z)

    rt_net.eval() # set the model to evaluation mode

    r.requires_grad = True # enable gradients for r
    z.requires_grad = True # enable gradients for z

    #zero_gradients
    if r.grad is not None: r.grad.zero_()
    if z.grad is not None: z.grad.zero_()

    rt = rt_net(pyhs, r, z) # forward pass through the RT net

    Fx = rt[:,0] # Fx output
    Fx.sum().backward() # backpropagate Fx loss

    dF_dr = r.grad.detach().cpu().numpy() # dFx/dr
    dF_dz = z.grad.detach().cpu().numpy() # dFx/dz

    # calculate the magnetic field components from the gradients
    _r = r.detach().cpu().numpy() # convert to numpy
    π = np.pi
    Br_pinn = - dF_dz / (2 * π * _r)
    Bz_pinn = dF_dr / (2 * π * _r) 

    Fx_pred, Br_pred, Bz_pred = rt[:,0].detach().cpu().numpy(), rt[:, 1].detach().cpu().numpy(), rt[:, 2].detach().cpu().numpy()
    return Fx_pred, Br_pred, Bz_pred, Br_pinn, Bz_pinn


d = np.load(f'{DS_DIR}/demo.npz')
n_examples = d[PTS].shape[0]
n_pts = d[PTS].shape[1]

rand_i = np.random.randint(0, n_examples) # random example index

phys = d[PHYS][rand_i]
r = d[PTS][rand_i, :, 0]
z = d[PTS][rand_i, :, 1]

Fx_pred, Br_pred, Bz_pred, Br_pinn, Bz_pinn = forward_with_Br_Bz_grad(phys, r, z) # forward pass through the RT net with gradients

Fx_true, Br_true, Bz_true = d[FX][rand_i], d[BR][rand_i], d[BZ][rand_i]

eBr_tn = np.abs(Br_pred - Br_true) # error True/Net
eBz_tn = np.abs(Bz_pred - Bz_true)

e_min, e_max = np.max(eBr_tn), np.min(eBr_tn)

eBr_tp = np.abs(Br_pinn - Br_true) # error True/PINN
eBz_tp = np.abs(Bz_pinn - Bz_true)

k = 8
np.set_printoptions(precision=4, suppress=True, linewidth=1000)
print(f'\nBr_pred -> {Br_pred[:k]}\nBr_true -> {Br_true[:k]} \nBr_pinn -> {Br_pinn[:k]}')
print(f'\nBz_pred -> {Bz_pred[:k]}\nBz_true -> {Bz_true[:k]} \nBz_pinn -> {Bz_pinn[:k]}')

# plot
plt.figure(figsize=(16, 9))
plt.rcParams['image.cmap'] = 'viridis' # set colormap to plasma
s = 5
Br_max, Br_min = np.max(Br_true), np.min(Br_true)
Bz_max, Bz_min = np.max(Bz_true), np.min(Bz_true)
plt.subplot(2, 5, 1)
plt.scatter(r, z, c=Br_true, s=s, vmin=Br_min, vmax=Br_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('LIUQE Br')

plt.subplot(2, 5, 2)
plt.scatter(r, z, c=Br_pred, s=s, vmin=Br_min, vmax=Br_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('NET Br')

plt.subplot(2, 5, 3)
plt.scatter(r, z, c=Br_pinn, s=s, vmin=Br_min, vmax=Br_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('-∂Fx/∂z / 2πr')

plt.subplot(2, 5, 4)
plt.scatter(r, z, c=eBr_tn, s=s, vmin=e_min, vmax=e_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('Err Br True/Net')

plt.subplot(2, 5, 5)
plt.scatter(r, z, c=eBr_tp, s=s, vmin=e_min, vmax=e_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('Err Br True/Grad')

plt.subplot(2, 5, 6)
plt.scatter(r, z, c=Bz_true, s=s, vmin=Bz_min, vmax=Bz_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('LIUQE Bz')

plt.subplot(2, 5, 7)
plt.scatter(r, z, c=Bz_pred, s=s, vmin=Bz_min, vmax=Bz_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('NET Bz')   

plt.subplot(2, 5, 8)
plt.scatter(r, z, c=Bz_pinn, s=s, vmin=Bz_min, vmax=Bz_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('∂Fx/∂r / 2πr')

plt.subplot(2, 5, 9)
plt.scatter(r, z, c=eBz_tn, s=s, vmin=e_min, vmax=e_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('Err Bz True/Net')

plt.subplot(2, 5, 10)
plt.scatter(r, z, c=eBz_tp, s=s, vmin=e_min, vmax=e_max)
plot_vessel(), plt.colorbar(), plt.axis('equal'), plt.title('Err Bz True/Grad')

plt.tight_layout()
plt.show()
