In [1]:
from tvtk.api import tvtk, write_data
import pyvista as pv 
import os
import numpy as np

def pv_plot(B, vtk_path='./evaluation.vtk', points=((7, 64, 8), (7, 64, 8)), title='LL', overwrite=False):

    if not os.path.exists(vtk_path) or overwrite:
        dim = B.shape[:-1]
        pts = np.stack(np.mgrid[0:dim[0], 0:dim[1], 0:dim[2]], -1).astype(np.float32)
        pts = pts.transpose(2, 1, 0, 3)
        pts = pts.reshape((-1, 3))
        vectors = B.transpose(2, 1, 0, 3)
        vectors = vectors.reshape((-1, 3))
        sg = tvtk.StructuredGrid(dimensions=dim, points=pts)
        sg.point_data.vectors = vectors
        sg.point_data.vectors.name = 'B'
        write_data(sg, str(vtk_path))

    mesh = pv.read(vtk_path)
    xindmax, yindmax, zindmax = mesh.dimensions
    xcenter, ycenter, zcenter = mesh.center

    p = pv.Plotter()
    p.add_mesh(mesh.outline())

    sargs_B = dict(
        title='Bz [G]',
        title_font_size=15,
        height=0.25,
        width=0.05,
        vertical=True,
        position_x = 0.05,
        position_y = 0.05,
    )
    dargs_B = dict(
        scalars='B', 
        component=2, 
        clim=(-150, 150), 
        scalar_bar_args=sargs_B, 
        show_scalar_bar=False, 
        lighting=False
    )
    p.add_mesh(mesh.extract_subset((0, xindmax, 0, yindmax, 0, 0)), 
            cmap='gray', **dargs_B)

    def draw_streamlines(pts):
        stream, src = mesh.streamlines(
            return_source=True,
            start_position = pts,
            integration_direction='both',
            max_time=1000,
        )
        # print(pts)
        key = pts[0]*pts[1] + (pts[0]//pts[1]) + (pts[0] - pts[1])
        # print(key)
        np.random.seed(key)
        colors = np.random.rand(3)
        # if pts[0] == 16 and pts[1] == 48:
        #     colors = 'white'
        # print(colors)
        p.add_mesh(stream.tube(radius=0.2), lighting=False, color=colors)
        p.add_mesh(src, point_size=7, color=colors)

    xrange = points[0]
    yrange = points[1]
    for i in np.arange(*xrange):
        for j in np.arange(*yrange):
            try: 
                draw_streamlines((i, j, 0))
            except:
                print(i, j)

    p.camera_position = 'xy'
    p.show_bounds()
    # p.add_title(title)

    return p

In [2]:
import numpy as np 
import pandas as pd

from pinf.metric import vector_norm, curl, divergence

def metrics(B, b, B_potential, b_potential):
    """
    B is the numerical solution
    b is the reference magnetic field
    """

    c_vec = np.sum((B * b).sum(-1)) / np.sqrt((B ** 2).sum(-1).sum() * (b ** 2).sum(-1).sum())

    M = np.prod(B.shape[:-1])
    c_cs = 1 / M * np.sum((B * b).sum(-1) / vector_norm(B) / vector_norm(b))

    E_n = vector_norm(B - b).sum() / vector_norm(b).sum()

    E_m = 1 / M * (vector_norm(B - b) / vector_norm(b)).sum()

    eps = (vector_norm(B) ** 2).sum() / (vector_norm(b) ** 2).sum()

    # B_potential = get_potential(B[:, :, 0, 2], B.shape[-1])
    eps_p = (vector_norm(B) ** 2).sum() / (vector_norm(B_potential) ** 2).sum()

    # b_potential = get_potential(b[:, :, 0, 2], b.shape[-1])
    # eps_p_b = (vector_norm(b) ** 2).sum() / (vector_norm(b_potential) ** 2).sum()

    j = curl(B)
    sig_J = (vector_norm(np.cross(j, B, -1)) / vector_norm(B)).sum() / vector_norm(j).sum()
    L1 = (vector_norm(np.cross(j, B, -1)) ** 2 / vector_norm(B) ** 2).mean()
    L2 = (divergence(B) ** 2).mean()

    # j_b = curl(b)
    # sig_J_b = (vector_norm(np.cross(j_b, b, -1)) / vector_norm(b)).sum() / vector_norm(j_b).sum()
    # L1_b = (vector_norm(np.cross(j_b, b, -1)) ** 2 / vector_norm(b) ** 2).mean()
    # L2_b = (divergence(b) ** 2).mean()

    key = ["C_vec", "C_cs", "1-En", "1-Em", "eps", "eps_p", "sig_J", "L1", "L2"]
    metric = [c_vec, c_cs, 1-E_n, 1-E_m, eps, eps_p, sig_J, L1, L2]
    return dict(zip(key, metric))

In [3]:
import os
import glob
import numpy as np

from pinf.analytical_field import get_analytic_b_field
from pinf.unpack import load_cube
from pinf.potential_field import get_potential

In [4]:
vtk_path = './output2/vtk'
metric_path = './output2/eval'
plot_path = './output2/plot'
os.makedirs(vtk_path, exist_ok=True)
os.makedirs(metric_path, exist_ok=True)
os.makedirs(plot_path, exist_ok=True)

In [5]:
b = get_analytic_b_field(n=1, m=1, l=0.3, psi=np.pi/2, resolution=64, bounds=[-1, 1, -1, 1, 0, 2])

In [6]:
b_potential = get_potential(b[:, :, 0, 2], b.shape[-1])

Potential Field: 100%|██████████| 6/6 [00:00<00:00,  8.24it/s]


In [7]:
title = 'LL'
vtk_file = os.path.join(vtk_path, f'{title}.vtk')
p = pv_plot(B=b, vtk_path=vtk_file, points=((16, 49, 8), (16, 49, 8)), title=title)

xy_path = os.path.join(plot_path, f'{title}_xy.pdf')
yz_path = os.path.join(plot_path, f'{title}_yz.pdf')
xz_path = os.path.join(plot_path, f'{title}_xz.pdf')
xz_tilted_path = os.path.join(plot_path, f'{title}_xz_tilted.pdf')

if not os.path.exists(xy_path):
    p.camera_position = 'xy'
    p.save_graphic(xy_path)

if not os.path.exists(yz_path):
    p.camera_position = 'yz'
    p.save_graphic(yz_path)

if not os.path.exists(xz_path):
    p.camera_position = 'xz'
    p.save_graphic(xz_path)  

if not os.path.exists(xz_tilted_path):
    p.camera_position = 'xz'
    p.camera.azimuth = -30
    p.camera.elevation = 25
    p.save_graphic(xz_tilted_path)

In [8]:
metric = metrics(B=b, b=b, B_potential=b_potential, b_potential=b_potential)
iterinfo = {'iter': -1}
metric = {**iterinfo, **metric}

In [9]:
df = pd.DataFrame.from_dict([metric])
df.to_csv(os.path.join(metric_path, 'metric.csv'), index=False)
df

Unnamed: 0,iter,C_vec,C_cs,1-En,1-Em,eps,eps_p,sig_J,L1,L2
0,-1,1.0,1.0,1.0,1.0,1.0,0.130138,0.01308,0.002065,0.002024


In [10]:
for file_path in sorted(glob.glob('run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_*.nf2')):
    iters = os.path.basename(file_path).split('.')[0][7:]
    title = 'PINN' + '_' + iters
    B = load_cube(file_path)
    
    metric = metrics(B=B, b=b, B_potential=b_potential, b_potential=b_potential)
    iterinfo = {'iter': int(iters)}
    metric = {**iterinfo, **metric}
    
    df = pd.concat([df, pd.DataFrame([metric])], ignore_index=True)
    print(file_path)

df.to_csv(os.path.join(metric_path, 'metric.csv'), index=False)
df

run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000000.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000001.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000100.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000200.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000300.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000400.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000500.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000600.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000700.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000800.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000900.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001000.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001100.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001200.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001300.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001400.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001500.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001600.nf2
run2/dim256_bin1_pfTrue_ld0.

Unnamed: 0,iter,C_vec,C_cs,1-En,1-Em,eps,eps_p,sig_J,L1,L2
0,-1,1.000000,1.000000,1.000000,1.000000,1.000000,0.130138,0.013080,0.002065,0.002024
1,0,0.060634,0.268943,-1.421001,-6.143196,0.743979,0.096820,0.883638,0.000160,0.000357
2,1,-0.046547,-0.283600,-0.559940,-2.791768,0.128930,0.016779,0.757738,0.000444,0.000408
3,100,0.208082,0.505417,0.020215,-0.285618,0.034750,0.004522,0.776417,0.007800,0.000265
4,200,0.412154,0.648507,0.017317,-0.622371,0.189378,0.024645,0.724366,0.109216,0.024384
...,...,...,...,...,...,...,...,...,...,...
498,49600,0.993660,0.909454,0.776006,0.527431,0.965616,0.125663,0.066401,0.003089,0.002433
499,49700,0.993762,0.925673,0.775532,0.522425,0.962079,0.125203,0.065954,0.003196,0.002562
500,49800,0.993538,0.921283,0.771459,0.512363,0.965637,0.125666,0.066792,0.003142,0.002394
501,49900,0.993783,0.929982,0.774804,0.515479,0.964135,0.125471,0.065764,0.003044,0.002539


In [15]:
for file_path in sorted(glob.glob('run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_*.nf2')):
    iters = os.path.basename(file_path).split('.')[0][7:]
    title = 'PINN' + '_' + iters
    B = load_cube(file_path)
    
    vtk_file = os.path.join(vtk_path, f'{title}.vtk')
    p = pv_plot(B=B, vtk_path=vtk_file, points=((16, 49, 8), (16, 49, 8)), title=title)

    xy_path = os.path.join(plot_path, f'{title}_xy.pdf')
    yz_path = os.path.join(plot_path, f'{title}_yz.pdf')
    xz_path = os.path.join(plot_path, f'{title}_xz.pdf')
    xz_tilted_path = os.path.join(plot_path, f'{title}_xz_tilted.pdf')

    if not os.path.exists(xy_path):
        p.camera_position = 'xy'
        p.save_graphic(xy_path)

    if not os.path.exists(yz_path):
        p.camera_position = 'yz'
        p.save_graphic(yz_path)

    if not os.path.exists(xz_path):
        p.camera_position = 'xz'
        p.save_graphic(xz_path)  

    if not os.path.exists(xz_tilted_path):
        p.camera_position = 'xz'
        p.camera.azimuth = -30
        p.camera.elevation = 25
        p.save_graphic(xz_tilted_path)

    print(file_path)

run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000000.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000001.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000100.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000200.nf2
24 24
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000300.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000400.nf2
40 32
48 16
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000500.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000600.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000700.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000800.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_000900.nf2
48 16
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001000.nf2
48 16
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001100.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001200.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001300.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001400.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001500.nf2
run2/dim256_bin1_pfTrue_ld0.1_lf0.1/fields_001600.nf