In [1]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch

import sys
sys.path.append('../')
from utils import matrl2_error, rl2_error
from ops import injection2d, injection4d 
from ops import interp2d, interp1d_cols, interp1d_rows
from ops import restrict2d
from ops import fetch_nbrs2d #, fetch_nbrs4d
from ops import coord2idx2d, coord2idx4d
from ops import cat2d_nbr_coords
from ops import grid2d_coords, grid4d_coords
from einops import rearrange

In [2]:
l = 65
poisson_data = scipy.io.loadmat(f'../../pde_data/green_learning/data2d_{l}/poisson.mat')
# A = poisson_data['A']
# U = poisson_data['U']
F = poisson_data['F']
# G = np.linalg.inv(A.todense())
# U_ = torch.tensor(G @ F).float()
# G = torch.tensor(np.array(G).reshape(l,l,l,l)).float()
# GHH = torch.tensor(np.array(G).reshape(l,l,l,l))

# plt.imshow(U[:,40].reshape(l,l)) #, vmin=-0.005, vmax=0.005)
# plt.colorbar()

In [3]:
def kernel_func(pts_pairs):
    x1 = pts_pairs[:,0]
    y1 = pts_pairs[:,1]

    x2 = pts_pairs[:,2]
    y2 = pts_pairs[:,3]

    mask = ((x1**2+y1**2) < 1) & ((x2**2+y2**2) < 1)

    k = 1/(4*torch.pi) * torch.log(((x1 - x2)**2 + (y1-y2)**2) / ((x1*y2-x2*y1)**2 + (x1*x2+y1*y2-1)**2))
    k = torch.nan_to_num(k, neginf=-2) * mask

    return k

In [4]:
def ffunc(pts):
    x = pts[:,0]
    y = pts[:,1]
    u = (1 - (x**2+y**2))**-0.5
    u = torch.nan_to_num(u, posinf=0)
    return u

In [5]:
l = 6
k = 2
d = 2
M = 4

# fine grid
nh = 2**l+1  # num of pts on each axis
nhh = nh*nh  # num of pts in fine kernel tensor
h = 2/(nh-1) # mesh size
hh = h*h     

# coarse grid 
nH = 2**(l-1)+1
H = 2/(nH-1)
nHH = nH*nH 
HH = H*H

# fine grid pts
x_h, coords_h = grid2d_coords(nh)
x_hh, coords_hh = grid4d_coords(nh)

# fine grid kernel and input func
K_hh = kernel_func(x_hh)
f_h = torch.tensor(F[:,0]).float()
# ffunc(x_h)

# fine grid output func eval 
u_h = hh * (K_hh.reshape(nh*nh,nh*nh) @ f_h).reshape(nh,nh)
u_H_ref = injection2d(u_h[None,None])[0,0]

# coarse grid pts pairs
x_H, coords_H = grid2d_coords(nH)
x_HH, coords_HH = grid4d_coords(nH)

# coarse grid kernel and restricted input func
K_HH = kernel_func(x_HH)
f_H = restrict2d(f_h.reshape(nh, nh)[None,None])[0,0].reshape(-1)

# coarse grid output func eval 
u_H =  HH * (K_HH.reshape(nH*nH, nH*nH) @ f_H).reshape(nH,nH)
u_h_woc = interp2d(u_H[None,None])[0,0] # without correction

In [6]:
# (x, y) = (2X, 2Y)
# fetch local
IJ_coords = fetch_nbrs2d(coords_H, mx1=M, mx2=M, my1=M, my2=M)
IJ_coords = cat2d_nbr_coords(coords_H, IJ_coords) # J is I's nbrs on H
IJ_idx = coord2idx4d(IJ_coords, nH)
K_IJ = K_HH[IJ_idx]

# Local Kernel Interpolation, i = (2X, 2Y)
# here i = (2X, 2Y), therefore we use i instead of I
# interp j=(2X',2Y') to get j=(2X',y')
K_ij_xeven_yfull_ = interp1d_cols(K_IJ.permute(0,3,1,2)).permute(0,2,3,1) 
# select j=(2X',2Y'+1) from j=(2X',y')
K_ij_xeven_yodd_ = K_ij_xeven_yfull_[:,:,1::2] 
K_ij_xeven_yodd_ = K_ij_xeven_yodd_.reshape(nH*nH,-1)
# interp j=(2X',y') to get j=(2X'+1,y')
K_ij_xodd_yfull_ = (K_ij_xeven_yfull_[:,:-1] + K_ij_xeven_yfull_[:,1:])/2 
K_ij_xodd_yfull_ = K_ij_xodd_yfull_.reshape(nH*nH,-1)

