In [None]:
import os
import logging
import numpy as np
import netCDF4

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from datetime import datetime
from astropy.nddata import block_reduce

import torch
from torch import nn
from torch.cuda import get_device_name
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import Dataset, TensorDataset

from tqdm import tqdm

# ISEE

In [2]:
class nlfff:

      def __init__(self,filename):
            self.filename=filename

            nc=netCDF4.Dataset(self.filename,'r')
            self.NOAA=nc.NOAA
            self.year_month_day_time=nc.year_month_day_time
            self.project=nc.project
            self.production_date=nc.production_date
            self.version=nc.version
            self.data_doi=nc.data_doi
            self.http_link=nc.http_link
            self.Distributor=nc.Distributor
            
            nc_x=nc.variables['x']
            self.x=nc_x[:]
            print(nc_x.long_name,' unit:',nc_x.units)
            nc_y=nc.variables['y']
            self.y=nc_y[:]
            print(nc_y.long_name,' unit:',nc_y.units)
            nc_z=nc.variables['z']
            self.z=nc_z[:]
            print(nc_z.long_name,' unit:',nc_z.units)
            
            nc_bx=nc.variables['Bx']
            self.bx=nc_bx[:].transpose(2,1,0)
            print(nc_bx.long_name,' unit:',nc_bx.units)
            nc_by=nc.variables['By']
            self.by=nc_by[:].transpose(2,1,0)
            print(nc_by.long_name,' unit:',nc_by.units)
            nc_bz=nc.variables['Bz']
            self.bz=nc_bz[:].transpose(2,1,0)
            print(nc_bz.long_name,' unit:',nc_bz.units)
            
            nc_bxp=nc.variables['Bx_pot']
            self.bx_pot=nc_bxp[:].transpose(2,1,0)
            print(nc_bxp.long_name,' unit:',nc_bxp.units)
            nc_byp=nc.variables['By_pot']
            self.by_pot=nc_byp[:].transpose(2,1,0)
            print(nc_byp.long_name,' unit:',nc_byp.units)
            nc_bzp=nc.variables['Bz_pot']
            self.bz_pot=nc_bzp[:].transpose(2,1,0)
            print(nc_bzp.long_name,' unit:',nc_bzp.units)
            
      def info(self):
            self.Lx_Mm=max(self.x) - min(self.x)
            self.Ly_Mm=max(self.y) - min(self.y)
            print(f'(Lx, Ly) in Mm = ({self.Lx_Mm:.2f}, {self.Ly_Mm:.2f})\n')
            print(f"NOAA",self.NOAA)
            print(f'year_month_day_time',self.year_month_day_time)
            print(f"project",self.project)
            print(f"production_date",self.production_date)
            print(f"version",self.version)
            print(f"data_doi",self.data_doi)
            print(f"http_link",self.http_link)
            print(f"Distributor",self.Distributor)

      def plot(self):
            xs=12.0
            ys=4.0

            xmin=min(self.x)
            xmax=max(self.x)
            ymin=min(self.y)
            ymax=max(self.y)

            plt.close()
            fig=plt.figure(figsize=(xs,ys))
            ax1=fig.add_axes((0.08,0.35,0.25,0.25*xs/ys*(ymax-ymin)/(xmax-xmin)))
            ax2=fig.add_axes((0.4,0.35,0.25,0.25*xs/ys*(ymax-ymin)/(xmax-xmin)))
            ax3=fig.add_axes((0.72,0.35,0.25,0.25*xs/ys*(ymax-ymin)/(xmax-xmin)))
            cax1=fig.add_axes((0.08,0.15,0.25,0.05))
            cax2=fig.add_axes((0.4,0.15,0.25,0.05))
            cax3=fig.add_axes((0.72,0.15,0.25,0.05))
            
            vmin=-3000.0 
            vmax=3000.0
            
            im1=ax1.pcolormesh(self.x,self.y,self.bx[:,:,0].transpose(),vmin=vmin,vmax=vmax,cmap='gist_gray',shading='auto')
            im2=ax2.pcolormesh(self.x,self.y,self.by[:,:,0].transpose(),vmin=vmin,vmax=vmax,cmap='gist_gray',shading='auto')
            im3=ax3.pcolormesh(self.x,self.y,self.bz[:,:,0].transpose(),vmin=vmin,vmax=vmax,cmap='gist_gray',shading='auto')

            cbar1=plt.colorbar(im1,cax=cax1,orientation='horizontal')
            cbar2=plt.colorbar(im2,cax=cax2,orientation='horizontal')
            cbar3=plt.colorbar(im3,cax=cax3,orientation='horizontal')
            
            ax1.set_title('Bx [G]')
            ax1.set_xlabel('x [Mm]')
            ax1.set_ylabel('y [Mm]')
            
            ax2.set_title('By [G]')
            ax2.set_xlabel('x [Mm]')
            ax2.set_ylabel('y [Mm]')
            
            ax3.set_title('Bz [G]')
            ax3.set_xlabel('x [Mm]')
            ax3.set_ylabel('y [Mm]')
            
            #plt.pause(0.1)
            # plt.savefig('./B.png')

