# PINN NF2 (custom)
> https://github.com/RobertJaro/NF2

## Low-Lou (19990) NLFFF

In [1]:
import numpy as np
from zpinn.lowloumag import LowLouMag

In [2]:
b = LowLouMag(resolutions=[32, 64, 100])
b

### Low and Lou (1990) NLFFF
bounds = [-1, 1, -1, 1, 0, 2]<br>
resolutions = [32, 64, 100]<br>
n = 1<br>
m = 1<br>
l = 0.3<br>
Phi = 1.5707963267948966<br>


In [3]:
b.calculate()

In [4]:
b.grid

Header,Data Arrays
"UniformGridInformation N Cells193347 N Points204800 X Bounds-1.000e+00, 1.000e+00 Y Bounds-1.000e+00, 1.000e+00 Z Bounds0.000e+00, 2.000e+00 Dimensions32, 64, 100 Spacing6.452e-02, 3.175e-02, 2.020e-02 N Arrays3",NameFieldTypeN CompMinMax BPointsfloat643-1.343e+022.231e+02 magPointsfloat6413.360e-012.258e+02 alphaPointsfloat641-9.682e+009.682e+00

UniformGrid,Information
N Cells,193347
N Points,204800
X Bounds,"-1.000e+00, 1.000e+00"
Y Bounds,"-1.000e+00, 1.000e+00"
Z Bounds,"0.000e+00, 2.000e+00"
Dimensions,"32, 64, 100"
Spacing,"6.452e-02, 3.175e-02, 2.020e-02"
N Arrays,3

Name,Field,Type,N Comp,Min,Max
B,Points,float64,3,-134.3,223.1
mag,Points,float64,1,0.336,225.8
alpha,Points,float64,1,-9.682,9.682


In [5]:
Nx, Ny, _ =  b.grid.dimensions
Nx, Ny

(32, 64)

In [6]:
bottom_subset = (0, Nx-1, 0, Ny-1, 0, 0)
bottom_subset

(0, 31, 0, 63, 0, 0)

In [7]:
bottom = b.grid.extract_subset(bottom_subset).extract_surface()
bottom

Header,Data Arrays
"PolyDataInformation N Cells1953 N Points2048 N Strips0 X Bounds-1.000e+00, 1.000e+00 Y Bounds-1.000e+00, 1.000e+00 Z Bounds0.000e+00, 0.000e+00 N Arrays5",NameFieldTypeN CompMinMax BPointsfloat643-1.343e+022.231e+02 magPointsfloat6412.433e+002.258e+02 alphaPointsfloat641-9.682e+009.682e+00 vtkOriginalPointIdsPointsint6410.000e+002.047e+03 vtkOriginalCellIdsCellsint6410.000e+001.952e+03

PolyData,Information
N Cells,1953
N Points,2048
N Strips,0
X Bounds,"-1.000e+00, 1.000e+00"
Y Bounds,"-1.000e+00, 1.000e+00"
Z Bounds,"0.000e+00, 0.000e+00"
N Arrays,5

Name,Field,Type,N Comp,Min,Max
B,Points,float64,3,-134.3,223.1
mag,Points,float64,1,2.433,225.8
alpha,Points,float64,1,-9.682,9.682
vtkOriginalPointIds,Points,int64,1,0.0,2047.0
vtkOriginalCellIds,Cells,int64,1,0.0,1952.0


In [8]:
b_bottom = bottom['B'].reshape(Nx, Ny, 3)
b_bottom = np.array(b_bottom)
b_bottom.shape

(32, 64, 3)

## PINN

$$
\mathbf{B}(z=0)
$$

In [9]:
b_bottom.shape

(32, 64, 3)

$$
\mathcal{L}_\text{ff}(\boldsymbol{\theta}; \mathcal{T}_f) = \frac{1}{|\mathcal{T}_f|} \sum_{\boldsymbol{x}\in \mathcal{T}_f} \frac{|(\nabla \times \mathbf{\hat{B}})\times \mathbf{\hat{B}}|^2}{|\mathbf{\hat{B}}|^2}
$$

$$
\mathcal{L}_\text{div}(\boldsymbol{\theta}; \mathcal{T}_f) = \frac{1}{|\mathcal{T}_f|} \sum_{\boldsymbol{x}\in \mathcal{T}_f} |\nabla \cdot \mathbf{\hat{B}}|^2
$$