# Local Kernel Evaluation, i = (2X, 2Y)
ij_xeven_yfull_coords = interp1d_cols(
    IJ_coords[...,2:].permute(0,3,1,2)*2).permute(0,2,3,1).int()
ij_xeven_yodd_coords = ij_xeven_yfull_coords[:,:,1::2]
ij_xeven_yodd_coords = cat2d_nbr_coords(coords_H*2, ij_xeven_yodd_coords)
ij_xeven_yodd_idx = coord2idx4d(ij_xeven_yodd_coords, nh)
x_ij_xeven_yodd = x_hh[ij_xeven_yodd_idx]
K_ij_xeven_yodd = kernel_func(x_ij_xeven_yodd.reshape(-1,4)).reshape(nH*nH,-1)

ij_xodd_yfull_coords = (ij_xeven_yfull_coords[:,:-1] + ij_xeven_yfull_coords[:,1:])//2
ij_xodd_yfull_coords = cat2d_nbr_coords(coords_H*2, ij_xodd_yfull_coords)
ij_xodd_yfull_idx = coord2idx4d(ij_xodd_yfull_coords, nh)
x_ij_xodd_yfull = x_hh[ij_xodd_yfull_idx]
K_ij_xodd_yfull = kernel_func(x_ij_xodd_yfull.reshape(-1,4)).reshape(nH*nH,-1)

# local fine f
j_xeven_yodd_idx = coord2idx2d(ij_xeven_yodd_coords[...,2:], nh)
j_xodd_yfull_idx = coord2idx2d(ij_xodd_yfull_coords[...,2:], nh)
f_j_xeven_yodd = f_h[j_xeven_yodd_idx].reshape(nH*nH,-1)
f_j_xodd_yfull = f_h[j_xodd_yfull_idx].reshape(nH*nH,-1)

# local correct u_h_
u_H_ = u_H.reshape(-1) + \
      hh*((K_ij_xeven_yodd - K_ij_xeven_yodd_) * f_j_xeven_yodd).sum(axis=-1) + \
      hh*((K_ij_xodd_yfull - K_ij_xodd_yfull_) * f_j_xodd_yfull).sum(axis=-1)

u_H_ = u_H_.reshape(nH,nH)
u_h_0 = interp2d(u_H_[None,None])[0,0]

In [7]:
# (x, y) != (2X, 2Y)
i_xeven_yeven_j_coords = fetch_nbrs2d(coords_H*2, mx1=2*M, mx2=2*M, my1=2*M, my2=2*M)
i_xeven_yeven_j_coords = cat2d_nbr_coords(coords_H*2, i_xeven_yeven_j_coords) # here i=(2X, 2Y), j is i's nbrs on h
i_xeven_yeven_j_idx = coord2idx4d(i_xeven_yeven_j_coords, nh)
x_i_xeven_yeven_j = x_hh[i_xeven_yeven_j_idx]
K_i_xeven_yeven_j = kernel_func(x_i_xeven_yeven_j.reshape(-1,4)).reshape(nH*nH,-1)

K_i_xeven_yeven_j = K_i_xeven_yeven_j.reshape(nH,nH,4*M+1,4*M+1)
K_i_xodd_yeven_j_ = (K_i_xeven_yeven_j[:-1] + K_i_xeven_yeven_j[1:])/2
K_i_xeven_yodd_j_ = (K_i_xeven_yeven_j[:,:-1] + K_i_xeven_yeven_j[:,1:])/2
K_i_xodd_yodd_j_ = (K_i_xeven_yodd_j_[:-1] + K_i_xeven_yodd_j_[1:])/2