In [3]:
# nc_filepath = '/home/tensor/workspace/pinn_study/_data/12673_20170905_202400/12673_20170905_202400.nc'
# nc_filepath = '/Users/mgjeon/Workspace/pinn_study/_data/12673_20170905_202400.nc'
nc_filepath = '/nas/obsdata/isee_nlfff_v1.2/12673/12673_20170905_202400.nc'
data = nlfff(nc_filepath)

x (westward)  unit: Mm
y (northward)  unit: Mm
z (out ot photosphere)  unit: Mm
Bx (westward)  unit: G
By (northward)  unit: G
Bz (out of photosphere)  unit: G
Bx_pot (westward)  unit: G
By_pot (northward)  unit: G
Bz_pot (out of photosphere)  unit: G


In [4]:
Lx = max(data.x) - min(data.x)
Ly = max(data.y) - min(data.y)
Lz = max(data.z) - min(data.z)

Nx = len(data.x)
Ny = len(data.y)
Nz = len(data.z)

dx = np.diff(data.x)[0]
dy = np.diff(data.y)[0]
dz = np.diff(data.z)[0]

print(f'Lx: {Lx}')
print(f'Ly: {Ly}')
print(f'Lz: {Lz}')

print(f'Nx: {Nx}')
print(f'Ny: {Ny}')
print(f'Nz: {Nz}')

# Mm per pixel
print(f'dx: {dx}')
print(f'dy: {dy}')
print(f'dz: {dz}')

# Mm^3 per pixel^3
dV = dx*dy*dz 
# cm^3 per pixel^3
dV = dx*dy*dz*(1e8**3)
print(f'dV: {dV}')

# (Nx-1) : # of pixels in x-direction
# (Nx-1)*dx : Mm in x-direction
print(f'(Nx-1)*dx: {(Nx-1)*dx}')
print(f'(Ny-1)*dy: {(Ny-1)*dy}')
print(f'(Nz-1)*dz: {(Nz-1)*dz}')

Lx: 250.724242944
Ly: 163.262298624
Lz: 163.262298624
Nx: 513
Ny: 257
Nz: 257
dx: 0.48969578700000227
dy: 0.6377433539999942
dz: 0.637743354
dV: 1.9916739845722572e+23
(Nx-1)*dx: 250.72424294400116
(Ny-1)*dy: 163.2622986239985
(Nz-1)*dz: 163.262298624


# NF2

In [5]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

# JSON
# base_path = '/home/tensor/workspace/pinn_study/_run/isee_spinn'
base_path = '/userhome/jeon_mg/workspace/pinn_study/_run/isee_plain'
# base_path = '/Users/mgjeon/Workspace/pinn_study/_run/isee_spinn'
# work_directory = '/home/tensor/workspace/pinn_study/_run/'
work_directory = '/userhome/jeon_mg/workspace/pinn_study/_run/'
meta_path = None

# d_slice = None
bin = 1

height = 257
spatial_norm = 320
b_norm = 2500

dim = 256

lambda_div = 1e-1
lambda_ff = 1e-1
decay_iterations = 50000

total_iterations = 100000
batch_size = 10000
log_interval = 10000
validation_interval = 10000
num_workers = 4

meta_info = None
positional_encoding = False
use_potential_boundary = True
potential_strides = 1
use_vector_potential = False

device = None

# init logging
os.makedirs(base_path, exist_ok=True)
log = logging.getLogger()
log.setLevel(logging.INFO)
for hdlr in log.handlers[:]:  # remove all old handlers
    log.removeHandler(hdlr)
