# PINN NF2 (2nd try)
> https://github.com/RobertJaro/NF2

In [None]:
import numpy as np
from zpinn.lowloumag import LowLouMag

In [None]:
b = LowLouMag(resolutions=[128, 64, 200])
b.calculate()
Nx, Ny, _ =  b.grid.dimensions
bottom_subset = (0, Nx-1, 0, Ny-1, 0, 0)
bottom = b.grid.extract_subset(bottom_subset).extract_surface()
b_bottom = bottom['B'].reshape(Nx, Ny, 3)
b_bottom = np.array(b_bottom)
b_bottom.shape

(128, 64, 3)

In [None]:
import torch
from torch import nn
from torch.cuda import get_device_name
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler

from tqdm import tqdm
from datetime import datetime
import os
import json
import logging

In [None]:
def coords(xbounds, ybounds, zbounds):
    return np.stack(np.mgrid[xbounds[0]:xbounds[1]+1, ybounds[0]:ybounds[1]+1, zbounds[0]:zbounds[1]+1], axis=-1).astype(np.float32)

In [None]:
class BoundaryDataset(Dataset):

    def __init__(self, batches_path):
        self.batches_path = batches_path

    def __len__(self):
        return np.load(self.batches_path, mmap_mode='r').shape[0]

    def __getitem__(self, idx):
        # lazy load data
        d = np.load(self.batches_path, mmap_mode='r')[idx]
        d = np.copy(d)
        coord, field = d[:, 0],  d[:, 1]
        return coord, field

In [None]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)
    
class PositionalEncoding(nn.Module):
    """
    Positional Encoding of the input coordinates.

    encodes x to (..., sin(2^k x), cos(2^k x), ...)
    k takes "num_freqs" number of values equally spaced between [0, max_freq]
    """

    def __init__(self, max_freq, num_freqs):
        """
        Args:
            max_freq (int): maximum frequency in the positional encoding.
            num_freqs (int): number of frequencies between [0, max_freq]
        """
        super().__init__()
        freqs = 2 ** torch.linspace(0, max_freq, num_freqs)
        self.register_buffer("freqs", freqs)  # (num_freqs)

    def forward(self, x):
        """
        Inputs:
            x: (batch, num_samples, in_features)
        Outputs:
            out: (batch, num_samples, 2*num_freqs*in_features)
        """
        x_proj = x.unsqueeze(dim=-2) * self.freqs.unsqueeze(dim=-1)  # (num_rays, num_samples, num_freqs, in_features)
        x_proj = x_proj.reshape(*x.shape[:-1], -1)  # (num_rays, num_samples, num_freqs*in_features)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)],
                        dim=-1)  # (num_rays, num_samples, 2*num_freqs*in_features)
        return out

class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim, pos_encoding=False):
        super().__init__()
        if pos_encoding:
            posenc = PositionalEncoding(8, 20)
            d_in = nn.Linear(in_coords * 40, dim)
            self.d_in = nn.Sequential(posenc, d_in)
        else:
            self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

In [None]:
class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.array([[0, 0, 1/np.sqrt(2*np.pi)]])
        c = torch.tensor(c, dtype=torch.float64)
        self.register_buffer('c', c)

    def forward(self, r):
        numerator = self.b_n[:, None]
        denominator = torch.sqrt(torch.sum((r[None, :] - self.r_p[:, None] + self.c[None])**2, -1))
        potential = torch.sum(numerator/denominator, 0) / (2*np.pi)
        return potential

