In [1]:
import numpy as np
from scipy.integrate import solve_bvp


def _differential_equation(mu, u, n, a2):
    """
    The differential equation to solve for P

    :param mu: cos(theta)
    :param u: P function and derivative
    :param n: variable according to Low & Lou (1989)
    :param a2: eigenvalue

    """
    P, dP = u
    dP_dmu = dP
    d2P_dmu2 = -(n * (n + 1) * P + a2 * (1 + n) / n * P ** (1 + 2 / n)) / (1 - mu ** 2 + 1e-8)
    return (dP_dmu, d2P_dmu2)


def get_analytic_b_field(n=1, m=1, l=0.3, psi=np.pi / 4, resolution=64, bounds=[-1, 1, -1, 1, 0, 2]):
    """
    Calculate the analytic NLFF field from Low & Lou (1989).

    :param n: variable see Low & Lou (1989), only works for n=1
    :param m: used for generating a proper initial condition.
    :param a2: eigenvalue
    :param l: depth below the photosphere
    :param psi: angle of the magnetic field relative to the dipol axis
    :param resolution: spatial resolution of the magnetic field in pixels
    :param bounds: dimensions of the volume (x_start, x_end, y_start, y_end, z_start, z_end)
    :return: magnetic field B (x, y, z, v)
    """
    sol_P, a2 = solve_P(n, m)

    resolution = [resolution] * 3 if not isinstance(resolution, list) else resolution
    coords = np.stack(np.meshgrid(np.linspace(bounds[0], bounds[1], resolution[1], dtype=np.float32),
                                  np.linspace(bounds[2], bounds[3], resolution[0], dtype=np.float32),
                                  np.linspace(bounds[4], bounds[5], resolution[2], dtype=np.float32)), -1).transpose(
        [1, 0, 2, 3])

    x, y, z = coords[..., 0], coords[..., 1], coords[..., 2]
    X = x * np.cos(psi) - (z + l) * np.sin(psi)
    Y = y
    Z = x * np.sin(psi) + (z + l) * np.cos(psi)

    # to spherical coordinates
    xy = X ** 2 + Y ** 2
    r = np.sqrt(xy + Z ** 2)
    theta = np.arctan2(np.sqrt(xy), Z)
    phi = np.arctan2(Y, X)

    mu = np.cos(theta)

    P, dP_dmu = sol_P(mu)
    A = P / r ** n
    dA_dtheta = -np.sin(theta) / (r ** n) * dP_dmu
    dA_dr = P * (-n * r ** (-n - 1))
    Q = np.sqrt(a2) * A * np.abs(A) ** (1 / n)

    Br = (r ** 2 * np.sin(theta)) ** -1 * dA_dtheta
    Btheta = - (r * np.sin(theta)) ** -1 * dA_dr
    Bphi = (r * np.sin(theta)) ** -1 * Q

    BX = Br * np.sin(theta) * np.cos(phi) + Btheta * np.cos(theta) * np.cos(phi) - Bphi * np.sin(phi)
    BY = Br * np.sin(theta) * np.sin(phi) + Btheta * np.cos(theta) * np.sin(phi) + Bphi * np.cos(phi)
    BZ = Br * np.cos(theta) - Btheta * np.sin(theta)

    Bx = BX * np.cos(psi) + BZ * np.sin(psi)
    By = BY
    Bz = - BX * np.sin(psi) + BZ * np.cos(psi)

    b_field = np.real(np.stack([Bx, By, Bz], -1))
    return b_field