$$
\mathcal{L}_\text{bc}(\boldsymbol{\theta};\mathcal{T}_b)=\frac{1}{|\mathcal{T}_b|}\sum_{\boldsymbol{x}\in\mathcal{T}_b}{|\mathbf{\hat{B}}-\mathbf{B}|^2}
$$

$$
\mathcal{L} = w_{\text{ff}}\mathcal{L}_\text{ff} + w_{\text{div}}\mathcal{L}_\text{div} +  w_{\text{bc}}\mathcal{L}_{\text{bc}}
$$

In [10]:
import torch
from torch import nn
from torch.cuda import get_device_name
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler

import os
import json
import logging
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime

In [11]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

In [12]:
def create_coordinates(bounds):
    xbounds = (bounds[0], bounds[1])
    ybounds = (bounds[2], bounds[3])
    zbounds = (bounds[4], bounds[5])
    meshgrid = np.mgrid[xbounds[0]:xbounds[1]+1, ybounds[0]:ybounds[1]+1, zbounds[0]:zbounds[1]+1]
    return np.stack(meshgrid, axis=-1)

In [13]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)
    
class PositionalEncoding(nn.Module):
    """
    Positional Encoding of the input coordinates.

    encodes x to (..., sin(2^k x), cos(2^k x), ...)
    k takes "num_freqs" number of values equally spaced between [0, max_freq]
    """

    def __init__(self, max_freq, num_freqs):
        """
        Args:
            max_freq (int): maximum frequency in the positional encoding.
            num_freqs (int): number of frequencies between [0, max_freq]
        """
        super().__init__()
        freqs = 2 ** torch.linspace(0, max_freq, num_freqs)
        self.register_buffer("freqs", freqs)  # (num_freqs)

    def forward(self, x):
        """
        Inputs:
            x: (batch, num_samples, in_features)
        Outputs:
            out: (batch, num_samples, 2*num_freqs*in_features)
        """
        x_proj = x.unsqueeze(dim=-2) * self.freqs.unsqueeze(dim=-1)  # (num_rays, num_samples, num_freqs, in_features)
        x_proj = x_proj.reshape(*x.shape[:-1], -1)  # (num_rays, num_samples, num_freqs*in_features)
        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)],
                        dim=-1)  # (num_rays, num_samples, 2*num_freqs*in_features)
        return out

class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim, pos_encoding=False):
        super().__init__()
        if pos_encoding:
            posenc = PositionalEncoding(8, 20)
            d_in = nn.Linear(in_coords * 40, dim)
            self.d_in = nn.Sequential(posenc, d_in)
        else:
            self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

In [14]:
class BoundaryDataset(Dataset):

    def __init__(self, batches_path):
        self.batches_path = batches_path

    def __len__(self):
        return np.load(self.batches_path, mmap_mode='r').shape[0]

    def __getitem__(self, idx):
        # lazy load data
        d = np.load(self.batches_path, mmap_mode='r')[idx]
        d = np.copy(d)
        coord, field = d[:, 0],  d[:, 1]
        return coord, field

In [15]:
class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.array([[0, 0, 1/np.sqrt(2*np.pi)]])
        c = torch.tensor(c, dtype=torch.float64)
        self.register_buffer('c', c)

    def forward(self, r):
        numerator = self.b_n[:, None]
        denominator = torch.sqrt(torch.sum((r[None, :] - self.r_p[:, None] + self.c[None])**2, -1))
        potential = torch.sum(numerator/denominator, 0) / (2*np.pi)
        return potential

