In [1]:
import numpy as np
import torch
from torch import nn

from neuralop.models import UNO

from rtmag.deeponet.model import MLP, Sine

from scipy.interpolate import RegularGridInterpolator as rgi

from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import Dataset

from pathlib import Path

from tqdm import tqdm

In [2]:
class DeepONetUNO(nn.Module):
    def __init__(self, trunk_in_dim, out_dim, hidden_dim, num_layers):
        super().__init__()
        # self.branch_inc = UNO(
        #                         hidden_channels=32,
        #                         in_channels=3,
        #                         out_channels=3,
        #                         lifting_channels=256,
        #                         projection_channels=256,
        #                         n_layers=6,

        #                         factorization="tucker",
        #                         implementation="factorized",
        #                         rank=0.5,

        #                         uno_n_modes=[[16,16],
        #                                         [ 8, 8],
        #                                         [ 8, 8],
        #                                         [ 8, 8],
        #                                         [ 8, 8],
        #                                         [16,16]],
        #                         uno_out_channels=[32,
        #                                             64,
        #                                             64,
        #                                             64,
        #                                             64,
        #                                             32],
        #                         uno_scalings=[[1.0,1.0],
        #                                         [0.5,0.5],
        #                                         [1.0,1.0],
        #                                         [0.5,0.5],
        #                                         [1.0,1.0],
        #                                         [0.5,0.5]]
        #                     )
        self.branch_inc = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.SiLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
        )
        self.branch_layer = nn.Linear(3*64*32, hidden_dim)
        self.trunk_layer = MLP(trunk_in_dim, hidden_dim, hidden_dim, num_layers)
        self.d_out = nn.Linear(hidden_dim, out_dim)
        self.activation = Sine()
    
    def forward(self, bc, x):
        """
        bc     : [batch_size, 3, 512, 256]
        x      : [batch_size, batch_coords, trunk_in_dim]

        output : [batch_size, batch_coords, out_dim]
        """
        branch_latent = self.branch_inc(bc)
        branch_latent = torch.flatten(branch_latent, 1)
        branch_latent = self.branch_layer(branch_latent)
        trunk_latent = self.trunk_layer(x)
        latent = branch_latent[:, None, :] * trunk_latent
        output = self.d_out(self.activation(latent))
        return output

In [3]:
class DeepONetDataset(Dataset):

    def __init__(self, file_list, b_norm, spatial_norm, cube_shape, 
                 bottom_batch_coords=1,
                 boundary_batch_coords=1, 
                 random_batch_coords=1):
        super().__init__()
        self.cube_shape = np.array([[0, cube_shape[0] - 1], [0, cube_shape[1] - 1], [0, cube_shape[2] - 1]])
        self.files = file_list
        self.b_norm = b_norm
        self.spatial_norm = spatial_norm
        self.bottom_batch_coords = int(bottom_batch_coords)
        self.boundary_batch_coords = int(boundary_batch_coords)
        self.random_batch_coords = int(random_batch_coords)
        self.float_tensor = torch.FloatTensor
        self.coords_shape = cube_shape

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        input_file = self.files[idx]
        input_file = Path(input_file)
        label_file = input_file.parent.parent / 'label' / input_file.name.replace('input_', 'label_')

        inputs = np.load(input_file)
        labels = np.load(label_file)

        branch_input = inputs['input'][..., 0].astype(np.float32)
        branch_input = branch_input / self.b_norm

        b_bottom = inputs['input'].transpose(1, 2, 3, 0)
        bottom_values = b_bottom.reshape(-1, 3).astype(np.float32)
        bottom_values = bottom_values / self.b_norm        

        bottom_coords = np.stack(np.mgrid[:b_bottom.shape[0], :b_bottom.shape[1], :b_bottom.shape[2]], -1).reshape(-1, 3).astype(np.float32)
        bottom_coords = bottom_coords / self.spatial_norm

        nx, ny, nz = self.coords_shape
        top_coords = np.stack(np.mgrid[:nx, :ny, (nz-1):nz], -1).reshape(-1, 3).astype(np.float32)
        lateral_1_coords = np.stack(np.mgrid[:nx, :1, :nz], -1).reshape(-1, 3).astype(np.float32)
        lateral_2_coords = np.stack(np.mgrid[:nx, (ny-1):ny, :nz], -1).reshape(-1, 3).astype(np.float32)
        lateral_3_coords = np.stack(np.mgrid[:1, :ny, :nz], -1).reshape(-1, 3).astype(np.float32)
        lateral_4_coords = np.stack(np.mgrid[(nx-1):nx, :ny, :nz], -1).reshape(-1, 3).astype(np.float32)

        boundary_coords = np.concatenate([top_coords, 
                                        lateral_1_coords, 
                                        lateral_2_coords, 
                                        lateral_3_coords, 
                                        lateral_4_coords], axis=0)
        
        boundary_coords = boundary_coords / self.spatial_norm

        random_coords = self.float_tensor(self.random_batch_coords, 3).uniform_()
        random_coords[:, 0] = (random_coords[:, 0] * (self.cube_shape[0, 1] - self.cube_shape[0, 0]) + self.cube_shape[0, 0])
        random_coords[:, 1] = (random_coords[:, 1] * (self.cube_shape[1, 1] - self.cube_shape[1, 0]) + self.cube_shape[1, 0])
        random_coords[:, 2] = (random_coords[:, 2] * (self.cube_shape[2, 1] - self.cube_shape[2, 0]) + self.cube_shape[2, 0])
        random_coords = random_coords / self.spatial_norm


        slices_values = labels['label'].transpose(1, 2, 3, 0)
        b_shape = slices_values.shape
        slices_values = slices_values.reshape(-1, 3).astype(np.float32)
        slices_values = slices_values / self.b_norm        

        slices_coords = np.stack(np.mgrid[:b_shape[0], :b_shape[1], :b_shape[2]], -1).astype(np.float32)
        slices_coords = slices_coords.reshape(-1, 3).astype(np.float32)
        slices_coords = slices_coords / self.spatial_norm

        #--- pick bottom points
        r = np.random.choice(bottom_coords.shape[0], self.bottom_batch_coords)
        bottom_coords = bottom_coords[r]
        bottom_values = bottom_values[r]

        #--- pick boundary points
        r = np.random.choice(boundary_coords.shape[0], self.boundary_batch_coords)
        boundary_coords = boundary_coords[r]

        
        samples = {'branch_input': branch_input,
                   'random_coords': random_coords,
                   'slices_coords': slices_coords,
                   'slices_values': slices_values,
                   'bottom_coords': bottom_coords,
                   'bottom_values': bottom_values,
                   'boundary_coords': boundary_coords}

        return samples