def solve_P(n, m):
    """
    Solve the differential equation from Low & Lou (1989).

    :param n: variable (only n=1)
    :param v0: start condition for dP/dmu
    :param P0: boundary condition for P(-1) and P(1)
    :return: interpolated functions for P and dP/dmu
    """

    def f(x, y, p):
        a2 = p[0]
        d2P_dmu2 = -(n * (n + 1) * y[0] + a2 * (1 + n) / n * y[0] ** (1 + 2 / n)) / (1 - x ** 2 + 1e-6)
        return [y[1], d2P_dmu2]

    def f_boundary(Pa, Pb, p):
        return np.array([Pa[0] - 0, Pb[0] - 0, Pa[1] - 10])

    mu = np.linspace(-1, 1, num=256)

    if m % 2 == 0:
        init = np.cos(mu * (m + 1) * np.pi / 2)
    else:
        init = np.sin(mu * (m + 1) * np.pi / 2)

    dinit = 10 * np.ones_like(init)  #
    initial = np.stack([init, dinit])

    @np.vectorize
    def shooting(a2_init):
        eval = solve_bvp(f, f_boundary, x=mu, y=initial, p=[a2_init], verbose=0, tol=1e-6)
        if eval.success == False:
            return None
        return eval

    # use shooting to find eigenvalues
    evals = shooting(np.linspace(0, 10, 100, dtype=np.float32))
    evals = [e for e in evals if e is not None]

    eigenvalues = np.array([e.p for e in evals])
    eigenvalues = sorted(set(np.round(eigenvalues, 4).reshape((-1,))))

    # get final solution
    eval = shooting([eigenvalues[-1]])[0]

    return eval.sol, eval.p[0]

In [2]:
import json
import os 

config_path = 'config_run.json'

with open(config_path) as config:
    info = json.load(config)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= info['simul']['gpu_id']

n = info['exact']['n']
m = info['exact']['m']
l = info['exact']['l']
psi = eval(info['exact']['psi'])
resolution = info['exact']['resolution']
bounds = info['exact']['bounds']

B = get_analytic_b_field(n=n, m=m, l=l, psi=psi, resolution=resolution, bounds=bounds)

base_path = os.path.join(info['simul']['base_path'], "run/")
meta_path = info['simul']['meta_path']

bin = info['simul']['bin']

height = info['simul']['height']
spatial_norm = info['simul']['spatial_norm']
b_norm = info['simul']['b_norm']

meta_info = info['simul']['meta_info']
dim = info['simul']['dim']
positional_encoding = info['simul']['positional_encoding']
use_potential_boundary = info['simul']['use_potential_boundary']
potential_strides = info['simul']['potential_strides']
use_vector_potential = info['simul']['use_vector_potential']
w_div = info['simul']['w_div']
w_ff = info['simul']['w_ff']
decay_iterations = info['simul']['decay_iterations']
device = info['simul']['device']
work_directory = info['simul']['work_directory']

total_iterations = info['simul']['total_iterations']
batch_size = info['simul']['batch_size']
log_interval = info['simul']['log_interval']
validation_interval = info['simul']['validation_interval']
num_workers = info['simul']['num_workers']


os.makedirs(base_path, exist_ok=True)
b_bottom = B[:, :, 0, :]

In [3]:
import logging

import torch
from torch import nn
from torch.cuda import get_device_name
from torch.optim.lr_scheduler import ExponentialLR

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from astropy.nddata import block_reduce

from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler
from tqdm import tqdm

from datetime import datetime

In [4]:
def jacobian(output, coords):
    jac_matrix = [torch.autograd.grad(output[:, i], coords,
                                      grad_outputs=torch.ones_like(output[:, i]).to(output),
                                      retain_graph=True,
                                      create_graph=True)[0]
                  for i in range(output.shape[1])]
    jac_matrix = torch.stack(jac_matrix, dim=1)
    return jac_matrix

class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)
    
class PositionalEncoding(nn.Module):
    """
    Positional Encoding of the input coordinates.

    encodes x to (..., sin(2^k x), cos(2^k x), ...)
    k takes "num_freqs" number of values equally spaced between [0, max_freq]
    """

    def __init__(self, max_freq, num_freqs):
        """
        Args:
            max_freq (int): maximum frequency in the positional encoding.
            num_freqs (int): number of frequencies between [0, max_freq]
        """
        super().__init__()
        freqs = 2 ** torch.linspace(0, max_freq, num_freqs)
        self.register_buffer("freqs", freqs)  # (num_freqs)

    def forward(self, x):
        """
        Inputs:
            x: (batch, num_samples, in_features)
        Outputs:
            out: (batch, num_samples, 2*num_freqs*in_features)
        """
        x_proj = x.unsqueeze(dim=-2) * self.freqs.unsqueeze(dim=-1)  # (num_rays, num_samples, num_freqs, in_features)
        x_proj = x_proj.reshape(*x.shape[:-1], -1)  # (num_rays, num_samples, num_freqs*in_features)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)],
                        dim=-1)  # (num_rays, num_samples, 2*num_freqs*in_features)
        return out
    