log.addHandler(logging.FileHandler("{0}/{1}.log".format(base_path, "info_log")))  # set the new file handler
log.addHandler(logging.StreamHandler())  # set the new console handler

start_time = datetime.now()
base_path = os.path.join(base_path, 'dim%d_bin%d_pf%s_ld%s_lf%s' % (
        dim, bin, str(use_potential_boundary), lambda_div, lambda_ff))

os.makedirs(base_path, exist_ok=True)

b_cube = np.array(np.stack([data.bx[:, :, 0], data.by[:, :, 0], data.bz[:, :, 0]], axis=-1))
meta_info = None

# if d_slice is not None:
#     b_cube = b_cube[d_slice[0]:d_slice[1], d_slice[2]:d_slice[3]]

# if bin > 1:
#     b_cube = block_reduce(b_cube, (bin, bin, 1), np.mean)

In [6]:
mf_coords = np.stack(np.mgrid[:b_cube.shape[0], :b_cube.shape[1], :1], -1)
mf_coords = mf_coords.reshape((-1, 3))
mf_values = b_cube.reshape((-1, 3))

In [7]:
class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.zeros((1, 3))
        c[:, 2] = (1 / np.sqrt(2 * np.pi))
        c = torch.tensor(c, dtype=torch.float32, )
        self.register_buffer('c', c)

    def forward(self, coord):
        v1 = self.b_n[:, None]
        v2 = 2 * np.pi * ((-self.r_p[:, None] + coord[None, :] + self.c[None]) ** 2).sum(-1) ** 0.5
        potential = torch.sum(v1 / v2, dim=0)
        return potential

In [8]:
pf_batch_size = int(1024 * 512 ** 2 / np.prod(b_cube.shape[:2]))

b_n = b_cube[:, :, 2]
cube_shape = (*b_n.shape, height)
b_n = b_n.reshape((-1)).astype(np.float32)
coords = [np.stack(np.mgrid[:cube_shape[0], :cube_shape[1], cube_shape[2] - 2:cube_shape[2] + 1], -1),
            np.stack(np.mgrid[:cube_shape[0], -1:2, :cube_shape[2]], -1),
            np.stack(np.mgrid[:cube_shape[0], cube_shape[1] - 2:cube_shape[1] + 1, :cube_shape[2]], -1),
            np.stack(np.mgrid[-1:2, :cube_shape[1], :cube_shape[2]], -1),
            np.stack(np.mgrid[cube_shape[0] - 2:cube_shape[0] + 1, :cube_shape[1], :cube_shape[2]], -1), ]
coords_shape = [c.shape[:-1] for c in coords]
flat_coords = np.concatenate([c.reshape(((-1, 3))) for c in coords])

r_p = np.stack(np.mgrid[:cube_shape[0], :cube_shape[1], :1], -1).reshape((-1, 3))

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
with torch.no_grad():
    b_n = torch.tensor(b_n, dtype=torch.float32, )
    r_p = torch.tensor(r_p, dtype=torch.float32, )
    model = nn.DataParallel(PotentialModel(b_n, r_p, )).to(device)

    flat_coords = torch.tensor(flat_coords, dtype=torch.float32, )

    potential = []
    for coord, in tqdm(DataLoader(TensorDataset(flat_coords), batch_size=pf_batch_size, num_workers=2),
                        desc='Potential Boundary'):
        coord = coord.to(device)
        p_batch = model(coord)
        potential += [p_batch.cpu()]

potential = torch.cat(potential).numpy()
idx = 0
fields = []
for s in coords_shape:
    p = potential[idx:idx + np.prod(s)].reshape(s)
    b = - 1 * np.stack(np.gradient(p, axis=[0, 1, 2], edge_order=2), axis=-1)
    fields += [b]
    idx += np.prod(s)

fields = [fields[0][:, :, 1].reshape((-1, 3)),
            fields[1][:, 1, :].reshape((-1, 3)), fields[2][:, 1, :].reshape((-1, 3)),
            fields[3][1, :, :].reshape((-1, 3)), fields[4][1, :, :].reshape((-1, 3))]
coords = [coords[0][:, :, 1].reshape((-1, 3)),
            coords[1][:, 1, :].reshape((-1, 3)), coords[2][:, 1, :].reshape((-1, 3)),
            coords[3][1, :, :].reshape((-1, 3)), coords[4][1, :, :].reshape((-1, 3))]

