In [1]:
import torch
import numpy as np
from pathlib import Path
import argparse
from neuralop.models import UNO
from magplot.base import create_mesh, mag_plotter
import pyvista as pv
# pv.start_xvfb()
pv.set_jupyter_backend('static')

import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from rtmag.test.analytical_field import get_analytic_b_field
from skimage.transform import resize

from rtmag.test.eval_plot import plot_sample

from rtmag.test.eval import evaluate

In [28]:
b_true = np.load('/mnt/d/low_lou/test/case1.npz')['b'].astype(np.float32)
b_true.shape

(64, 64, 64, 3)

In [29]:
meta_path = Path("/home/usr/workspace/uno_pi_cc_hnorm_unit_aug_ccc_square_energy_lowlou2")
checkpoint = torch.load(meta_path / "best_model.pt", map_location=device)

args = argparse.Namespace()
info = np.load(meta_path / 'args.npy', allow_pickle=True).item()
for key, value in info.items():
        args.__dict__[key] = value

b_norm = args.data["b_norm"]

model = UNO(
        hidden_channels = args.model["hidden_channels"],
        in_channels = args.model["in_channels"],
        out_channels = args.model["out_channels"],
        lifting_channels = args.model["lifting_channels"],
        projection_channels = args.model["projection_channels"],
        n_layers = args.model["n_layers"],

        factorization = args.model["factorization"],
        implementation = args.model["implementation"],
        rank = args.model["rank"],

        uno_n_modes = args.model["uno_n_modes"], 
        uno_out_channels = args.model["uno_out_channels"],
        uno_scalings = args.model["uno_scalings"],
    ).to(device)

checkpoint = torch.load(meta_path / 'best_model.pt')

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [30]:
b_bottom = b_true[:, :, 0, :]
b_bottom.shape

(64, 64, 3)

In [31]:
model_input = b_bottom
model_input = model_input[None, :, :, None, :]
model_input = model_input.transpose(0, 3, 2, 1, 4) / b_norm
model_input = torch.from_numpy(model_input).to(device) 
model_input.shape

torch.Size([1, 1, 64, 64, 3])

In [32]:
model_output = model(model_input)
model_output.shape

torch.Size([1, 64, 64, 64, 3])

In [33]:
b = model_output.detach().cpu().numpy().transpose(0, 3, 2, 1, 4)[0]
divi = (b_norm / np.arange(1, b.shape[2] + 1)).reshape(1, 1, -1, 1)
b = b * divi
# b = b[8:-8, 8:-8, :, :]
b.shape

(64, 64, 64, 3)

In [35]:
evaluate(b, b_true)

C_vec     : 0.9981
C_cs      : 0.9751
E_n'      : 0.9048
E_m'      : 0.8340
eps       : 0.9982
CW_sin    : 0.3123
L_f       : 0.1240
L_d       : 0.1006
l2_err    : 0.0615


{'C_vec': 0.9981094018838822,
 'C_cs': 0.9750930628427174,
 "E_n'": 0.9047551746330212,
 "E_m'": 0.8340436591152522,
 'eps': 0.99821068563295,
 'CW_sin': 0.3123187328865913,
 'L_f': 0.12402395475678948,
 'L_d': 0.10059799215715001,
 'l2_err': 0.061472072208964064}