class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim, pos_encoding=False):
        super().__init__()
        if pos_encoding:
            posenc = PositionalEncoding(8, 20)
            d_in = nn.Linear(in_coords * 40, dim)
            self.d_in = nn.Sequential(posenc, d_in)
        else:
            self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

class VectorPotentialModel(nn.Module):

    def __init__(self, in_coords, dim, pos_encoding=False):
        super().__init__()
        if pos_encoding:
            posenc = PositionalEncoding(8, 20)
            d_in = nn.Linear(in_coords * 40, dim)
            self.d_in = nn.Sequential(posenc, d_in)
        else:
            self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, 3)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        coord = x
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        a = self.d_out(x)
        #
        jac_matrix = jacobian(a, coord)
        dAy_dx = jac_matrix[:, 1, 0]
        dAz_dx = jac_matrix[:, 2, 0]
        dAx_dy = jac_matrix[:, 0, 1]
        dAz_dy = jac_matrix[:, 2, 1]
        dAx_dz = jac_matrix[:, 0, 2]
        dAy_dz = jac_matrix[:, 1, 2]
        rot_x = dAz_dy - dAy_dz
        rot_y = dAx_dz - dAz_dx
        rot_z = dAy_dx - dAx_dy
        b = torch.stack([rot_x, rot_y, rot_z], -1)
        #
        return b

In [5]:
class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.zeros((1, 3))
        c[:, 2] = (1 / np.sqrt(2 * np.pi))
        c = torch.tensor(c, dtype=torch.float32, )
        self.register_buffer('c', c)

    def forward(self, coord):
        v1 = self.b_n[:, None]
        v2 = 2 * np.pi * ((-self.r_p[:, None] + coord[None, :] + self.c[None]) ** 2).sum(-1) ** 0.5
        potential = torch.sum(v1 / v2, dim=0)
        return potential
    
def get_potential_boundary(b_n, height, batch_size=2048):
    assert not np.any(np.isnan(b_n)), 'Invalid data value'

    cube_shape = (*b_n.shape, height)

    b_n = b_n.reshape((-1)).astype(np.float32)
    coords = [np.stack(np.mgrid[:cube_shape[0], :cube_shape[1], cube_shape[2] - 2:cube_shape[2] + 1], -1),
              np.stack(np.mgrid[:cube_shape[0], -1:2, :cube_shape[2]], -1),
              np.stack(np.mgrid[:cube_shape[0], cube_shape[1] - 2:cube_shape[1] + 1, :cube_shape[2]], -1),
              np.stack(np.mgrid[-1:2, :cube_shape[1], :cube_shape[2]], -1),
              np.stack(np.mgrid[cube_shape[0] - 2:cube_shape[0] + 1, :cube_shape[1], :cube_shape[2]], -1), ]
    coords_shape = [c.shape[:-1] for c in coords]
    flat_coords = np.concatenate([c.reshape(((-1, 3))) for c in coords])

    r_p = np.stack(np.mgrid[:cube_shape[0], :cube_shape[1], :1], -1).reshape((-1, 3))

    # torch code
    # r = (x * y, 3); coords = (x*y*z, 3), c = (1, 3)
    # --> (x * y, x * y * z, 3) --> (x * y, x * y * z) --> (x * y * z)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    with torch.no_grad():
        b_n = torch.tensor(b_n, dtype=torch.float32, )
        r_p = torch.tensor(r_p, dtype=torch.float32, )
        model = nn.DataParallel(PotentialModel(b_n, r_p, )).to(device)

        flat_coords = torch.tensor(flat_coords, dtype=torch.float32, )

        potential = []
        for coord, in tqdm(DataLoader(TensorDataset(flat_coords), batch_size=batch_size, num_workers=2),
                           desc='Potential Boundary'):
            coord = coord.to(device)
            p_batch = model(coord)
            potential += [p_batch.cpu()]

    potential = torch.cat(potential).numpy()
    idx = 0
    fields = []
    for s in coords_shape:
        p = potential[idx:idx + np.prod(s)].reshape(s)
        b = - 1 * np.stack(np.gradient(p, axis=[0, 1, 2], edge_order=2), axis=-1)
        fields += [b]
        idx += np.prod(s)

    fields = [fields[0][:, :, 1].reshape((-1, 3)),
              fields[1][:, 1, :].reshape((-1, 3)), fields[2][:, 1, :].reshape((-1, 3)),
              fields[3][1, :, :].reshape((-1, 3)), fields[4][1, :, :].reshape((-1, 3))]
    coords = [coords[0][:, :, 1].reshape((-1, 3)),
              coords[1][:, 1, :].reshape((-1, 3)), coords[2][:, 1, :].reshape((-1, 3)),
              coords[3][1, :, :].reshape((-1, 3)), coords[4][1, :, :].reshape((-1, 3))]
    return np.concatenate(coords), np.concatenate(fields)

