In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "0"

import numpy as np
b = np.load("b.npy")

height = 64
spatial_norm = 32
b_norm = 100

In [2]:
b_bottom = b[:, :, 0, :]

Nx, Ny, _ = b_bottom.shape
Nz = height

cube_shape = (Nx, Ny, Nz)

def coords(xbounds, ybounds, zbounds):
    return np.stack(np.mgrid[xbounds[0]:xbounds[1]+1, ybounds[0]:ybounds[1]+1, zbounds[0]:zbounds[1]+1], axis=-1).reshape(-1, 3)

bottom_coords = coords((0, Nx-1), (0, Ny-1), (0, 0))
bottom_values = b_bottom.reshape(-1, 3)

# top_lateral_coords = [coords((0, Nx-1), (0, Ny-1), (Nz-1, Nz-1)),
#                       coords((0, 0), (0, Ny-1), (0, Nz-1)),
#                       coords((Nx-1, Nx-1), (0, Ny-1), (0, Nz-1)),
#                       coords((0, Nx-1), (0, 0), (0, Nz-1)),
#                       coords((0, Nx-1), (Ny-1, Ny-1), (0, Nz-1))]

# top_lateral BC +- 1 normal!
top_lateral_coords = [coords((0, Nx-1), (0, Ny-1), (Nz-2, Nz)),
                      coords((-1, 1), (0, Ny-1), (0, Nz-1)),
                      coords((Nx-2, Nx), (0, Ny-1), (0, Nz-1)),
                      coords((0, Nx-1), (-1, 1), (0, Nz-1)),
                      coords((0, Nx-1), (Ny-2, Ny), (0, Nz-1))]

In [3]:
r_top_lateral = np.concatenate(top_lateral_coords)
r_bottom = bottom_coords
bz_bottom = bottom_values[:, 2]

In [4]:
c = np.array([[0, 0, 1/np.sqrt(2*np.pi)]])
denominator_vector = r_top_lateral[:, None] - r_bottom[None, :] + c[None, :]
denominator = np.sqrt(np.sum(denominator_vector**2, -1))
potential_numpy = (1/(2*np.pi))*np.sum((bz_bottom[None, :] / denominator), -1)

In [20]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.array([[0, 0, 1/np.sqrt(2*np.pi)]])
        c = torch.tensor(c, dtype=torch.float32)
        self.register_buffer('c', c)

    def forward(self, r):
        v1 = self.b_n[:, None]
        print(v1.shape)
        v2 = 2 * np.pi * ((r[None, :] - self.r_p[:, None] + self.c[None]) ** 2).sum(-1) ** 0.5
        potential = torch.sum(v1 / v2, dim=0)
        return potential
    
with torch.no_grad():
    b_n = torch.tensor(bz_bottom, dtype=torch.float32)
    r_p = torch.tensor(r_bottom, dtype=torch.float32)
    model = nn.DataParallel(PotentialModel(b_n, r_p)).to(device)

    flat_coords = torch.tensor(r_top_lateral, dtype=torch.float32)

    potential = []
    for coord, in tqdm(DataLoader(TensorDataset(flat_coords), batch_size=1000, num_workers=2),
                        desc='Potential Boundary'):
        coord = coord.to(device)
        p_batch = model(coord)
        potential += [p_batch.cpu()]

potential = torch.cat(potential).numpy()

Potential Boundary:  42%|████▏     | 26/62 [00:00<00:00, 137.95it/s]

torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size([1, 1, 3])
torch.Size([4096, 1])
torch.Size

Potential Boundary: 100%|██████████| 62/62 [00:00<00:00, 189.33it/s]


In [16]:
torch.cat(potential)

tensor([7.3860, 7.3316, 7.2765,  ..., 7.3935, 7.3410, 7.2877])

In [6]:
potential_numpy[:10]

array([7.38598259, 7.33160802, 7.27652588, 7.43107096, 7.37563895,
       7.31950322, 7.47625497, 7.41972848, 7.36250503, 7.5215176 ])

In [7]:
potential[:10]

array([7.385983 , 7.331608 , 7.276526 , 7.431071 , 7.37564  , 7.319504 ,
       7.4762526, 7.419729 , 7.362506 , 7.521517 ], dtype=float32)

In [9]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
class PotentialModel(nn.Module):

    def __init__(self, b_n, r_p):
        super().__init__()
        self.register_buffer('b_n', b_n)
        self.register_buffer('r_p', r_p)
        c = np.array([[0, 0, 1/np.sqrt(2*np.pi)]])
        c = torch.tensor(c, dtype=torch.float32)
        self.register_buffer('c', c)

    def forward(self, r):
        v1 = self.b_n[None, :]
        v2 = 2 * np.pi * ((r[:, None] - self.r_p[None, :] + self.c[None, :]) ** 2).sum(-1) ** 0.5
        potential = torch.sum(v1 / v2, dim=0)
        return potential
    
pf_batch_size = int(1024 * 512 ** 2 / np.prod(b_bottom.shape[:2]))
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
with torch.no_grad():
    b_n = torch.tensor(bz_bottom, dtype=torch.float32)
    r_p = torch.tensor(r_bottom, dtype=torch.float32)
    model = nn.DataParallel(PotentialModel(b_n, r_p)).to(device)

    flat_coords = torch.tensor(r_top_lateral, dtype=torch.float32, )

    potential = []
    for coord, in tqdm(DataLoader(TensorDataset(flat_coords), batch_size=pf_batch_size, num_workers=2),
                        desc='Potential Boundary'):
        coord = coord.to(device)
        print(coord)
        p_batch = model(coord)
        print(p_batch)
        potential += [p_batch.cpu()]

potential = torch.cat(potential).numpy()

Potential Boundary: 100%|██████████| 1/1 [00:00<00:00,  1.85it/s]


In [10]:
potential[:10]

array([-385.80853, -406.6559 , -425.59357, -444.3722 , -462.82825,
       -480.83823, -498.30423, -515.13776, -531.2613 , -546.61127],
      dtype=float32)