pf_coords, pf_values = np.concatenate(coords), np.concatenate(fields)

Potential Boundary: 100%|██████████| 778/778 [00:33<00:00, 23.15it/s]


In [9]:
pf_values = np.array(pf_values, dtype=np.float32)
pf_coords = np.array(pf_coords, dtype=np.float32)

In [10]:
coords = np.concatenate([pf_coords, mf_coords])
values = np.concatenate([pf_values, mf_values])
coords = coords.astype(np.float32)
values = values.astype(np.float32)

In [11]:
values

array([[ -1.0869141 ,   0.41046143,  -0.59832   ],
       [ -1.0871735 ,   0.41107178,  -0.6021805 ],
       [ -1.0874405 ,   0.4116516 ,  -0.6060486 ],
       ...,
       [ 64.65      , -66.61      ,  13.29      ],
       [ 92.42      ,  35.37      , -30.01      ],
       [  0.        ,   0.        ,   0.        ]], dtype=float32)

In [12]:
values = Normalize(-b_norm, b_norm, clip=False)(values) * 2 - 1

In [13]:
values

masked_array(
  data=[[-0.00043476,  0.00016415, -0.00023937],
        [-0.00043488,  0.00016451, -0.00024092],
        [-0.00043494,  0.00016463, -0.00024241],
        ...,
        [ 0.02585995, -0.02664405,  0.00531602],
        [ 0.03696799,  0.014148  , -0.01200402],
        [ 0.        ,  0.        ,  0.        ]],
  mask=False,
  fill_value=1e+20,
  dtype=float32)

In [14]:
coords

array([[  0.,   0., 256.],
       [  0.,   1., 256.],
       [  0.,   2., 256.],
       ...,
       [512., 254.,   0.],
       [512., 255.,   0.],
       [512., 256.,   0.]], dtype=float32)

In [15]:
coords = coords / spatial_norm

In [16]:
coords

array([[0.      , 0.      , 0.8     ],
       [0.      , 0.003125, 0.8     ],
       [0.      , 0.00625 , 0.8     ],
       ...,
       [1.6     , 0.79375 , 0.      ],
       [1.6     , 0.796875, 0.      ],
       [1.6     , 0.8     , 0.      ]], dtype=float32)

In [17]:
data = np.stack([coords, values], 1)

In [18]:
cube_shape = [*b_cube.shape[:-1], height]

In [19]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)
    
class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim):
        super().__init__()
        self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        b = self.d_out(x)
        return b

In [20]:
model = BModel(3, 3, dim).to(device)
parallel_model = nn.DataParallel(model)
parallel_model.to(device)
opt = torch.optim.Adam(parallel_model.parameters(), lr=5e-4)

In [21]:
class RandomCoordinateSampler():

    def __init__(self, cube_shape, spatial_norm, batch_size, cuda=True):
        self.cube_shape = cube_shape
        self.spatial_norm = spatial_norm
        self.batch_size = batch_size
        self.float_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    def load_sample(self):
        random_coords = self.float_tensor(self.batch_size, 3).uniform_()
        random_coords[:, 0] *= self.cube_shape[0] / self.spatial_norm
        random_coords[:, 1] *= self.cube_shape[1] / self.spatial_norm
        random_coords[:, 2] *= self.cube_shape[2] / self.spatial_norm
        return random_coords

In [22]:
sampler = RandomCoordinateSampler(cube_shape, spatial_norm, batch_size * 2)
scheduler = ExponentialLR(opt, gamma=(5e-5 / 5e-4) ** (1 / total_iterations))

In [23]:
class BoundaryDataset(Dataset):

    def __init__(self, batches_path):
        """Data set for lazy loading a pre-batched numpy data array.

        :param batches_path: path to the numpy array.
        """
        self.batches_path = batches_path

    def __len__(self):
        return np.load(self.batches_path, mmap_mode='r').shape[0]

    def __getitem__(self, idx):
        # lazy load data
        d = np.load(self.batches_path, mmap_mode='r')[idx]
        d = np.copy(d)
        coord, field= d[:, 0],  d[:, 1]
        return coord, field
    