def _load_potential_field_data(hmi_cube, height, reduce):
    if reduce > 1:
        hmi_cube = block_reduce(hmi_cube, (reduce, reduce, 1), func=np.mean)
        height = height // reduce
    pf_batch_size = int(1024 * 512 ** 2 / np.prod(hmi_cube.shape[:2]))  # adjust batch to AR size
    pf_coords, pf_values = get_potential_boundary(hmi_cube[:, :, 2], height, batch_size=pf_batch_size)
    pf_values = np.array(pf_values, dtype=np.float32)
    pf_coords = np.array(pf_coords, dtype=np.float32) * reduce # expand to original coordinate spacing
    pf_err = np.zeros_like(pf_values)
    return pf_coords, pf_err, pf_values

def _plot_data(n_hmi_cube, plot_path, b_norm):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(n_hmi_cube[..., 0].transpose(), vmin=-b_norm, vmax=b_norm, cmap='gray', origin='lower')
    axs[1].imshow(n_hmi_cube[..., 1].transpose(), vmin=-b_norm, vmax=b_norm, cmap='gray', origin='lower')
    axs[2].imshow(n_hmi_cube[..., 2].transpose(), vmin=-b_norm, vmax=b_norm, cmap='gray', origin='lower')
    plt.savefig(os.path.join(plot_path, 'b.jpg'))
    plt.close()

def prep_b_data(b_bottom,
                height, spatial_norm, b_norm,
                potential_boundary=True, potential_strides=4,
                plot=False, plot_path=None):
    # load coordinates
    mf_coords = np.stack(np.mgrid[:b_bottom.shape[0], :b_bottom.shape[1], :1], -1)
    # flatten data
    mf_coords = mf_coords.reshape((-1, 3))
    mf_values = b_bottom.reshape((-1, 3))
    # load potential field
    if potential_boundary:
        pf_coords, pf_err, pf_values = _load_potential_field_data(b_bottom, height, potential_strides)
        # concatenate pf data points
        coords = np.concatenate([pf_coords, mf_coords])
        values = np.concatenate([pf_values, mf_values])
    else:
        coords = mf_coords
        values = mf_values

    coords = coords.astype(np.float32)
    values = values.astype(np.float32)

    # normalize data
    values = Normalize(-b_norm, b_norm, clip=False)(values) * 2 - 1

    # apply spatial normalization
    coords = coords / spatial_norm

    # stack to numpy array
    data = np.stack([coords, values], 1)

    if plot:
        _plot_data(b_bottom, plot_path, b_norm)

    return data

def calculate_loss(b, coords):
    jac_matrix = jacobian(b, coords)
    dBx_dx = jac_matrix[:, 0, 0]
    dBy_dx = jac_matrix[:, 1, 0]
    dBz_dx = jac_matrix[:, 2, 0]
    dBx_dy = jac_matrix[:, 0, 1]
    dBy_dy = jac_matrix[:, 1, 1]
    dBz_dy = jac_matrix[:, 2, 1]
    dBx_dz = jac_matrix[:, 0, 2]
    dBy_dz = jac_matrix[:, 1, 2]
    dBz_dz = jac_matrix[:, 2, 2]
    #
    rot_x = dBz_dy - dBy_dz
    rot_y = dBx_dz - dBz_dx
    rot_z = dBy_dx - dBx_dy
    #
    j = torch.stack([rot_x, rot_y, rot_z], -1)
    jxb = torch.cross(j, b, -1)
    force_loss = torch.sum(jxb ** 2, dim=-1) / (torch.sum(b ** 2, dim=-1) + 1e-7)
    divergence_loss = (dBx_dx + dBy_dy + dBz_dz) ** 2
    return divergence_loss, force_loss