In [None]:
def prepare_bc_data(b_bottom, height, device, b_norm, spatial_norm):
    Nx, Ny, _ = b_bottom.shape
    Nz = height

    bottom_values = b_bottom.reshape(-1, 3)
    bottom_values = np.double(bottom_values)
    bottom_coords = coords((0, Nx-1), (0, Ny-1), (0, 0)).reshape(-1, 3)
    bottom_coords = np.double(bottom_coords)

    top_lateral_coordinates = [coords((0, Nx-1), (0, Ny-1), (Nz-1, Nz-1)).reshape(-1, 3),
                        coords((0, 0), (0, Ny-1), (0, Nz-1)).reshape(-1, 3),
                        coords((Nx-1, Nx-1), (0, Ny-1), (0, Nz-1)).reshape(-1, 3),
                        coords((0, Nx-1), (0, 0), (0, Nz-1)).reshape(-1, 3),
                        coords((0, Nx-1), (Ny-1, Ny-1), (0, Nz-1)).reshape(-1, 3)]

    b_n = torch.tensor(bottom_values[:, 2], dtype=torch.float64)
    r_p = torch.tensor(bottom_coords, dtype=torch.float64)

    model = nn.DataParallel(PotentialModel(b_n, r_p)).to(device)

    pf_fields = []
    pf_coords = []
    for r_coords in top_lateral_coordinates:
        r_coords = torch.tensor(r_coords, dtype=torch.float64)
        pf_batch_size = int(np.prod(r_coords.shape[:-1]) // 10)

        fields = []
        for r, in tqdm(DataLoader(TensorDataset(r_coords), batch_size=pf_batch_size, num_workers=2),
                            desc='Potential Boundary'):
            r = r.to(device).requires_grad_(True)
            p_batch = model(r)
            b_p = -1 * torch.autograd.grad(p_batch, r, torch.ones_like(p_batch), retain_graph=True, create_graph=True)[0]
            fields += [b_p.clone().detach().cpu().numpy()]
        pf_fields += [np.concatenate(fields)]
        pf_coords += [r_coords.clone().detach().cpu().numpy()]

    top_lateral_values = np.concatenate(pf_fields) 
    top_lateral_coords = np.concatenate(pf_coords)

    boundary_values = np.concatenate([top_lateral_values, bottom_values])
    boundary_coords = np.concatenate([top_lateral_coords, bottom_coords])

    normalized_boundary_values = boundary_values / b_norm
    normalized_boundary_coords = boundary_coords / spatial_norm

    boundary_data = np.stack([normalized_boundary_coords, normalized_boundary_values], 1)

    return boundary_data

In [None]:
class NF2Trainer:

    def __init__(self, base_path, b_bottom, height, spatial_norm, b_norm, meta_info=None, dim=256,
                 positional_encoding=False, meta_path=None, use_potential_boundary=True, potential_strides=1, use_vector_potential=False,
                 w_div=0.1, w_ff=0.1, decay_iterations=None,
                 device=None, work_directory=None):
        
        # logging
        self.log = logging.getLogger()
        self.log.setLevel(logging.INFO)
        for hdlr in self.log.handlers[:]:  # remove all old handlers
            self.log.removeHandler(hdlr)
        self.log.addHandler(logging.FileHandler("{0}/{1}.log".format(base_path, "info_log")))  # set the new file handler
        self.log.addHandler(logging.StreamHandler())  # set the new console handler
        self.log.info('Configuration:')
        self.log.info(
            'dim: %d, w_div: %f, w_ff: %f, decay_iterations: %s, potential: %s, vector_potential: %s, ' % (
                dim, w_div, w_ff, str(decay_iterations), str(use_potential_boundary),
                str(use_vector_potential)))
        
        # path
        self.base_path = base_path
        os.makedirs(self.base_path, exist_ok=True)
        self.checkpoint_path = os.path.join(base_path, 'checkpoint.pt')
        work_directory = base_path if work_directory is None else work_directory
        self.work_directory = work_directory

        # info
        self.spatial_norm = spatial_norm
        self.height = height
        self.b_norm = b_norm
        self.meta_info = meta_info

        n_gpus = torch.cuda.device_count()
        device_names = [get_device_name(i) for i in range(n_gpus)]
        if device is None:
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.log.info('Using device: %s (gpus %d) %s' % (str(device), n_gpus, str(device_names)))
        self.device = device

        # prepare data
        self.b_bottom = b_bottom

        # load dataset
        self.data = prepare_bc_data(b_bottom, height, device, b_norm, spatial_norm)
        # self.cube_shape = [*b_bottom.shape[:-1], height]

        Nx, Ny, _ = b_bottom.shape
        Nz = height

        self.cube_shape = (Nx, Ny, Nz)

        # init model
        if use_vector_potential:
            model = VectorPotentialModel(3, dim, pos_encoding=positional_encoding)
        else:
            model = BModel(3, 3, dim, pos_encoding=positional_encoding)
        parallel_model = nn.DataParallel(model)
        parallel_model.to(device)
        opt = torch.optim.Adam(parallel_model.parameters(), lr=5e-4)
        self.model = model
        self.parallel_model = parallel_model

        # load last state
        if os.path.exists(self.checkpoint_path):
            state_dict = torch.load(self.checkpoint_path, map_location=device)
            start_iteration = state_dict['iteration']
            model.load_state_dict(state_dict['m'])
            opt.load_state_dict(state_dict['o'])
            history = state_dict['history']
            w_bc = state_dict['w_bc']
            self.log.info('Resuming training from iteration %d' % start_iteration)
        else:
            if meta_path:
                state_dict = torch.load(meta_path, map_location=device)['model'].state_dict() \
                    if meta_path.endswith('nf2') else torch.load(meta_path, map_location=device)['m']
                model.load_state_dict(state_dict)
                opt = torch.optim.Adam(parallel_model.parameters(), lr=5e-5)
                self.log.info('Loaded meta state: %s' % meta_path)
            # init
            start_iteration = 0
            w_bc = 1000 if decay_iterations else 1
            history = {'iteration': [], 'height': [],
                       'b_loss': [], 'divergence_loss': [], 'force_loss': [], 'sigma_angle': []}

        self.opt = opt
        self.start_iteration = start_iteration
        self.history = history
        self.w_bc = w_bc
        self.w_bc_decay = (1 / 1000) ** (1 / decay_iterations) if decay_iterations is not None else 1
        self.w_div, self.w_ff = w_div, w_ff

        collocation_coords = coords((0, Nx-1), (0, Ny-1), (0, Nz-1)).reshape(-1, 3)
        normalized_collocation_coords = collocation_coords / self.spatial_norm
        self.normalized_collocation_coords = torch.tensor(normalized_collocation_coords)

    def train(self, total_iterations, batch_size, log_interval=100, validation_interval=100, num_workers=None):
        """Start magnetic field extrapolation fit.

        :param total_iterations: number of iterations for training.
        :param batch_size: number of samples per iteration.
        :param log_interval: log training details every nth iteration.
        :param validation_interval: evaluate simulation every nth iteration.
        :param num_workers: number of workers for data loading (default system spec).
        :return: path of the final save state.
        """
        start_time = datetime.now()
        num_workers = os.cpu_count() if num_workers is None else num_workers

        model = self.parallel_model
        opt = self.opt
        device = self.device
        w_div, w_ff = self.w_div, self.w_ff

        # init
        scheduler = ExponentialLR(opt, gamma=(5e-5 / 5e-4) ** (1 / total_iterations))
        iterations = total_iterations - self.start_iteration
        if iterations <= 0:
            self.log.info('Training already finished!')
            return self.save_path

        # init loader
        data_loader, batches_path = self._init_loader(batch_size, self.data, num_workers, iterations)
        
        model.train()
        for iter, (boundary_coords, boundary_b) in tqdm(enumerate(data_loader, start=self.start_iteration),
                                                           total=len(data_loader), desc='Training'):

            boundary_coords, boundary_b= boundary_coords.to(device), boundary_b.to(device)

            perm = torch.randperm(self.normalized_collocation_coords.shape[0])
            idx = perm[:batch_size]
            co_coords = self.normalized_collocation_coords[idx].to(device)

            # concatenate boundary and random points
            # n_boundary_coords = boundary_coords.shape[0]
            # r = torch.cat([boundary_coords, co_coords], 0)
            r = co_coords
            r.requires_grad = True

            # forward step
            B = model(r)

            # if iter == 0:
            #     model.eval()
            #     torch.save({'model': self.model,
            #         'cube_shape': self.cube_shape,
            #         'b_norm': self.b_norm,
            #         'spatial_norm': self.spatial_norm,
            #         'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))
            #     self.plot_sample(iter-1, batch_size=batch_size)
            #     model.train()

            # compute boundary loss
            # boundary_B = B[:n_boundary_coords]
            boundary_B = model(boundary_coords)
            # bc_loss = torch.abs(boundary_B - boundary_b)
            # bc_loss = torch.mean(bc_loss.pow(2).sum(-1))

            bc_loss = torch.sum((boundary_B - boundary_b)**2, dim=-1)
            bc_loss = torch.mean(bc_loss)
            # compute div and ff loss
            # divergence_loss, force_free_loss = calculate_loss(b, coords)

            dBx_dr = torch.autograd.grad(B[:, 0], r, torch.ones_like(B[:, 0]), retain_graph=True, create_graph=True)[0]
            dBy_dr = torch.autograd.grad(B[:, 1], r, torch.ones_like(B[:, 1]), retain_graph=True, create_graph=True)[0]
            dBz_dr = torch.autograd.grad(B[:, 2], r, torch.ones_like(B[:, 2]), retain_graph=True, create_graph=True)[0]

            dBx_dx = dBx_dr[:, 0]
            dBx_dy = dBx_dr[:, 1]
            dBx_dz = dBx_dr[:, 2]

            dBy_dx = dBy_dr[:, 0]
            dBy_dy = dBy_dr[:, 1]
            dBy_dz = dBy_dr[:, 2]

            dBz_dx = dBz_dr[:, 0]
            dBz_dy = dBz_dr[:, 1]
            dBz_dz = dBz_dr[:, 2]

            rot_x = dBz_dy - dBy_dz
            rot_y = dBx_dz - dBz_dx
            rot_z = dBy_dx - dBx_dy

            J = torch.stack([rot_x, rot_y, rot_z], -1)
            JxB = torch.cross(J, B, dim=-1)

            divB = dBx_dx + dBy_dy + dBz_dz

            force_free_loss = torch.sum(JxB**2, dim=-1) / (torch.sum(B**2, dim=-1) + 1e-7)
            force_free_loss = torch.mean(force_free_loss)
            divergence_loss = torch.sum((divB)**2, dim=-1)
            divergence_loss = torch.mean(divergence_loss)

            loss = self.w_bc*bc_loss + w_ff*force_free_loss + w_div*divergence_loss

            if iter == 0:
                self.log.info('[Iteration %06d/%06d] [loss: %.08f] [bc_loss: %.08f; div_loss: %.08f; ff_loss: %.08f] [w_bc: %f, LR: %f] [%s]' %
                        (iter + 1, total_iterations,
                        loss,
                        self.w_bc*bc_loss,
                        w_ff*force_free_loss,
                        w_div*divergence_loss,
                        self.w_bc,
                        scheduler.get_last_lr()[0],
                        datetime.now() - start_time))
                
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.mean().detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.mean().detach().cpu().numpy(),
                    'w_ff': w_ff,}, os.path.join(self.base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': self.model,
                    'cube_shape': self.cube_shape,
                    'b_norm': self.b_norm,
                    'spatial_norm': self.spatial_norm,
                    'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            opt.step()

            if (log_interval > 0 and (iter + 1) % log_interval == 0):
                # log loss
                self.log.info('[Iteration %06d/%06d] [loss: %.08f] [bc_loss: %.08f; div_loss: %.08f; ff_loss: %.08f] [w_bc: %f, LR: %f] [%s]' %
                        (iter + 1, total_iterations,
                        loss,
                        self.w_bc*bc_loss,
                        w_ff*force_free_loss,
                        w_div*divergence_loss,
                        self.w_bc,
                        scheduler.get_last_lr()[0],
                        datetime.now() - start_time))
                
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                            'lambda_BC': self.w_bc,
                            'divergence_loss': divergence_loss.detach().cpu().numpy(),
                            'lambda_div': w_div,
                            'force_loss': force_free_loss.detach().cpu().numpy(),
                            'lambda_ff': w_ff,
                            'LR':scheduler.get_last_lr()[0]}, 
                            os.path.join(base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': model,
                            'cube_shape': self.cube_shape,
                            'b_norm': b_norm,
                            'spatial_norm': spatial_norm,
                            'meta_info': meta_info}, 
                            os.path.join(base_path, 'fields_%06d.nf2' % iter))

            # update training parameters
            if self.w_bc > 1:
                self.w_bc *= self.w_bc_decay
                if self.w_bc <= 1:
                    self.w_bc = 1
            if scheduler.get_last_lr()[0] > 5e-5:
                scheduler.step()

        # save final model state
        torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.detach().cpu().numpy(),
                    'w_ff': w_ff,
                    'LR':scheduler.get_last_lr()[0]}, 
                    os.path.join(base_path, 'loss_final.nf2'))
        torch.save({'model': model,
                    'cube_shape': self.cube_shape,
                    'b_norm': b_norm,
                    'spatial_norm': spatial_norm,
                    'meta_info': meta_info}, 
                    os.path.join(base_path, 'fields_final.nf2'))
        torch.save({'m': model.state_dict(),
                    'o': opt.state_dict(), },
                    os.path.join(base_path, 'model_final.pt'))
        # cleanup
        os.remove(batches_path)    

    def _init_loader(self, batch_size, data, num_workers, iterations):
        # shuffle data
        r = np.random.permutation(data.shape[0])
        data = data[r]
        # adjust to batch size
        pad = batch_size - data.shape[0] % batch_size
        data = np.concatenate([data, data[:pad]])
        # split data into batches
        n_batches = data.shape[0] // batch_size
        batches = np.array(np.split(data, n_batches), dtype=np.float32)
        # store batches to disk
        batches_path = os.path.join(self.work_directory, 'batches.npy')
        np.save(batches_path, batches)
        # create data loaders
        dataset = BoundaryDataset(batches_path)
        # create loader
        data_loader = DataLoader(dataset, batch_size=None, num_workers=num_workers, pin_memory=True,
                                 sampler=RandomSampler(dataset, replacement=True, num_samples=iterations))
        return data_loader, batches_path

    def save(self, iteration):
        torch.save({'model': self.model,
                    'cube_shape': self.cube_shape,
                    'b_norm': self.b_norm,
                    'spatial_norm': self.spatial_norm,
                    'meta_info': self.meta_info}, self.save_path)
        torch.save({'iteration': iteration + 1,
                    'm': self.model.state_dict(),
                    'o': self.opt.state_dict(),
                    'history': self.history,
                    'lambda_B': self.lambda_B},
                   self.checkpoint_path)