In [4]:
def jacobian(output, coords):
    jac_matrix = [torch.autograd.grad(output[..., i], coords,
                                      grad_outputs=torch.ones_like(output[..., i]).to(output),
                                      retain_graph=True, create_graph=True, allow_unused=True)[0]
                  for i in range(output.shape[-1])]
    jac_matrix = torch.stack(jac_matrix, dim=-1)
    return jac_matrix


def calculate_pde_loss(b, coords):
    jac_matrix = jacobian(b, coords)
    dBx_dx = jac_matrix[..., 0, 0]
    dBy_dx = jac_matrix[..., 1, 0]
    dBz_dx = jac_matrix[..., 2, 0]
    dBx_dy = jac_matrix[..., 0, 1]
    dBy_dy = jac_matrix[..., 1, 1]
    dBz_dy = jac_matrix[..., 2, 1]
    dBx_dz = jac_matrix[..., 0, 2]
    dBy_dz = jac_matrix[..., 1, 2]
    dBz_dz = jac_matrix[..., 2, 2]
    #
    curl_x = dBz_dy - dBy_dz
    curl_y = dBx_dz - dBz_dx
    curl_z = dBy_dx - dBx_dy
    #
    j = torch.stack([curl_x, curl_y, curl_z], -1)
    #
    jxb = torch.cross(j, b, -1)
    loss_ff = torch.sum(jxb ** 2, dim=-1) / (torch.sum(b ** 2, dim=-1) + 1e-7)
    loss_ff = loss_ff.mean()

    loss_div = (dBx_dx + dBy_dy + dBz_dz) ** 2
    loss_div = loss_div.mean()
    return loss_div, loss_ff

In [5]:
branch_inc = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.SiLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
        )

In [6]:
test = torch.rand(1, 3, 512, 256)

In [7]:
branch_inc(test).shape

torch.Size([1, 3, 64, 32])

In [8]:
data_path = [
    '/mnt/f/isee_dataset/11158/input/input_11158_20110213_120000.npz',
]

label_path = [
    '/mnt/f/isee_dataset/11158/label/label_11158_20110213_120000.npz',
]

In [9]:
b_bottom = np.load(data_path[0])['input'].astype(np.float32)[:, :-1, :-1, :]

bx1 = b_bottom[0]
by1 = b_bottom[1]
bz1 = b_bottom[2]

nx, ny, nz = bx1.shape
spatial_norm = 256
b_norm = 2500

b_bottom = b_bottom / b_norm

x1 = np.linspace(0, (nx-1), nx).astype(np.float32) / spatial_norm
y1 = np.linspace(0, (ny-1), ny).astype(np.float32) / spatial_norm
z1 = np.linspace(0, (nz-1), nz).astype(np.float32) / spatial_norm

bxs = rgi((x1,y1,z1), bx1, bounds_error=False, fill_value=0)
bys = rgi((x1,y1,z1), by1, bounds_error=False, fill_value=0)
bzs = rgi((x1,y1,z1), bz1, bounds_error=False, fill_value=0)


def bx(x, y, z):
    """
        Evaluate Bx at given point(s).
    """
    xx = np.stack((x, y, z), axis=len(np.shape(x)))
    return bxs(xx)

def by(x, y, z):
    """
        Evaluate By at given point(s).
    """
    xx = np.stack((x, y, z), axis=len(np.shape(x)))
    return bys(xx)


