In [1]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch

from utils import matrl2_error, rl2_error
from ops import injection2d, injection4d 
from ops import interp2d, interp1d_cols, interp1d_rows
from ops import restrict2d
from ops import fetch_nbrs2d #, fetch_nbrs4d
from ops import coord2idx2d, coord2idx4d
from ops import cat2d_nbr_coords
from ops import grid2d_coords, grid4d_coords
from einops import rearrange

In [2]:
l = 129
poisson_data = scipy.io.loadmat(f'../pde_data/green_learning/data2d_{l}/poisson.mat')
F = poisson_data['F']

In [3]:
def kernel_func(pts_pairs):
    x1 = pts_pairs[:,0]
    y1 = pts_pairs[:,1]

    x2 = pts_pairs[:,2]
    y2 = pts_pairs[:,3]

    mask = ((x1**2+y1**2) < 1) & ((x2**2+y2**2) < 1)

    k = 1/(4*torch.pi) * torch.log(((x1 - x2)**2 + (y1-y2)**2) / ((x1*y2-x2*y1)**2 + (x1*x2+y1*y2-1)**2))
    k = torch.nan_to_num(k, neginf=-2) * mask

    return k

In [4]:
def ffunc(pts):
    x = pts[:,0]
    y = pts[:,1]
    u = (1 - (x**2+y**2))**-0.5
    u = torch.nan_to_num(u, posinf=0)
    return u

In [8]:
class Grid2D:
    def __init__(self, n, m):
        self.n = n 
        self.m = m
        self.nh = 2**n + 1
        self.init_grid()
        self.fetch_nbrs()
        self.fetch_fine_even()
        self.fetch_fine_odd()
        self.h = self.x_h[1] - self.x_h[0]
    
    def init_grid(self):
        x_h, coords_h = grid2d_coords(self.nh)
        x_hh, coords_hh = grid4d_coords(self.nh)
        self.x_h = x_h 
        self.x_hh = x_hh 
        self.coords_h = coords_h 
        self.coords_hh = coords_hh
    
    def fetch_nbrs(self):
        ij_coords = fetch_nbrs2d(
            self.coords_h, mx1=self.m, mx2=self.m,
            my1=self.m, my2=self.m)
        
        self.ij_coords = cat2d_nbr_coords(
            self.coords_h, ij_coords)
        
        self.ij_idx = coord2idx4d(self.ij_coords, self.nh)
        self.x_ij = torch.squeeze(self.x_hh[self.ij_idx])
    
    def fetch_fine_even(self):
        coords_H = self.coords_h
        IJ_coords = self.ij_coords
        ij_xeven_yfull_coords = interp1d_cols(
            IJ_coords[...,2:].permute(0,3,1,2)*2).permute(0,2,3,1).int()
        ij_xeven_yodd_coords = ij_xeven_yfull_coords[:,:,1::2]
        ij_xeven_yodd_coords = cat2d_nbr_coords(coords_H*2, ij_xeven_yodd_coords)
        ij_xeven_yodd_idx = coord2idx4d(ij_xeven_yodd_coords, self.nh)
        x_ij_xeven_yodd = x_hh[ij_xeven_yodd_idx]

In [17]:
n = 7 # total level
k = 2 # coarse level
d = 2 # problem dimension
M = 2 # local range

In [18]:
# multi grid
ml_grids = []

for l in range(k+1):
    ml_grids.append(Grid2D(n))
    n = (n-1)//2+1
    

17