class BoundaryDataset(Dataset):

    def __init__(self, batches_path):
        self.batches_path = batches_path

    def __len__(self):
        return np.load(self.batches_path, mmap_mode='r').shape[0]

    def __getitem__(self, idx):
        # lazy load data
        d = np.load(self.batches_path, mmap_mode='r')[idx]
        d = np.copy(d)
        coord, field = d[:, 0],  d[:, 1]
        return coord, field

class ImageDataset(Dataset):

    def __init__(self, cube_shape, norm, z=0):
        coordinates = np.stack(np.mgrid[:cube_shape[0],
                               :cube_shape[1]], -1)
        self.coordinates = coordinates
        self.coordinates_flat = coordinates.reshape((-1, 2))
        self.norm = norm
        self.z = z / self.norm

    def __len__(self, ):
        return self.coordinates_flat.shape[0]

    def __getitem__(self, idx):
        coord = self.coordinates_flat[idx]
        scaled_coord = [coord[0] / self.norm,
                        coord[1] / self.norm,
                        self.z]
        return np.array(scaled_coord, dtype=np.float32)

In [6]:
def coords(xbounds, ybounds, zbounds):
    return np.stack(np.mgrid[xbounds[0]:xbounds[1]+1, ybounds[0]:ybounds[1]+1, zbounds[0]:zbounds[1]+1], axis=-1).astype(np.float32)

