In [8]:
import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Sine(nn.Module):
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class BModel(nn.Module):

    def __init__(self, in_coords, out_values, dim):
        super().__init__()
        self.d_in = nn.Linear(in_coords, dim)
        lin = [nn.Linear(dim, dim) for _ in range(8)]
        self.linear_layers = nn.ModuleList(lin)
        self.d_out = nn.Linear(dim, out_values)
        self.activation = Sine()  # torch.tanh

    def forward(self, x):
        x = self.activation(self.d_in(x))
        for l in self.linear_layers:
            x = self.activation(l(x))
        x = self.d_out(x)
        return x

In [4]:
model = BModel(3, 3, 256)

In [44]:
from torch.utils.data import Dataset, DataLoader, RandomSampler
import numpy as np

In [14]:
b_slices = np.load('b.npy')

In [15]:
b_slices.shape

(512, 256, 256, 3)

In [16]:
coords = np.stack(np.mgrid[:b_slices.shape[0], :b_slices.shape[1], :b_slices.shape[2]], -1).astype(np.float32)
coords.shape

(512, 256, 256, 3)

In [17]:
coords = coords.reshape((-1, 3)).astype(np.float32)
values = b_slices.reshape((-1, 3)).astype(np.float32)

In [18]:
coords.shape

(33554432, 3)

In [19]:
values.shape

(33554432, 3)

In [20]:
cube_shape = b_slices.shape
cube_shape

(512, 256, 256, 3)

In [None]:
cube_shape = np.array([[0, cube_shape[0] - 1], [0, cube_shape[1] - 1], [0, cube_shape[2] - 1]])

In [34]:
class MyDataset(Dataset):
    def __init__(self, data_path, b_norm, spatial_norm, boundary_batch_coords, random_batch_coords):
        self.data_path = data_path
        self.b_norm = b_norm
        self.spatial_norm = spatial_norm
        self.boundary_batch_coords = int(boundary_batch_coords)
        self.random_batch_coords = int(random_batch_coords)
        self.float_tensor = torch.FloatTensor

    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        b_slices = np.load(self.data_path)

        coords = np.stack(np.mgrid[:b_slices.shape[0], :b_slices.shape[1], :b_slices.shape[2]], -1).astype(np.float32)
        coords = coords.reshape((-1, 3)).astype(np.float32)
        values = b_slices.reshape((-1, 3)).astype(np.float32)

        coords = coords / self.spatial_norm
        values = values / self.b_norm

        cube_shape = b_slices.shape
        cube_shape = np.array([[0, cube_shape[0] - 1], [0, cube_shape[1] - 1], [0, cube_shape[2] - 1]])

        random_coords = self.float_tensor(self.random_batch_coords, 3).uniform_()
        random_coords[:, 0] = (random_coords[:, 0] * (cube_shape[0, 1] - cube_shape[0, 0]) + cube_shape[0, 0])
        random_coords[:, 1] = (random_coords[:, 1] * (cube_shape[1, 1] - cube_shape[1, 0]) + cube_shape[1, 0])
        random_coords[:, 2] = (random_coords[:, 2] * (cube_shape[2, 1] - cube_shape[2, 0]) + cube_shape[2, 0])
        random_coords = random_coords / self.spatial_norm

        #--- pick one data
        r = np.random.choice(coords.shape[0], self.boundary_batch_coords)
        coords = coords[r]
        values = values[r]
        
        samples = {'random_coords': random_coords,
                   'coords': coords,
                   'values': values}

        return samples
    

In [35]:
dataset = MyDataset('b.npy', 2500, 255, 100, 50)

In [45]:
batch_size = None
num_workers = 4
num_samples = 200
don_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                        sampler=RandomSampler(dataset, replacement=True, num_samples=num_samples))

In [46]:
batch = next(iter(don_loader))

In [52]:
random_coords = batch['random_coords'].to(device)
coords = batch['coords'].to(device)
values = batch['values'].to(device)