def bz(x, y, z):
    """
        Evaluate Bz at given point(s).
    """
    xx = np.stack((x, y, z), axis=len(np.shape(x)))
    return bzs(xx)

def get_bottom(x, y, z):
    """
        Evaluate Bx, By, Bz at given point(s).
    """
    return np.stack([bx(x, y, z), 
                     by(x, y, z), 
                     bz(x, y, z)], -1).astype(np.float32)

def is_bottom(coord, tol=1e-4):
    bools = torch.logical_not(coord < tol)
    return torch.tensor(bools, dtype=torch.float32)

def output_transform(model, bc, coord, b_norm):
    return is_bottom(coord[..., 2].detach().cpu())[..., None] * model(bc, coord).detach().cpu() \
        + get_bottom(coord[..., 0].detach().cpu(), coord[..., 1].detach().cpu(), coord[..., 2].detach().cpu()) / b_norm

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
model = DeepONetUNO(3, 3, 256, 8).to(device)

In [12]:
b_bottom.shape

(3, 512, 256, 1)

In [13]:
bc = b_bottom[..., 0]
bc = bc[None, :]
bc = torch.tensor(bc, dtype=torch.float32)
bc.shape

torch.Size([1, 3, 512, 256])

In [14]:
nnz = 256
coords = np.stack(np.mgrid[:nx, :ny, :nnz], -1).astype(np.float32)
coords_shape = coords.shape
coords = coords.reshape(-1, 3)
coords = torch.tensor(coords, dtype=torch.float32)
coords = coords / spatial_norm
coords.shape

torch.Size([33554432, 3])

In [15]:
coords

tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0039],
        [0.0000, 0.0000, 0.0078],
        ...,
        [1.9961, 0.9961, 0.9883],
        [1.9961, 0.9961, 0.9922],
        [1.9961, 0.9961, 0.9961]])

In [21]:
cube_shape = (512, 256, 256)
b_norm = 2500
spatial_norm = 256

bottom_batch_coords = 2e4
boundary_batch_coords = 2e4
random_batch_coords = 1e4

don_dataset = DeepONetDataset(data_path, b_norm, spatial_norm, cube_shape,
                                 bottom_batch_coords,
                                 boundary_batch_coords, 
                                 random_batch_coords)

batch_size = 1
num_workers = 4
num_samples = int(2e3)
don_loader = DataLoader(don_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                        sampler=RandomSampler(don_dataset, replacement=True, num_samples=num_samples))

In [22]:
total_iterations = num_samples // batch_size
total_iterations

2000

In [23]:
base_path = Path("TL")
base_path.mkdir(parents=True, exist_ok=True)

In [24]:
lr_start = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=lr_start)

lambda_slices = 1
lambda_random = 1
lambda_div = 1
lambda_ff = 1

In [25]:
tqdm_loader = tqdm(don_loader, desc='Train')

best_loss = np.inf

for idx, batch in enumerate(tqdm_loader):
    model = model.train()

    branch_input = batch['branch_input'].to(device)

    slices_coords = batch['slices_coords'].to(device)
    b_slices = batch['slices_values'].to(device)
    b_slices_pred = output_transform(model, branch_input, slices_coords, b_norm)
    loss_slices = (b_slices - b_slices_pred).pow(2).mean()

    random_coords = batch['random_coords'].to(device)
    random_coords.requires_grad = True
    b_random = model(branch_input, random_coords).to(device)
    b_random_pred = output_transform(model, branch_input, random_coords, b_norm)
    loss_random = (b_random - b_random_pred).pow(2).mean()

    loss_div, loss_ff = calculate_pde_loss(b_random_pred, random_coords)

    loss = lambda_slices * loss_slices + \
           lambda_random * loss_random + \
           lambda_div * loss_div + \
           lambda_ff * loss_ff

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    tqdm_loader.set_description(f"Loss {loss.item():.4g} sli {loss_slices.item():.4g} ran {loss_random.item():.4g} div {loss_div.item():.4g} ff {loss_ff.item():.4g}")

    torch.save({'idx': idx, 
                'model_state_dict': model.state_dict(), 
                'optimizer_state_dict': optimizer.state_dict()}, 
                base_path / "last.pt")

    torch.save({'model_state_dict': model.state_dict(),
                'spatial_norm':spatial_norm,
                'b_norm':b_norm,
                'cube_shape':cube_shape}, base_path / "model_last.pt")

    if loss.item() < best_loss:
        torch.save({'idx': idx, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict()}, 
                    base_path / "best.pt")
        torch.save({'model_state_dict': model.state_dict(),
                    'spatial_norm':spatial_norm,
                    'b_norm':b_norm,
                    'cube_shape':cube_shape}, base_path / "model_best.pt")
        best_loss = loss.item()

  return torch.tensor(bools, dtype=torch.float32)
Train:   0%|          | 0/2000 [00:45<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 32.31 GiB. GPU 0 has a total capacity of 11.99 GiB of which 9.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 843.91 MiB is allocated by PyTorch, and 88.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)