In [1]:
from operator_aliasing.data.utils import get_data
from operator_aliasing.models.utils import get_model
from operator_aliasing.train.utils import get_loss

import torch
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [2]:
# Get Data
data_kwargs = {
    'dataset_name': 'darcy',
    'filter_lim': 3,
    'img_size': 16,
    'downsample_dim': -1,
    'train': True,
    'batch_size':16,
    'seed':0,
}

train_dataloader, test_loaders = get_data(**data_kwargs)

Loading test db for resolution 16 with 100 samples 
Loading test db for resolution 16 with 100 samples 


In [3]:
model = get_model(
            model_name='FNO2D',
            out_channels=1,
            in_channels=1,
            hidden_channels=32,
            max_modes=16,
        ).to(device)

In [4]:
class LpLoss(object):
    '''
    loss function with rel/abs Lp loss
    '''
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)

# PINNs Loss

# u = model pred
# a = label
def FDM_Darcy(u, a, D=1):
    batchsize = u.size(0)
    size = u.size(1)
    u = u.reshape(batchsize, size, size)
    a = a.reshape(batchsize, size, size)
    dx = D / (size - 1)
    dy = dx

    # ux: (batch, size-2, size-2)
    ux = (u[:, 2:, 1:-1] - u[:, :-2, 1:-1]) / (2 * dx)
    uy = (u[:, 1:-1, 2:] - u[:, 1:-1, :-2]) / (2 * dy)

    # ax = (a[:, 2:, 1:-1] - a[:, :-2, 1:-1]) / (2 * dx)
    # ay = (a[:, 1:-1, 2:] - a[:, 1:-1, :-2]) / (2 * dy)
    # uxx = (u[:, 2:, 1:-1] -2*u[:,1:-1,1:-1] +u[:, :-2, 1:-1]) / (dx**2)
    # uyy = (u[:, 1:-1, 2:] -2*u[:,1:-1,1:-1] +u[:, 1:-1, :-2]) / (dy**2)

    a = a[:, 1:-1, 1:-1]
    # u = u[:, 1:-1, 1:-1]
    # Du = -(ax*ux + ay*uy + a*uxx + a*uyy)

    # inner1 = torch.mean(a*(ux**2 + uy**2), dim=[1,2])
    # inner2 = torch.mean(f*u, dim=[1,2])
    # return 0.5*inner1 - inner2

    aux = a * ux
    auy = a * uy
    auxx = (aux[:, 2:, 1:-1] - aux[:, :-2, 1:-1]) / (2 * dx)
    auyy = (auy[:, 1:-1, 2:] - auy[:, 1:-1, :-2]) / (2 * dy)
    Du = - (auxx + auyy)
    return Du


def darcy_loss(u, a):
    batchsize = u.size(0)
    size = u.size(1)
    u = u.reshape(batchsize, size, size)
    a = a.reshape(batchsize, size, size)
    lploss = LpLoss(size_average=True)

    # index_x = torch.cat([torch.tensor(range(0, size)), (size - 1) * torch.ones(size), torch.tensor(range(size-1, 1, -1)),
    #                      torch.zeros(size)], dim=0).long()
    # index_y = torch.cat([(size - 1) * torch.ones(size), torch.tensor(range(size-1, 1, -1)), torch.zeros(size),
    #                      torch.tensor(range(0, size))], dim=0).long()

    # boundary_u = u[:, index_x, index_y]
    # truth_u = torch.zeros(boundary_u.shape, device=u.device)
    # loss_u = lploss.abs(boundary_u, truth_u)

    Du = FDM_Darcy(u, a)
    f = torch.ones(Du.shape, device=u.device)
    loss_f = lploss.rel(Du, f)

    # im = (Du-f)[0].detach().cpu().numpy()
    # plt.imshow(im)
    # plt.show()

    # loss_f = FDM_Darcy(u, a)
    # loss_f = torch.mean(loss_f)
    return loss_f


In [6]:
epochs = 5
lr = 1e-3
weight_decay = 1e-8
step_size = 15
gamma = 0.5
# set up optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=step_size, gamma=gamma
)
loss = get_loss("l1")

data_loss_weight = 0.95
pinn_loss_weight = 0.05

for epoch in range(epochs):
    train_loss = 0.0
    for _step, batch in enumerate(train_dataloader):  
        input_batch = batch['x'].to(device)
        output_batch = batch['y'].to(device)
        optimizer.zero_grad()
        output_pred_batch = model(input_batch)
        data_loss = loss(output_pred_batch, output_batch)

        pinn_loss = darcy_loss(output_pred_batch.squeeze(), output_batch.squeeze())
        #loss_f = data_loss
        loss_f = data_loss_weight * data_loss + pinn_loss_weight * pinn_loss
        loss_f.backward()
        optimizer.step()
        train_loss += loss_f.item()
    train_loss /= len(train_dataloader)
    scheduler.step()
    print(f"{epoch=}, {train_loss=}")

epoch=0, train_loss=0.1665982005615083
epoch=1, train_loss=0.1642174647440986
epoch=2, train_loss=0.16323051920958928
epoch=3, train_loss=0.16174429134717064
epoch=4, train_loss=0.15996588198911577