def _init_loader(batch_size, data, num_workers, iterations):
    # shuffle data
    r = np.random.permutation(data.shape[0])
    data = data[r]
    # adjust to batch size
    pad = batch_size - data.shape[0] % batch_size
    data = np.concatenate([data, data[:pad]])
    # split data into batches
    n_batches = data.shape[0] // batch_size
    batches = np.array(np.split(data, n_batches), dtype=np.float32)
    # store batches to disk
    batches_path = os.path.join(work_directory, 'batches.npy')
    np.save(batches_path, batches)
    # create data loaders
    dataset = BoundaryDataset(batches_path)
    # create loader
    data_loader = DataLoader(dataset, batch_size=None, num_workers=num_workers, pin_memory=True,
                                sampler=RandomSampler(dataset, replacement=True, num_samples=iterations))
    return data_loader, batches_path

In [24]:
data_loader, batches_path = _init_loader(batch_size, data, num_workers, total_iterations)

In [25]:
total_b_diff = []
total_divergence_loss = []
total_force_loss = []
model.train()

BModel(
  (d_in): Linear(in_features=3, out_features=256, bias=True)
  (linear_layers): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): Linear(in_features=256, out_features=256, bias=True)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): Linear(in_features=256, out_features=256, bias=True)
  )
  (d_out): Linear(in_features=256, out_features=3, bias=True)
  (activation): Sine()
)

In [26]:
def jacobian(output, coords):
    jac_matrix = [torch.autograd.grad(output[:, i], coords,
                                      grad_outputs=torch.ones_like(output[:, i]).to(output),
                                      retain_graph=True,
                                      create_graph=True)[0]
                  for i in range(output.shape[1])]
    jac_matrix = torch.stack(jac_matrix, dim=1)
    return jac_matrix

def calculate_loss(b, coords):
    jac_matrix = jacobian(b, coords)
    dBx_dx = jac_matrix[:, 0, 0]
    dBy_dx = jac_matrix[:, 1, 0]
    dBz_dx = jac_matrix[:, 2, 0]
    dBx_dy = jac_matrix[:, 0, 1]
    dBy_dy = jac_matrix[:, 1, 1]
    dBz_dy = jac_matrix[:, 2, 1]
    dBx_dz = jac_matrix[:, 0, 2]
    dBy_dz = jac_matrix[:, 1, 2]
    dBz_dz = jac_matrix[:, 2, 2]
    #
    rot_x = dBz_dy - dBy_dz
    rot_y = dBx_dz - dBz_dx
    rot_z = dBy_dx - dBx_dy
    #
    j = torch.stack([rot_x, rot_y, rot_z], -1)
    jxb = torch.cross(j, b, -1)
    force_loss = torch.sum(jxb ** 2, dim=-1) / (torch.sum(b ** 2, dim=-1) + 1e-7)
    divergence_loss = (dBx_dx + dBy_dy + dBz_dz) ** 2
    return divergence_loss, force_loss

In [27]:
lambda_B = 1000 if decay_iterations else 1
lambda_B_decay = (1 / 1000) ** (1 / decay_iterations) if decay_iterations is not None else 1
history = {'iteration': [], 'height': [],
                       'b_loss': [], 'divergence_loss': [], 'force_loss': [], 'sigma_angle': []}

In [28]:
save_path = os.path.join(base_path, 'extrapolation_result.nf2')
checkpoint_path = os.path.join(base_path, 'checkpoint.pt')

In [29]:
def plot_sample(iteration, n_samples=10, batch_size=4096):
        fig, axs = plt.subplots(3, n_samples, figsize=(n_samples * 4, 12))
        heights = np.linspace(0, 1, n_samples) ** 2 * (height - 1)  # more samples from lower heights
        imgs = np.array([get_image(h, batch_size) for h in heights])
        for i in range(3):
            for j in range(10):
                v_min_max = np.max(np.abs(imgs[j, ..., i]))
                axs[i, j].imshow(imgs[j, ..., i].transpose(), cmap='gray', vmin=-v_min_max, vmax=v_min_max,
                                 origin='lower')
                axs[i, j].set_axis_off()
        for j, h in enumerate(heights):
            axs[0, j].set_title('%.01f' % h)
        fig.tight_layout()
        fig.savefig(os.path.join(base_path, '%06d.jpg' % (iteration + 1)))
        plt.close(fig)

