In [8]:
from operator_aliasing.data.utils import get_data
from operator_aliasing.models.utils import get_model
from operator_aliasing.train.utils import get_loss

import torch
from torch.optim import AdamW
import numpy as np
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Get Data
data_kwargs = {
    'dataset_name': 'darcy',
    'filter_lim': 3,
    'img_size': 16,
    'downsample_dim': -1,
    'train': True,
    'batch_size':16,
    'seed':0,
}

train_dataloader, test_loaders = get_data(**data_kwargs)

Loading test db for resolution 16 with 100 samples 
Loading test db for resolution 16 with 100 samples 


In [3]:
model = get_model(
            model_name='FNO2D',
            out_channels=1,
            in_channels=1,
            hidden_channels=32,
            max_modes=16,
        ).to(device)

In [4]:
class LpLoss(object):
    '''
    loss function with rel/abs Lp loss
    '''
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

# PINNs Loss

# u = model pred
# a = label
def FDM_Darcy(u, a, D=1):
    batchsize = u.size(0)
    size = u.size(1)
    u = u.reshape(batchsize, size, size)
    a = a.reshape(batchsize, size, size)
    dx = D / (size - 1)
    dy = dx

    # ux: (batch, size-2, size-2)
    ux = (u[:, 2:, 1:-1] - u[:, :-2, 1:-1]) / (2 * dx)
    uy = (u[:, 1:-1, 2:] - u[:, 1:-1, :-2]) / (2 * dy)

    # ax = (a[:, 2:, 1:-1] - a[:, :-2, 1:-1]) / (2 * dx)
    # ay = (a[:, 1:-1, 2:] - a[:, 1:-1, :-2]) / (2 * dy)
    # uxx = (u[:, 2:, 1:-1] -2*u[:,1:-1,1:-1] +u[:, :-2, 1:-1]) / (dx**2)
    # uyy = (u[:, 1:-1, 2:] -2*u[:,1:-1,1:-1] +u[:, 1:-1, :-2]) / (dy**2)

    a = a[:, 1:-1, 1:-1]
    # u = u[:, 1:-1, 1:-1]
    # Du = -(ax*ux + ay*uy + a*uxx + a*uyy)

    # inner1 = torch.mean(a*(ux**2 + uy**2), dim=[1,2])
    # inner2 = torch.mean(f*u, dim=[1,2])
    # return 0.5*inner1 - inner2

    aux = a * ux
    auy = a * uy
    auxx = (aux[:, 2:, 1:-1] - aux[:, :-2, 1:-1]) / (2 * dx)
    auyy = (auy[:, 1:-1, 2:] - auy[:, 1:-1, :-2]) / (2 * dy)
    Du = - (auxx + auyy)
    return Du


def darcy_loss(u, a):
    batchsize = u.size(0)
    size = u.size(1)
    u = u.reshape(batchsize, size, size)
    a = a.reshape(batchsize, size, size)
    lploss = LpLoss(size_average=True)

    # index_x = torch.cat([torch.tensor(range(0, size)), (size - 1) * torch.ones(size), torch.tensor(range(size-1, 1, -1)),
    #                      torch.zeros(size)], dim=0).long()
    # index_y = torch.cat([(size - 1) * torch.ones(size), torch.tensor(range(size-1, 1, -1)), torch.zeros(size),
    #                      torch.tensor(range(0, size))], dim=0).long()

    # boundary_u = u[:, index_x, index_y]
    # truth_u = torch.zeros(boundary_u.shape, device=u.device)
    # loss_u = lploss.abs(boundary_u, truth_u)

    Du = FDM_Darcy(u, a)
    f = torch.ones(Du.shape, device=u.device)
    loss_f = lploss.rel(Du, f)

    # im = (Du-f)[0].detach().cpu().numpy()
    # plt.imshow(im)
    # plt.show()

    # loss_f = FDM_Darcy(u, a)
    # loss_f = torch.mean(loss_f)
    return loss_f


In [11]:
import torch.nn as nn
class DarcyDataAndPinnsLoss(nn.Module):
    def __init__(self, pinn_loss_weight:float):
        super().__init__()
        self.L1 = nn.L1Loss()
        self.lploss = LpLoss(size_average=True)
        self.pinn_loss_weight = pinn_loss_weight

    def forward(self, model_pred:torch.Tensor, ground_truth:torch.Tensor):
        """Loss calculation.
            model_pred shape: batch_size x 1 (no time) x X_dim x Y_dim
            ground_truth shape: same as model pred
        """
        data_loss = self.L1(model_pred, ground_truth)

        batchsize = model_pred.size(0)
        size = ground_truth.size(-1)
        u = model_pred.reshape(batchsize, size, size)
        a = ground_truth.reshape(batchsize, size, size)
        Du = FDM_Darcy(u, a)
        f = torch.ones(Du.shape, device=u.device)
        pinn_loss = self.lploss.rel(Du, f)
        
        return (1 - self.pinn_loss_weight) * data_loss + self.pinn_loss_weight * pinn_loss

In [12]:
epochs = 5
lr = 1e-3
weight_decay = 1e-8
step_size = 15
gamma = 0.5
# set up optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=step_size, gamma=gamma
)
loss = get_loss("l1")

data_loss_weight = 0.95
pinn_loss_weight = 0.05

darcy_pinn_loss = DarcyDataAndPinnsLoss(pinn_loss_weight=0.5)

