In [1]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch

from utils import matrl2_error, rl2_error
from ops import injection2d, injection4d 
from ops import interp2d, interp1d_cols, interp1d_rows
from ops import restrict2d
from ops import fetch_nbrs2d #, fetch_nbrs4d
from ops import coord2idx2d, coord2idx4d
from ops import cat2d_nbr_coords
from ops import grid2d_coords, grid4d_coords
from einops import rearrange

from mlmm import Grid2D
from mlmm import K_local_interp_4D, K_local_eval_4D, K_local_assemble

In [2]:
l = 129
poisson_data = scipy.io.loadmat(f'../pde_data/green_learning/data2d_{l}/poisson.mat')
F = poisson_data['F']

In [3]:
def kernel_func(pts_pairs):
    x1 = pts_pairs[:,0]
    y1 = pts_pairs[:,1]

    x2 = pts_pairs[:,2]
    y2 = pts_pairs[:,3]

    mask = ((x1**2+y1**2) < 1) & ((x2**2+y2**2) < 1)

    k = 1/(4*torch.pi) * torch.log(((x1 - x2)**2 + (y1-y2)**2) / ((x1*y2-x2*y1)**2 + (x1*x2+y1*y2-1)**2))
    k = torch.nan_to_num(k, neginf=-2) * mask

    return k

In [4]:
def ffunc(pts):
    x = pts[:,0]
    y = pts[:,1]
    u = (1 - (x**2+y**2))**-0.5
    u = torch.nan_to_num(u, posinf=0)
    return u

In [56]:
2**15+1

32769

In [43]:
n = 9
m = 4
k = 3

# build multi-level grids
ml_grids = []
for l in range(k+1):
    nh = 2**(n-l)+1
    grid = Grid2D(nh, m)
    ml_grids.append(grid)

# build multi-level f
ml_f = []
for l in range(k+1):
    if l == 0:
        x_h = ml_grids[0].x_h
        nh = ml_grids[0].nh
        f_h = ffunc(x_h).reshape(nh, nh)
    else:
        f_h = restrict2d(f_h[None,None])[0,0]
    ml_f.append(f_h)

# # eval kernel at finest level
# finest_grid = ml_grids[0]
# finest_grid.init_grid_hh()
# K_hh = kernel_func(finest_grid.x_hh.reshape(-1,4))

# # eval kernel integral at finest level
# nH = finest_grid.nh
# HH = finest_grid.hh
# f_h = ml_f[0]
# u_ref = HH * n(K_hh.reshape(nH*nH, nH*nH) @ f_h.reshape(-1)).reshape(nH,nH)
    
# eval kernel at coarest level
coarest_grid = ml_grids[-1]
coarest_grid.init_grid_hh()
K_hh = kernel_func(coarest_grid.x_hh.reshape(-1,4))

# eval kernel integral at coarest level
nH = coarest_grid.nh
HH = coarest_grid.hh
f_h = ml_f[-1]
u_h = HH * (K_hh.reshape(nH*nH, nH*nH) @ f_h.reshape(-1)).reshape(nH,nH)
u_interp = interp2d(u_h[None,None])[0,0]
# print("m {:} : {:.4e} ".format(2*m, matrl2_error(u_interp, u_ref).numpy()))

# multi-level correction
ml_grids = ml_grids[::-1]
ml_f = ml_f[::-1]
K_IJ = K_hh[coarest_grid.ij_idx]
K_IJ = K_IJ.reshape(nH,nH,2*m+1,2*m+1)

for l in range(k):
    nh = ml_grids[l+1].nh
    hh = ml_grids[l+1].hh
    f_h = ml_f[l+1]

    # local kernel evaluation
    idx_corr_even, idx_corr_odd = ml_grids[l].fetch_local_idx()
    x_2Ij = ml_grids[l].fetch_K_local_x()
    K_local_even_lst, K_local_odd_lst = K_local_eval_4D(x_2Ij, kernel_func)
    print("K_IJ {:} : ".format(l), K_IJ.shape)
    K_ij = K_local_assemble(K_IJ, K_local_even_lst, K_local_odd_lst)
    K_2Ij = K_ij[::2,::2]

    # local kernel interpolation
    K_local_even_, K_local_odd_ = K_local_interp_4D(K_IJ, K_2Ij)

    # calculate difference
    K_local_even = torch.cat([k.reshape(-1) for k in K_local_even_lst], axis=0)
    K_local_odd = torch.cat([k.reshape(-1) for k in K_local_odd_lst], axis=0)
    K_corr_even = K_local_even - K_local_even_
    K_corr_odd = K_local_odd - K_local_odd_

    K_corr_even_sparse = torch.sparse_coo_tensor(idx_corr_even, K_corr_even,(nh**2,nh**2))
    u_corr_ = torch.sparse.mm(K_corr_even_sparse, f_h.reshape(-1,1)).reshape(nh,nh)
    u_corr_ = hh * injection2d(u_corr_[None,None])[0,0]
    u_h_ = u_h + u_corr_
    u_h_ = interp2d(u_h_[None,None])[0,0]
    # print("m {:} : {:.4e} ".format(2*m, matrl2_error(u_h_, u_ref).numpy()))

    K_corr_odd_sparse = torch.sparse_coo_tensor(idx_corr_odd, K_corr_odd,(nh**2,nh**2))
    u_corr_ = hh*torch.sparse.mm(K_corr_odd_sparse, f_h.reshape(-1,1)).reshape(nh,nh)
    u_h_ = u_h_ + u_corr_
    
    K_IJ = K_ij[:,:,m:-m,m:-m]
    u_h = u_h_

print("m {:} : {:.4e} ".format(2*m, matrl2_error(u_h_, u_ref).numpy()))

K_IJ 0 :  torch.Size([33, 33, 17, 17])
K_IJ 1 :  torch.Size([65, 65, 17, 17])
m 16 : 6.5506e-03 


In [35]:
print("m {:} : {:.4e} ".format(2*m, matrl2_error(u_h_, u_ref).numpy()))

m 2 : 1.0112e-02 


In [37]:
%%timeit

u_h = hh * (K_hh.reshape(nh*nh, nh*nh) @ f_h.reshape(-1)).reshape(nh,nh)

16.3 ms ± 94.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
%%timeit 

u_H = HH * (K_HH.reshape(nH*nH, nH*nH) @ f_H.reshape(-1)).reshape(nH,nH)
u_corr_ = torch.sparse.mm(K_corr_even_sparse, f_h.reshape(-1,1)).reshape(nh,nh)
u_corr_ = u_corr_[::2,::2]
u_H_ = u_H + hh * u_corr_
u_h_ = interp2d(u_H_[None,None])[0,0]
u_corr_ = torch.sparse.mm(K_corr_odd_sparse, f_h.reshape(-1,1)).reshape(nh,nh)
u_h_ = u_h_ + hh*u_corr_

12.6 ms ± 27.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