In [7]:
class NF2Trainer:

    def __init__(self, base_path, b_bottom, height, spatial_norm, b_norm, meta_info=None, dim=256,
                 positional_encoding=False, meta_path=None, use_potential_boundary=True, potential_strides=1, use_vector_potential=False,
                 w_div=0.1, w_ff=0.1, decay_iterations=None,
                 device=None, work_directory=None):
        """Magnetic field extrapolations trainer

        :param base_path: path to the results folder.
        :param b_bottom: magnetic field data (x, y, (Bp, -Bt, Br)).
        :param height: height of simulation volume.
        :param spatial_norm: normalization of coordinate axis.
        :param b_norm: normalization of magnetic field strength.
        :param meta_info: additional data information. stored in the save state.
        :param dim: number of neurons per layer (8 layers).
        :param positional_encoding: use positional encoding.
        :param meta_path: start from a pre-learned simulation state.
        :param use_potential_boundary: use potential field as boundary condition. If None use an open boundary.
        :param potential_strides: use binned potential field boundary condition. Only applies if use_potential_boundary = True.
        :param use_vector_potential: derive the magnetic field from a vector potential.
        :param w_div: weighting parameter for divergence freeness of the simulation.
        :param w_ff: weighting parameter for force freeness of the simulation.
        :param decay_iterations: decay weighting for boundary condition (w_bc=1000) over n iterations to 1.
        :param device: device for model training.
        :param work_directory: directory to store scratch data (prepared batches).
        """

        # general parameters
        self.base_path = base_path
        work_directory = base_path if work_directory is None else work_directory
        self.work_directory = work_directory
        self.save_path = os.path.join(base_path, 'extrapolation_result.nf2')
        self.checkpoint_path = os.path.join(base_path, 'checkpoint.pt')

        # data parameters
        self.spatial_norm = spatial_norm
        self.height = height
        self.b_norm = b_norm
        self.meta_info = meta_info

        # init directories
        os.makedirs(base_path, exist_ok=True)
        os.makedirs(work_directory, exist_ok=True)

        # init log
        self.log = logging.getLogger()
        self.log.setLevel(logging.INFO)
        for hdlr in self.log.handlers[:]:  # remove all old handlers
            self.log.removeHandler(hdlr)
        self.log.addHandler(logging.FileHandler("{0}/{1}.log".format(base_path, "info_log")))  # set the new file handler
        self.log.addHandler(logging.StreamHandler())  # set the new console handler

        # log settings
        self.log.info('Configuration:')
        self.log.info(
            'dim: %d, w_div: %f, w_ff: %f, decay_iterations: %s, potential: %s, vector_potential: %s, ' % (
                dim, w_div, w_ff, str(decay_iterations), str(use_potential_boundary),
                str(use_vector_potential)))

        # setup device
        n_gpus = torch.cuda.device_count()
        device_names = [get_device_name(i) for i in range(n_gpus)]
        if device is None:
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.log.info('Using device: %s (gpus %d) %s' % (str(device), n_gpus, str(device_names)))
        self.device = device

        # prepare data
        self.b_bottom = b_bottom

        # load dataset
        self.data = prep_b_data(b_bottom, height, spatial_norm, b_norm,
                                plot=True, plot_path=base_path,
                                potential_boundary=use_potential_boundary, potential_strides=potential_strides)
        # self.cube_shape = [*b_bottom.shape[:-1], height]

        Nx, Ny, _ = b_bottom.shape
        Nz = height

        self.cube_shape = (Nx, Ny, Nz)

        # init model
        if use_vector_potential:
            model = VectorPotentialModel(3, dim, pos_encoding=positional_encoding)
        else:
            model = BModel(3, 3, dim, pos_encoding=positional_encoding)
        parallel_model = nn.DataParallel(model)
        parallel_model.to(device)
        opt = torch.optim.Adam(parallel_model.parameters(), lr=5e-4)
        self.model = model
        self.parallel_model = parallel_model

        # load last state
        if os.path.exists(self.checkpoint_path):
            state_dict = torch.load(self.checkpoint_path, map_location=device)
            start_iteration = state_dict['iteration']
            model.load_state_dict(state_dict['m'])
            opt.load_state_dict(state_dict['o'])
            history = state_dict['history']
            w_bc = state_dict['w_bc']
            self.log.info('Resuming training from iteration %d' % start_iteration)
        else:
            if meta_path:
                state_dict = torch.load(meta_path, map_location=device)['model'].state_dict() \
                    if meta_path.endswith('nf2') else torch.load(meta_path, map_location=device)['m']
                model.load_state_dict(state_dict)
                opt = torch.optim.Adam(parallel_model.parameters(), lr=5e-5)
                self.log.info('Loaded meta state: %s' % meta_path)
            # init
            start_iteration = 0
            w_bc = 1000 if decay_iterations else 1
            history = {'iteration': [], 'height': [],
                       'b_loss': [], 'divergence_loss': [], 'force_loss': [], 'sigma_angle': []}

        self.opt = opt
        self.start_iteration = start_iteration
        self.history = history
        self.w_bc = w_bc
        self.w_bc_decay = (1 / 1000) ** (1 / decay_iterations) if decay_iterations is not None else 1
        self.w_div, self.w_ff = w_div, w_ff

        collocation_coords = coords((0, Nx-1), (0, Ny-1), (0, Nz-1)).reshape(-1, 3)
        normalized_collocation_coords = collocation_coords / self.spatial_norm
        self.normalized_collocation_coords = torch.tensor(normalized_collocation_coords)

    def train(self, total_iterations, batch_size, log_interval=100, validation_interval=100, num_workers=None):
        """Start magnetic field extrapolation fit.

        :param total_iterations: number of iterations for training.
        :param batch_size: number of samples per iteration.
        :param log_interval: log training details every nth iteration.
        :param validation_interval: evaluate simulation every nth iteration.
        :param num_workers: number of workers for data loading (default system spec).
        :return: path of the final save state.
        """
        start_time = datetime.now()
        num_workers = os.cpu_count() if num_workers is None else num_workers

        model = self.parallel_model
        opt = self.opt
        device = self.device
        w_div, w_ff = self.w_div, self.w_ff

        # init
        scheduler = ExponentialLR(opt, gamma=(5e-5 / 5e-4) ** (1 / total_iterations))
        iterations = total_iterations - self.start_iteration
        if iterations <= 0:
            self.log.info('Training already finished!')
            return self.save_path

        # init loader
        data_loader, batches_path = self._init_loader(batch_size, self.data, num_workers, iterations)
        
        model.train()
        for iter, (boundary_coords, boundary_b) in tqdm(enumerate(data_loader, start=self.start_iteration),
                                                           total=len(data_loader), desc='Training'):

            boundary_coords, boundary_b= boundary_coords.to(device), boundary_b.to(device)

            perm = torch.randperm(self.normalized_collocation_coords.shape[0])
            idx = perm[:batch_size]
            co_coords = self.normalized_collocation_coords[idx].to(device)

            # concatenate boundary and random points
            # n_boundary_coords = boundary_coords.shape[0]
            # r = torch.cat([boundary_coords, co_coords], 0)
            r = co_coords
            r.requires_grad = True

            # forward step
            B = model(r)

            # if iter == 0:
            #     model.eval()
            #     torch.save({'model': self.model,
            #         'cube_shape': self.cube_shape,
            #         'b_norm': self.b_norm,
            #         'spatial_norm': self.spatial_norm,
            #         'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))
            #     self.plot_sample(iter-1, batch_size=batch_size)
            #     model.train()

            # compute boundary loss
            # boundary_B = B[:n_boundary_coords]
            boundary_B = model(boundary_coords)
            # bc_loss = torch.abs(boundary_B - boundary_b)
            # bc_loss = torch.mean(bc_loss.pow(2).sum(-1))

            bc_loss = torch.sum((boundary_B - boundary_b)**2, dim=-1)
            bc_loss = torch.mean(bc_loss)
            # compute div and ff loss
            # divergence_loss, force_free_loss = calculate_loss(b, coords)

            dBx_dr = torch.autograd.grad(B[:, 0], r, torch.ones_like(B[:, 0]), retain_graph=True, create_graph=True)[0]
            dBy_dr = torch.autograd.grad(B[:, 1], r, torch.ones_like(B[:, 1]), retain_graph=True, create_graph=True)[0]
            dBz_dr = torch.autograd.grad(B[:, 2], r, torch.ones_like(B[:, 2]), retain_graph=True, create_graph=True)[0]

            dBx_dx = dBx_dr[:, 0]
            dBx_dy = dBx_dr[:, 1]
            dBx_dz = dBx_dr[:, 2]

            dBy_dx = dBy_dr[:, 0]
            dBy_dy = dBy_dr[:, 1]
            dBy_dz = dBy_dr[:, 2]

            dBz_dx = dBz_dr[:, 0]
            dBz_dy = dBz_dr[:, 1]
            dBz_dz = dBz_dr[:, 2]

            rot_x = dBz_dy - dBy_dz
            rot_y = dBx_dz - dBz_dx
            rot_z = dBy_dx - dBx_dy

            J = torch.stack([rot_x, rot_y, rot_z], -1)
            JxB = torch.cross(J, B, dim=-1)

            divB = dBx_dx + dBy_dy + dBz_dz

            force_free_loss = torch.sum(JxB**2, dim=-1) / (torch.sum(B**2, dim=-1) + 1e-7)
            force_free_loss = torch.mean(force_free_loss)
            divergence_loss = torch.sum((divB)**2, dim=-1)
            divergence_loss = torch.mean(divergence_loss)

            loss = self.w_bc*bc_loss + w_ff*force_free_loss + w_div*divergence_loss

            if iter == 0:
                model.eval()
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.mean().detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.mean().detach().cpu().numpy(),
                    'w_ff': w_ff,}, os.path.join(self.base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': self.model,
                    'cube_shape': self.cube_shape,
                    'b_norm': self.b_norm,
                    'spatial_norm': self.spatial_norm,
                    'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))
                model.train()

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            opt.step()

            if (log_interval > 0 and (iter + 1) % log_interval == 0):
                # log loss
                self.log.info('[Iteration %06d/%06d] [loss: %.08f] [bc_loss: %.08f; div_loss: %.08f; ff_loss: %.08f] [w_bc: %f, LR: %f] [%s]' %
                        (iter + 1, total_iterations,
                        loss,
                        self.w_bc*bc_loss,
                        w_ff*force_free_loss,
                        w_div*divergence_loss,
                        self.w_bc,
                        scheduler.get_last_lr()[0],
                        datetime.now() - start_time))
                
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                            'lambda_BC': self.w_bc,
                            'divergence_loss': divergence_loss.detach().cpu().numpy(),
                            'lambda_div': w_div,
                            'force_loss': force_free_loss.detach().cpu().numpy(),
                            'lambda_ff': w_ff,
                            'LR':scheduler.get_last_lr()[0]}, 
                            os.path.join(base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': model,
                            'cube_shape': self.cube_shape,
                            'b_norm': b_norm,
                            'spatial_norm': spatial_norm,
                            'meta_info': meta_info}, 
                            os.path.join(base_path, 'fields_%06d.nf2' % iter))

            # update training parameters
            if self.w_bc > 1:
                self.w_bc *= self.w_bc_decay
            if scheduler.get_last_lr()[0] > 5e-5:
                scheduler.step()

        # save final model state
        torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.detach().cpu().numpy(),
                    'w_ff': w_ff,
                    'LR':scheduler.get_last_lr()[0]}, 
                    os.path.join(base_path, 'loss_final.nf2'))
        torch.save({'model': model,
                    'cube_shape': self.cube_shape,
                    'b_norm': b_norm,
                    'spatial_norm': spatial_norm,
                    'meta_info': meta_info}, 
                    os.path.join(base_path, 'fields_final.nf2'))
        torch.save({'m': model.state_dict(),
                    'o': opt.state_dict(), },
                    os.path.join(base_path, 'model_final.pt'))
        # cleanup
        os.remove(batches_path)

        return self.save_path
    

    def _init_loader(self, batch_size, data, num_workers, iterations):
        # shuffle data
        r = np.random.permutation(data.shape[0])
        data = data[r]
        # adjust to batch size
        pad = batch_size - data.shape[0] % batch_size
        data = np.concatenate([data, data[:pad]])
        # split data into batches
        n_batches = data.shape[0] // batch_size
        batches = np.array(np.split(data, n_batches), dtype=np.float32)
        # store batches to disk
        batches_path = os.path.join(self.work_directory, 'batches.npy')
        np.save(batches_path, batches)
        # create data loaders
        dataset = BoundaryDataset(batches_path)
        # create loader
        data_loader = DataLoader(dataset, batch_size=None, num_workers=num_workers, pin_memory=True,
                                 sampler=RandomSampler(dataset, replacement=True, num_samples=iterations))
        return data_loader, batches_path

    def save(self, iteration):
        torch.save({'model': self.model,
                    'cube_shape': self.cube_shape,
                    'b_norm': self.b_norm,
                    'spatial_norm': self.spatial_norm,
                    'meta_info': self.meta_info}, self.save_path)
        torch.save({'iteration': iteration + 1,
                    'm': self.model.state_dict(),
                    'o': self.opt.state_dict(),
                    'history': self.history,
                    'lambda_B': self.lambda_B},
                   self.checkpoint_path)