for epoch in range(epochs):
    train_loss = 0.0
    for _step, batch in enumerate(train_dataloader):  
        input_batch = batch['x'].to(device)
        output_batch = batch['y'].to(device)
        optimizer.zero_grad()
        output_pred_batch = model(input_batch)

        loss_f = darcy_pinn_loss(output_pred_batch, output_batch)
        loss_f.backward()
        optimizer.step()
        train_loss += loss_f.item()
    train_loss /= len(train_dataloader)
    scheduler.step()
    print(f"{epoch=}, {train_loss=}")

epoch=0, train_loss=0.16677389920703947
epoch=1, train_loss=0.16411511552712274
epoch=2, train_loss=0.16248389203397054
epoch=3, train_loss=0.16238959748593587
epoch=4, train_loss=0.16045862885694656


In [7]:
loss

L1Loss()

# Burgers

In [19]:
# Get Data
data_kwargs = {
    'dataset_name': 'burgers_pdebench',
    'filter_lim': -1,
    'img_size': 1024,
    'downsample_dim': -1,
    'batch_size':7,
    'seed':0,
    'initial_steps':10,
    'model_name':'FNO1d',
    'darcy_forcing_term':1,
}

train_dataloader, test_loaders = get_data(**data_kwargs)

model = get_model(
            model_name='FNO1D',
            out_channels=1,
            in_channels=10,
            hidden_channels=32,
            max_modes=16,
        ).to(device)

In [45]:
v = 1.0 # viscosity (from dataset)

def FDM_Burgers(u, v, D=1):
    batchsize = u.size(0)
    nt = u.size(1)
    nx = u.size(-1)

    u = u.reshape(batchsize, nt, nx)
    dt = D / (nt-1)
    dx = D / (nx)

    u_h = torch.fft.fft(u, dim=2)
    # Wavenumbers in y-direction
    k_max = nx//2
    k_x = torch.cat((torch.arange(start=0, end=k_max, step=1, device=u.device),
                     torch.arange(start=-k_max, end=0, step=1, device=u.device)), 0).reshape(1,1,nx)
    ux_h = 2j *np.pi*k_x*u_h
    uxx_h = 2j *np.pi*k_x*ux_h
    ux = torch.fft.irfft(ux_h[:, :, :k_max+1], dim=2, n=nx)
    uxx = torch.fft.irfft(uxx_h[:, :, :k_max+1], dim=2, n=nx)
    ut = (u[:, 2:, :] - u[:, :-2, :]) / (2 * dt)
    Du = ut + (ux*u - v*uxx)[:,1:-1,:]
    return Du

In [37]:
initial_steps = 10
for _step, batch in enumerate(train_dataloader):  
    input_batch = batch['x'].to(device).squeeze()
    output_batch = batch['y'].to(device).squeeze()
    img_size = input_batch.shape[-1]
    batch_size = input_batch.shape[0]
    loss_f = 0
    t_train = output_batch.shape[1]  # number of time steps
    shape = (batch_size, -1)
    for _dim in range(model.n_dim):
        shape += (img_size,)

    preds = []
    for t in range(initial_steps, t_train):
        # Extract target at current time step
        # squeeze out time dim
        output_at_time_step = output_batch[:, t : t + 1, ...].squeeze(dim=1)

        # Model run
        model_input = torch.reshape(input_batch, shape)
        output_pred_batch = model(model_input)
        preds.append(output_pred_batch)

    full_model_preds = torch.stack(preds, dim=1)
    break

In [36]:
torch.stack(preds, dim=1).shape

torch.Size([7, 91, 1, 1024])

In [22]:
input_batch.shape

torch.Size([7, 10, 1024])

In [46]:
output_pred_batch = input_batch # temp just to get correct shape

FDM_Burgers(full_model_preds, v).shape

torch.Size([7, 89, 1024])

In [24]:
u = input_batch
batchsize = u.size(0)
nt = u.size(1)
nx = u.size(2)
print(u.shape)

u = u.reshape(batchsize, nt, nx)
print(u.shape)
# lploss = LpLoss(size_average=True)

index_t = torch.zeros(nx,).long()
index_x = torch.tensor(range(nx)).long()
boundary_u = u[:, index_t, index_x]
#loss_u = F.mse_loss(boundary_u, u0)

torch.Size([7, 10, 1024])
torch.Size([7, 10, 1024])


In [26]:
boundary_u.shape

torch.Size([7, 1024])

In [None]:
def PINO_loss(u, u0, v):
    batchsize = u.size(0)
    nt = u.size(1)
    nx = u.size(2)

    u = u.reshape(batchsize, nt, nx)
    # lploss = LpLoss(size_average=True)

    index_t = torch.zeros(nx,).long()
    index_x = torch.tensor(range(nx)).long()
    boundary_u = u[:, index_t, index_x]
    loss_u = F.mse_loss(boundary_u, u0)

    Du = FDM_Burgers(u, v)[:, :, :]
    f = torch.zeros(Du.shape, device=u.device)
    loss_f = F.mse_loss(Du, f)

    return loss_u, loss_f

In [47]:
batch['x'].shape

torch.Size([7, 10, 1, 1024])

In [42]:
from torch import nn

loss = nn.L1Loss()

loss(full_model_preds, full_model_preds)

tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)

In [44]:
full_model_preds.shape

torch.Size([7, 91, 1, 1024])