K_i_xodd_yeven_j_ = K_i_xodd_yeven_j_.reshape((nH-1)*nH,-1)
K_i_xeven_yodd_j_ = K_i_xeven_yodd_j_.reshape(nH*(nH-1),-1)
K_i_xodd_yodd_j_ = K_i_xodd_yodd_j_.reshape((nH-1)*(nH-1),-1)

In [8]:
i_xeven_yeven_j_coords = i_xeven_yeven_j_coords.reshape(nH,nH,4*M+1,4*M+1,4)
i_xodd_yeven_j_coords = (i_xeven_yeven_j_coords[:-1] + i_xeven_yeven_j_coords[1:])//2
i_xeven_yodd_j_coords = (i_xeven_yeven_j_coords[:,:-1] + i_xeven_yeven_j_coords[:,1:])//2
i_xodd_yodd_j_coords = (i_xeven_yodd_j_coords[:-1] + i_xeven_yodd_j_coords[1:])//2

i_xodd_yeven_j_idx = coord2idx4d(i_xodd_yeven_j_coords, nh)
i_xeven_yodd_j_idx = coord2idx4d(i_xeven_yodd_j_coords, nh)
i_xodd_yodd_j_idx = coord2idx4d(i_xodd_yodd_j_coords, nh)

x_i_xodd_yeven_j = x_hh[i_xodd_yeven_j_idx]
x_i_xeven_yodd_j = x_hh[i_xeven_yodd_j_idx]
x_i_xodd_yodd_j = x_hh[i_xodd_yodd_j_idx]

K_i_xodd_yeven_j = kernel_func(x_i_xodd_yeven_j.reshape(-1,4)).reshape(nH*(nH-1),-1)
K_i_xeven_yodd_j = kernel_func(x_i_xeven_yodd_j.reshape(-1,4)).reshape(nH*(nH-1),-1)
K_i_xodd_yodd_j = kernel_func(x_i_xodd_yodd_j.reshape(-1,4)).reshape((nH-1)*(nH-1),-1)

In [9]:
j_idx = coord2idx2d(i_xodd_yeven_j_coords[...,2:], nh)
f_j = f_h[j_idx]
f_j = rearrange(f_j, 'm n x y c -> (m n) (x y c)')

u_h_1 = u_h_0.clone()
u_h_1[1:-1:2,::2] = u_h_1[1:-1:2,::2] + \
      hh*((K_i_xodd_yeven_j - K_i_xodd_yeven_j_) * f_j).sum(axis=-1).reshape(nH-1, nH)

j_idx = coord2idx2d(i_xeven_yodd_j_coords[...,2:], nh)
f_j = f_h[j_idx]
f_j = rearrange(f_j, 'm n x y c -> (m n) (x y c)')

u_h_2 = u_h_1.clone()
u_h_2[::2, 1:-1:2] = u_h_2[::2, 1:-1:2] + \
      hh*((K_i_xeven_yodd_j - K_i_xeven_yodd_j_) * f_j).sum(axis=-1).reshape(nH, nH-1)

j_idx = coord2idx2d(i_xodd_yodd_j_coords[...,2:], nh)
f_j = f_h[j_idx]
f_j = rearrange(f_j, 'm n x y c -> (m n) (x y c)')

u_h_3 = u_h_2.clone()
u_h_3[1:-1:2, 1:-1:2] = u_h_3[1:-1:2, 1:-1:2] + \
      hh*((K_i_xodd_yodd_j - K_i_xodd_yodd_j_) * f_j).sum(axis=-1).reshape(nH-1, nH-1)

In [10]:
print("m {:} : {:.4e} ".format(
    2*M, matrl2_error(u_h_woc, u_h).numpy()))

print("m {:} : {:.4e} ".format(
    2*M, matrl2_error(u_h_0, u_h).numpy()))

print("m {:} : {:.4e} ".format(
    2*M, matrl2_error(u_h_1, u_h).numpy()))

print("m {:} : {:.4e} ".format(
    2*M, matrl2_error(u_h_2, u_h).numpy()))

print("m {:} : {:.4e} ".format(
    2*M, matrl2_error(u_h_3, u_h).numpy()))

m 8 : 7.9922e-02 
m 8 : 1.2953e-02 
m 8 : 1.1170e-02 
m 8 : 9.7552e-03 
m 8 : 7.7303e-03 