def get_image(z=0, batch_size=4096):
    image_loader = DataLoader(ImageDataset([*cube_shape, 3], spatial_norm, z),
                                batch_size=batch_size, shuffle=False)
    image = []
    for coord in image_loader:
        coord.requires_grad = True
        pred_pix = model(coord.to(device))
        image.extend(pred_pix.detach().cpu().numpy())
    image = np.array(image).reshape((*cube_shape[:2], 3))
    return image

def save(iteration):
    torch.save({'model': model,
                'cube_shape': cube_shape,
                'b_norm': b_norm,
                'spatial_norm': spatial_norm,
                'meta_info': meta_info}, save_path)
    torch.save({'iteration': iteration + 1,
                'm': model.state_dict(),
                'o': opt.state_dict(),
                'history': history,
                'lambda_B': lambda_B},
                checkpoint_path)
    
def validate(z, batch_size):
    b, j, div, coords = get_cube(z, batch_size)
    b = b.unsqueeze(0) * b_norm
    j = j.unsqueeze(0) * b_norm / spatial_norm
    div = div.unsqueeze(0) * b_norm / spatial_norm

    norm = b.pow(2).sum(-1).pow(0.5) * j.pow(2).sum(-1).pow(0.5)
    angle = torch.cross(j, b, dim=-1).pow(2).sum(-1).pow(0.5) / norm
    sig = torch.asin(torch.clip(angle, -1. + 1e-7, 1. - 1e-7)) * (180 / np.pi)
    sig = torch.abs(sig)
    weighted_sig = np.average(sig.numpy(), weights=j.pow(2).sum(-1).pow(0.5).numpy())

    b_diff = torch.abs(b[0, :, :, 0, :] - b_cube)
    b_diff = torch.clip(b_diff, 0, None)
    b_diff = torch.sqrt((b_diff ** 2).sum(-1))

    b_norm = b.pow(2).sum(-1).pow(0.5) + 1e-7
    div_loss = div / b_norm
    for_loss = torch.cross(j, b, dim=-1).pow(2).sum(-1).pow(0.5) / b_norm

    return b_diff.mean().numpy(), torch.mean(div_loss).numpy(), \
            torch.mean(for_loss).numpy(), weighted_sig

def get_cube(max_height, batch_size=int(1e4)):
    b = []
    j = []
    div = []

    coords = np.stack(np.mgrid[:cube_shape[0], :cube_shape[1], :max_height], -1)
    coords = torch.tensor(coords / spatial_norm, dtype=torch.float32)
    coords_shape = coords.shape[:-1]
    coords = coords.view((-1, 3))
    for k in tqdm(range(int(np.ceil(coords.shape[0] / batch_size))), desc='Validation'):
        coord = coords[k * batch_size: (k + 1) * batch_size]
        coord.requires_grad = True
        coord = coord.to(device)
        b_batch = model(coord)

        jac_matrix = jacobian(b_batch, coord)
        dBx_dx = jac_matrix[:, 0, 0]
        dBy_dx = jac_matrix[:, 1, 0]
        dBz_dx = jac_matrix[:, 2, 0]
        dBx_dy = jac_matrix[:, 0, 1]
        dBy_dy = jac_matrix[:, 1, 1]
        dBz_dy = jac_matrix[:, 2, 1]
        dBx_dz = jac_matrix[:, 0, 2]
        dBy_dz = jac_matrix[:, 1, 2]
        dBz_dz = jac_matrix[:, 2, 2]
        #
        rot_x = dBz_dy - dBy_dz
        rot_y = dBx_dz - dBz_dx
        rot_z = dBy_dx - dBx_dy
        #
        j_batch = torch.stack([rot_x, rot_y, rot_z], -1)
        div_batch = torch.abs(dBx_dx + dBy_dy + dBz_dz)
        #
        b += [b_batch.detach().cpu()]
        j += [j_batch.detach().cpu()]
        div += [div_batch.detach().cpu()]

    b = torch.cat(b, dim=0).view((*coords_shape, 3))
    j = torch.cat(j, dim=0).view((*coords_shape, 3))
    div = torch.cat(div, dim=0).view(coords_shape)
    return b, j, div, coords