In [None]:
config_path = 'config_run.json'

with open(config_path) as config:
    info = json.load(config)

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= info['simul']['gpu_id']

n = info['exact']['n']
m = info['exact']['m']
l = info['exact']['l']
psi = eval(info['exact']['psi'])
resolution = info['exact']['resolution']
bounds = info['exact']['bounds']


base_path = os.path.join(info['simul']['base_path'], "run/")
meta_path = info['simul']['meta_path']

bin = info['simul']['bin']

height = info['simul']['height']
spatial_norm = info['simul']['spatial_norm']
b_norm = info['simul']['b_norm']

meta_info = info['simul']['meta_info']
dim = info['simul']['dim']
positional_encoding = info['simul']['positional_encoding']
use_potential_boundary = info['simul']['use_potential_boundary']
potential_strides = info['simul']['potential_strides']
use_vector_potential = info['simul']['use_vector_potential']
w_div = info['simul']['w_div']
w_ff = info['simul']['w_ff']
decay_iterations = info['simul']['decay_iterations']
device = info['simul']['device']
work_directory = info['simul']['work_directory']

total_iterations = info['simul']['total_iterations']
batch_size = info['simul']['batch_size']
log_interval = info['simul']['log_interval']
validation_interval = info['simul']['validation_interval']
num_workers = info['simul']['num_workers']

In [None]:
os.makedirs(base_path, exist_ok=True)

trainer = NF2Trainer(base_path, b_bottom, height, spatial_norm, b_norm, 
                     meta_info=meta_info, dim=dim, positional_encoding=positional_encoding, 
                     meta_path=meta_path, use_potential_boundary=use_potential_boundary, 
                     potential_strides=potential_strides, use_vector_potential=use_vector_potential,
                     w_div=w_div, w_ff=w_ff, decay_iterations=decay_iterations,
                     device=device, work_directory=work_directory)

trainer.train(total_iterations, batch_size, 
              log_interval=log_interval, validation_interval=validation_interval, 
              num_workers=num_workers)

Configuration:
dim: 256, w_div: 0.100000, w_ff: 0.100000, decay_iterations: 25000, potential: True, vector_potential: False, 
Using device: cuda (gpus 1) ['NVIDIA GeForce RTX 3060']
Potential Boundary: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 11.26it/s]
Potential Boundary: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 80.83it/s]
Potential Boundary: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 85.27it/s]
Potential Boundary: 100%|█████████████████████████

KeyboardInterrupt: 