In [16]:
def prepare_bc_data(device, b_bottom, height, b_norm, spatial_norm):
    Nx, Ny, _ = b_bottom.shape
    Nz = height

    bottom_values = b_bottom.reshape(-1, 3)
    bottom_values = np.double(bottom_values)
    bottom_coords = create_coordinates((0, Nx-1, 0, Ny-1, 0, 0)).reshape(-1, 3)
    bottom_coords = np.double(bottom_coords)

    top_lateral_coordinates = [create_coordinates((0, Nx-1, 0, Ny-1, Nz-1, Nz-1)).reshape(-1, 3),
                        create_coordinates((0, 0, 0, Ny-1, 0, Nz-1)).reshape(-1, 3),
                        create_coordinates((Nx-1, Nx-1, 0, Ny-1, 0, Nz-1)).reshape(-1, 3),
                        create_coordinates((0, Nx-1, 0, 0, 0, Nz-1)).reshape(-1, 3),
                        create_coordinates((0, Nx-1, Ny-1, Ny-1, 0, Nz-1)).reshape(-1, 3)]

    b_n = torch.tensor(bottom_values[:, 2], dtype=torch.float64)
    r_p = torch.tensor(bottom_coords, dtype=torch.float64)

    model = nn.DataParallel(PotentialModel(b_n, r_p)).to(device)

    pf_fields = []
    pf_coords = []
    for r_coords in top_lateral_coordinates:
        r_coords = torch.tensor(r_coords, dtype=torch.float64)
        pf_batch_size = int(np.prod(r_coords.shape[:-1]) // 10)

        fields = []
        for r, in tqdm(DataLoader(TensorDataset(r_coords), batch_size=pf_batch_size, num_workers=2),
                            desc='Potential Boundary'):
            r = r.to(device).requires_grad_(True)
            p_batch = model(r)
            b_p = -1 * torch.autograd.grad(p_batch, r, torch.ones_like(p_batch), retain_graph=True, create_graph=True)[0]
            fields += [b_p.clone().detach().cpu().numpy()]
        pf_fields += [np.concatenate(fields)]
        pf_coords += [r_coords.clone().detach().cpu().numpy()]

    top_lateral_values = np.concatenate(pf_fields) 
    top_lateral_coords = np.concatenate(pf_coords)

    boundary_values = np.concatenate([top_lateral_values, bottom_values])
    boundary_coords = np.concatenate([top_lateral_coords, bottom_coords])

    normalized_boundary_values = boundary_values / b_norm
    normalized_boundary_coords = boundary_coords / spatial_norm

    boundary_data = np.stack([normalized_boundary_coords, normalized_boundary_values], 1)

    return boundary_data

In [17]:
def create_boundary_batches(boundary_data, batch_size, total_iterations, num_workers):
        # shuffle data
        r = np.random.permutation(boundary_data.shape[0])
        boundary_data = boundary_data[r]
    
        # adjust to batch size
        pad = batch_size - boundary_data.shape[0] % batch_size
        boundary_data = np.concatenate([boundary_data, boundary_data[:pad]])
    
        # split data into batches
        n_batches = boundary_data.shape[0] // batch_size
        boundary_batches = np.array(np.split(boundary_data, n_batches), dtype=np.float32)
    
        # store batches to disk
        boundary_batches_path = 'boundary_batches.npy'
        np.save(boundary_batches_path, boundary_batches)
        # create data loaders
        boundary_dataset = BoundaryDataset(boundary_batches_path)
        # create loader
        boundary_data_loader = DataLoader(boundary_dataset, batch_size=None, num_workers=num_workers, pin_memory=True,
                                 sampler=RandomSampler(boundary_dataset, replacement=True, num_samples=total_iterations))
        return boundary_data_loader, boundary_batches_path

In [18]:
def plot_b_bottom(b_bottom):
    plt.close()
    fig, axes = plt.subplots(nrows=1, ncols=3)
    axes[0].contour(b_bottom[:, :, 2].transpose(), origin='lower', cmap='plasma')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title(r"$B_z(z=0)$")
    axes[0].set_aspect('equal')
    
    axes[1].contourf(b_bottom[:, :, 2].transpose(), origin='lower', cmap='plasma')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    axes[1].set_title(r"$B_z(z=0)$")
    axes[1].set_aspect('equal')
    
    CS = axes[2].imshow(b_bottom[:, :, 2].transpose(), origin='lower', cmap='plasma')
    axes[2].set_xlabel('x')
    axes[2].set_ylabel('y')
    axes[2].set_title(r"$B_z(z=0)$")
    axes[2].set_aspect('equal')
    
    fig.colorbar(CS, ax=axes, orientation='horizontal', pad=0.2)
    plt.show()

In [19]:
class NF2Trainer:
    def __init__(self, device, batch_size, b_bottom, height, b_norm=150, spatial_norm=150, decay_iterations=25000, total_iterations=50000):
        self.device = device
        self.batch_size = batch_size
        
        self.b_bottom = b_bottom
        self.height = height
        self.Nx, self.Ny, _ = self.b_bottom.shape
        self.Nz = height
        self.cube_shape = (self.Nx, self.Ny, self.Nz)
        
        self.b_norm = b_norm
        self.spatial_norm = spatial_norm
        self.total_iterations = total_iterations

        self.boundary_data = prepare_bc_data(self.device, self.b_bottom, self.height, self.b_norm, self.spatial_norm)

        collocation_coords = create_coordinates((0, self.Nx-1, 0, self.Ny-1, 0, self.Nz-1)).reshape(-1, 3)
        normalized_collocation_coords = collocation_coords / self.spatial_norm
        self.normalized_collocation_coords = torch.tensor(normalized_collocation_coords)

        self.Bmodel = nn.DataParallel(BModel(3, 3, 256)).to(device)
        self.opt = torch.optim.Adam(self.Bmodel.parameters(), lr=5e-4)
        self.scheduler = ExponentialLR(self.opt, gamma=(5e-5 / 5e-4) ** (1 / total_iterations))

        self.w_ff = 1
        self.w_div = 1
        self.w_bc = 1000
        self.w_bc_decay = (1 / 1000) ** (1 / decay_iterations) 
        
    def train(self):
        num_workers = os.cpu_count()
        boundary_data_loader, boundary_batches_path = create_boundary_batches(self.boundary_data, self.batch_size, num_workers, self.total_iterations)
        
        model = self.Bmodel
        opt = self.opt
        device = self.device
        w_div, w_ff = self.w_div, self.w_ff

        for iter, (boundary_coords, boundary_b) in tqdm(enumerate(boundary_data_loader, start=0),
                                                                   total=len(data_loader), desc='Training'):
            print(iter)
    
            boundary_coords, boundary_b= boundary_coords.to(device), boundary_b.to(device)
    
            perm = torch.randperm(self.normalized_collocation_coords.shape[0])
            idx = perm[:batch_size]
            co_coords = self.normalized_collocation_coords[idx].to(device)
    
            # concatenate boundary and random points
            # n_boundary_coords = boundary_coords.shape[0]
            # r = torch.cat([boundary_coords, co_coords], 0)
            r = co_coords
            r.requires_grad = True
    
            # forward step
            B = model(r)
    
            # if iter == 0:
            #     model.eval()
            #     torch.save({'model': self.model,
            #         'cube_shape': self.cube_shape,
            #         'b_norm': self.b_norm,
            #         'spatial_norm': self.spatial_norm,
            #         'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))
            #     self.plot_sample(iter-1, batch_size=batch_size)
            #     model.train()
    
            # compute boundary loss
            # boundary_B = B[:n_boundary_coords]
            boundary_B = model(boundary_coords)
            # bc_loss = torch.abs(boundary_B - boundary_b)
            # bc_loss = torch.mean(bc_loss.pow(2).sum(-1))
    
            bc_loss = torch.sum((boundary_B - boundary_b)**2, dim=-1)
            bc_loss = torch.mean(bc_loss)
            # compute div and ff loss
            # divergence_loss, force_free_loss = calculate_loss(b, coords)
    
            dBx_dr = torch.autograd.grad(B[:, 0], r, torch.ones_like(B[:, 0]), retain_graph=True, create_graph=True)[0]
            dBy_dr = torch.autograd.grad(B[:, 1], r, torch.ones_like(B[:, 1]), retain_graph=True, create_graph=True)[0]
            dBz_dr = torch.autograd.grad(B[:, 2], r, torch.ones_like(B[:, 2]), retain_graph=True, create_graph=True)[0]
    
            dBx_dx = dBx_dr[:, 0]
            dBx_dy = dBx_dr[:, 1]
            dBx_dz = dBx_dr[:, 2]
    
            dBy_dx = dBy_dr[:, 0]
            dBy_dy = dBy_dr[:, 1]
            dBy_dz = dBy_dr[:, 2]
    
            dBz_dx = dBz_dr[:, 0]
            dBz_dy = dBz_dr[:, 1]
            dBz_dz = dBz_dr[:, 2]
    
            rot_x = dBz_dy - dBy_dz
            rot_y = dBx_dz - dBz_dx
            rot_z = dBy_dx - dBx_dy
    
            J = torch.stack([rot_x, rot_y, rot_z], -1)
            JxB = torch.cross(J, B, dim=-1)
    
            divB = dBx_dx + dBy_dy + dBz_dz
    
            force_free_loss = torch.sum(JxB**2, dim=-1) / (torch.sum(B**2, dim=-1) + 1e-7)
            force_free_loss = torch.mean(force_free_loss)
            divergence_loss = torch.sum((divB)**2, dim=-1)
            divergence_loss = torch.mean(divergence_loss)
    
            loss = self.w_bc*bc_loss + w_ff*force_free_loss + w_div*divergence_loss
    
            if iter == 0:
                self.log.info('[Iteration %06d/%06d] [loss: %.08f] [bc_loss: %.08f; div_loss: %.08f; ff_loss: %.08f] [w_bc: %f, LR: %f] [%s]' %
                        (iter + 1, total_iterations,
                        loss,
                        self.w_bc*bc_loss,
                        w_ff*force_free_loss,
                        w_div*divergence_loss,
                        self.w_bc,
                        scheduler.get_last_lr()[0],
                        datetime.now() - start_time))
                
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.mean().detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.mean().detach().cpu().numpy(),
                    'w_ff': w_ff,}, os.path.join(self.base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': self.model,
                    'cube_shape': self.cube_shape,
                    'b_norm': self.b_norm,
                    'spatial_norm': self.spatial_norm,
                    'meta_info': self.meta_info}, os.path.join(self.base_path, 'fields_%06d.nf2' % iter))
    
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
            opt.step()
    
            if (log_interval > 0 and (iter + 1) % log_interval == 0):
                # log loss
                self.log.info('[Iteration %06d/%06d] [loss: %.08f] [bc_loss: %.08f; div_loss: %.08f; ff_loss: %.08f] [w_bc: %f, LR: %f] [%s]' %
                        (iter + 1, total_iterations,
                        loss,
                        self.w_bc*bc_loss,
                        w_ff*force_free_loss,
                        w_div*divergence_loss,
                        self.w_bc,
                        scheduler.get_last_lr()[0],
                        datetime.now() - start_time))
                
                torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                            'lambda_BC': self.w_bc,
                            'divergence_loss': divergence_loss.detach().cpu().numpy(),
                            'lambda_div': w_div,
                            'force_loss': force_free_loss.detach().cpu().numpy(),
                            'lambda_ff': w_ff,
                            'LR':scheduler.get_last_lr()[0]}, 
                            os.path.join(base_path, 'loss_%06d.nf2' % iter))
                torch.save({'model': model,
                            'cube_shape': self.cube_shape,
                            'b_norm': b_norm,
                            'spatial_norm': spatial_norm,
                            'meta_info': meta_info}, 
                            os.path.join(base_path, 'fields_%06d.nf2' % iter))
    
            # update training parameters
            if self.w_bc > 1:
                self.w_bc *= self.w_bc_decay
                if self.w_bc <= 1:
                    self.w_bc = 1
            if scheduler.get_last_lr()[0] > 5e-5:
                scheduler.step()
    
        # save final model state
        torch.save({'BC_loss': bc_loss.detach().cpu().numpy(),
                    'w_bc': self.w_bc,
                    'divergence_loss': divergence_loss.detach().cpu().numpy(),
                    'w_div': w_div,
                    'force_loss': force_free_loss.detach().cpu().numpy(),
                    'w_ff': w_ff,
                    'LR':scheduler.get_last_lr()[0]}, 
                    os.path.join(base_path, 'loss_final.nf2'))
        torch.save({'model': model,
                    'cube_shape': self.cube_shape,
                    'b_norm': b_norm,
                    'spatial_norm': spatial_norm,
                    'meta_info': meta_info}, 
                    os.path.join(base_path, 'fields_final.nf2'))
        torch.save({'m': model.state_dict(),
                    'o': opt.state_dict(), },
                    os.path.join(base_path, 'model_final.pt'))
        # cleanup
        os.remove(batches_path)    

In [20]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
batch_size = 10000
height = 100

In [21]:
trainer = NF2Trainer(device, batch_size, b_bottom, height)

Potential Boundary: 100%|█████████████| 11/11 [00:00<00:00, 97.23it/s]
Potential Boundary: 100%|█████████████| 10/10 [00:00<00:00, 90.31it/s]
Potential Boundary: 100%|█████████████| 10/10 [00:00<00:00, 97.85it/s]
Potential Boundary: 100%|████████████| 10/10 [00:00<00:00, 105.74it/s]
Potential Boundary: 100%|████████████| 10/10 [00:00<00:00, 112.86it/s]


In [None]:
trainer.train()