def plotHistory(self):
    history = self.history
    plt.figure(figsize=(12, 16))
    plt.subplot(411)
    plt.plot(history['iteration'], history['b_loss'], label='B')
    plt.xlabel('Iteration')
    plt.ylabel('B')
    plt.subplot(412)
    plt.plot(history['iteration'], history['divergence_loss'], label='Divergence')
    plt.xlabel('Iteration')
    plt.ylabel('Divergence')
    plt.subplot(413)
    plt.plot(history['iteration'], history['force_loss'], label='Force')
    plt.xlabel('Iteration')
    plt.ylabel('Force')
    plt.subplot(414)
    plt.plot(history['iteration'], history['sigma_angle'], label='Angle')
    plt.xlabel('Iteration')
    plt.ylabel('Angle')
    plt.tight_layout()
    plt.savefig(os.path.join(self.base_path, 'history.jpg'))
    plt.close()

In [30]:
for iter, (boundary_coords, b_true) in tqdm(enumerate(data_loader, start=0),
                                                    total=len(data_loader), desc='Training'):
    opt.zero_grad()
    # load input data
    boundary_coords, b_true= boundary_coords.to(device), b_true.to(device)
    random_coords = sampler.load_sample()
    random_coords = random_coords.to(device)

    # concatenate boundary and random points
    n_boundary_coords = boundary_coords.shape[0]
    coords = torch.cat([boundary_coords, random_coords], 0)
    coords.requires_grad = True

    # forward step
    b = model(coords)

    # compute boundary loss
    boundary_b = b[:n_boundary_coords]
    b_diff = torch.abs(boundary_b - b_true)
    b_diff = torch.mean(b_diff.pow(2).sum(-1))

    # compute div and ff loss
    divergence_loss, force_loss = calculate_loss(b, coords)

    # reset grad from auto-gradient operation
    opt.zero_grad()
    # compute loss
    (b_diff * lambda_B +
        divergence_loss.mean() * lambda_div +
        force_loss.mean() * lambda_ff).backward()
    # update step
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    opt.step()

    # save loss information
    total_b_diff += [b_diff.detach().cpu().numpy()]
    total_divergence_loss += [divergence_loss.mean().detach().cpu().numpy()]
    total_force_loss += [force_loss.mean().detach().cpu().numpy()]

    # update training parameters
    if lambda_B > 1:
        lambda_B *= lambda_B_decay
    if scheduler.get_last_lr()[0] > 5e-5:
        scheduler.step()
    # logging
    if log_interval > 0 and (iter + 1) % log_interval == 0:
        # log loss
        logging.info('[Iteration %06d/%06d] [B-Field: %.08f; Div: %.08f; For: %.08f] [%s]' %
                        (iter + 1, total_iterations,
                        np.mean(total_b_diff),
                        np.mean(total_divergence_loss),
                        np.mean(total_force_loss),
                        datetime.now() - start_time))
        # reset
        total_b_diff = []
        total_divergence_loss = []
        total_force_loss = []

        # plot sample
        model.eval()
        plot_sample(iter, batch_size=batch_size)
        model.train()

        # log decay parameters
        logging.info('Lambda B: %f' % (lambda_B))
        logging.info('LR: %f' % (scheduler.get_last_lr()[0]))
    # validation
    if validation_interval > 0 and (iter + 1) % validation_interval == 0:
        model.eval()
        save(iter)
        # validate and plot
        mean_b, total_divergence, mean_force, sigma_angle = validate(height, batch_size)
        logging.info('Validation [Cube: B: %.03f; Div: %.03f; For: %.03f; Sig: %.03f]' %
                        (mean_b, total_divergence, mean_force, sigma_angle))
        #
        history['iteration'].append(iter + 1)
        history['b_loss'].append(mean_b.mean())
        history['divergence_loss'].append(total_divergence)
        history['force_loss'].append(mean_force)
        history['sigma_angle'].append(sigma_angle)
        plotHistory()
        #
        model.train()
# save final model state
torch.save({'m': model.state_dict(),
            'o': opt.state_dict(), },
            os.path.join(base_path, 'final.pt'))
torch.save({'model': model,
            'cube_shape': cube_shape,
            'b_norm': b_norm,
            'spatial_norm': spatial_norm,
            'meta_info': meta_info}, save_path)
# cleanup
os.remove(batches_path)

Training:  10%|▉         | 9999/100000 [11:58<1:47:45, 13.92it/s][Iteration 010000/100000] [B-Field: 0.00232540; Div: 0.16562338; For: 0.34179765] [0:12:38.160740]
Lambda B: 251.188643
LR: 0.000397
Training:  10%|▉         | 9999/100000 [12:05<1:48:45, 13.79it/s]


NameError: name 'self' is not defined