In [9]:
trainer = NF2Trainer(base_path, b_bottom, height, spatial_norm, b_norm, 
                     meta_info=meta_info, dim=dim, positional_encoding=positional_encoding, 
                     meta_path=meta_path, use_potential_boundary=use_potential_boundary, 
                     potential_strides=potential_strides, use_vector_potential=use_vector_potential,
                     w_div=w_div, w_ff=w_ff, decay_iterations=decay_iterations,
                     device=device, work_directory=work_directory)

trainer.train(total_iterations, batch_size, 
              log_interval=log_interval, validation_interval=validation_interval, 
              num_workers=num_workers)

Configuration:
dim: 256, w_div: 0.100000, w_ff: 0.100000, decay_iterations: 25000, potential: True, vector_potential: False, 
Using device: cuda (gpus 1) ['NVIDIA GeForce RTX 3060']
Potential Boundary: 100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
Training:   0%|          | 98/50000 [00:06<52:05, 15.97it/s] [Iteration 000100/050000] [loss: 31.24823189] [bc_loss: 31.16707230; div_loss: 0.00014445; ff_loss: 0.08101460] [w_bc: 973.016041, LR: 0.000498] [0:00:06.343937]
Training:   0%|          | 198/50000 [00:12<52:11, 15.90it/s][Iteration 000200/050000] [loss: 29.34061050] [bc_loss: 29.28204727; div_loss: 0.00061701; ff_loss: 0.05794651] [w_bc: 946.498652, LR: 0.000495] [0:00:12.690126]
Training:   0%|          | 200/50000 [00:12<52:32, 15.80it/s]


KeyboardInterrupt: 