In [53]:
print(random_coords.shape)
print(coords.shape)
print(values.shape)

torch.Size([50, 3])
torch.Size([100, 3])
torch.Size([100, 3])


In [50]:
model = model.to(device)

In [57]:
def jacobian(output, coords):
    jac_matrix = [torch.autograd.grad(output[..., i], coords,
                                      grad_outputs=torch.ones_like(output[..., i]).to(output),
                                      retain_graph=True, create_graph=True, allow_unused=True)[0]
                  for i in range(output.shape[-1])]
    jac_matrix = torch.stack(jac_matrix, dim=-1)
    return jac_matrix


def calculate_pde_loss(b, coords):
    jac_matrix = jacobian(b, coords)
    dBx_dx = jac_matrix[..., 0, 0]
    dBy_dx = jac_matrix[..., 1, 0]
    dBz_dx = jac_matrix[..., 2, 0]
    dBx_dy = jac_matrix[..., 0, 1]
    dBy_dy = jac_matrix[..., 1, 1]
    dBz_dy = jac_matrix[..., 2, 1]
    dBx_dz = jac_matrix[..., 0, 2]
    dBy_dz = jac_matrix[..., 1, 2]
    dBz_dz = jac_matrix[..., 2, 2]
    #
    curl_x = dBz_dy - dBy_dz
    curl_y = dBx_dz - dBz_dx
    curl_z = dBy_dx - dBx_dy
    #
    j = torch.stack([curl_x, curl_y, curl_z], -1)
    #
    jxb = torch.cross(j, b, -1)
    loss_ff = torch.sum(jxb ** 2, dim=-1) / (torch.sum(b ** 2, dim=-1) + 1e-7)

    loss_div = (dBx_dx + dBy_dy + dBz_dz) ** 2
    return loss_div, loss_ff

In [58]:
lr_start = 1e-5

In [59]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr_start)

In [60]:
from tqdm import tqdm

In [None]:
tqdm_loader = tqdm(don_loader, desc='Train')

best_loss = np.inf

model = model.train()

for idx, batch in enumerate(tqdm_loader):

    branch_input = batch['branch_input']

    coords = batch['coords']
    coords.requires_grad = True
    b_slices = batch['values'].to(device)

    random_coords = batch['random_coords']
    random_coords.requires_grad = True

    n_boundary_coords = coords.shape[1]
    coords = torch.concatenate([coords, random_coords], 1)

    branch_input = branch_input.to(device)
    coords = coords.to(device)

    b = model(branch_input, coords)

    b_pred = b[:, :n_boundary_coords, :]
    loss_bc = torch.clip(torch.abs(b_pred - b_slices), 0)
    loss_bc = torch.mean(torch.nansum(loss_bc.pow(2), -1))
    
    loss_div, loss_ff = calculate_pde_loss(b, coords)
    loss_div, loss_ff = loss_div.mean(), loss_ff.mean()

    loss = lambda_b * loss_bc + \
           lambda_div * loss_div + \
           lambda_ff * loss_ff

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    tqdm_loader.set_description(f"Loss {loss.item():.4f} bc {loss_bc.item():.4f} div {loss_div.item():.4f} ff {loss_ff.item():.4f}")

    torch.save({'idx': idx, 
                'model_state_dict': model.state_dict(), 
                'optimizer_state_dict': optimizer.state_dict()}, 
                "last.pt")

    torch.save({'model_state_dict': model.state_dict(),
                'spatial_norm':spatial_norm,
                'b_norm':b_norm,
                'cube_shape':cube_shape}, "model_last.pt")

    if loss.item() < best_loss:
        torch.save({'idx': idx, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict()}, 
                    "best.pt")
        torch.save({'model_state_dict': model.state_dict(),
                    'spatial_norm':spatial_norm,
                    'b_norm':b_norm,
                    'cube_shape':cube_shape}, "model_best.pt")
        best_loss = loss.item()