In [1]:
%matplotlib inline
import importlib
if importlib.util.find_spec("matplotlib_inline") is not None:
    import matplotlib_inline
    matplotlib_inline.backend_inline.set_matplotlib_formats('png')
else:
    from IPython.display import set_matplotlib_formats
    set_matplotlib_formats('png')

In [2]:
import numpy as np
import torch
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorboardX, datetime
import pandas as pd
from scipy.special import gamma
from torch import nn
from scipy.stats import special_ortho_group
from collections import defaultdict
from collections import OrderedDict as odict
from copy import deepcopy
from matplotlib.ticker import EngFormatter
import glob

In [3]:
show_figs = False
save_figs = False
plt.rcParams["axes.grid"] = False
! mkdir -p figures

## Theory

Consider the $d$-dimensional space $\mathbb{R}^{d}$, and the following charge:

$$\rho(x) = \delta^d(x).$$

For $d \neq 2$ The analytical solution to the system

$$\nabla \cdot \vec{E} = \rho$$

$$\nabla V = \vec{E}$$

can be defined as 

$$V_{\vec{x}} = \frac{\Gamma(d/2)}{2\cdot\pi^{d/2}\cdot (2-d)} \|\vec{x}\|^{2-d}, $$

$$\vec{E}_{\vec{x}} = \frac{\Gamma(d/2)}{2\cdot \pi^{d/2}\cdot \|\vec{x}\|^{d}} \vec{x}.$$

For $d=2$, $\vec{E}_{\vec{x}}$ is the same, but for $V_{\vec{x}}$ we have

$$V_{\vec{x}} = \frac{1}{2\pi} \ln(\|\vec{x}\|).$$

We want to solve this system using the divergence theorem:

$$\iint_{S_{d-1}(V)} \vec{E}\cdot \hat{n}\text{ d}S = \iiint_{V_d} \nabla.\vec{E}\text{ d}V.$$

Keep in mind that the $d-1$-dimensional surface of a $d$-dimensional shpere with radius $r$ is 
$$\iint_{S_{d-1}(V^{\text{d-Ball}}_{r})} 1\text{ d}S = \frac{2\cdot \pi^{d/2}}{\Gamma(d/2)}\cdot r^{d-1}.$$

In [4]:
device_name = 'cuda:0'
tch_device = torch.device(device_name)
tch_dtype = torch.double
torch.pi = torch.tensor(np.pi).to(device=tch_device, dtype=tch_dtype)

def isscalar(v):
    if torch.is_tensor(v):
        return v.numel() == 1
    else:
        return np.isscalar(v)

## Defining the Problem and the Analytical Solution

In [5]:
class DeltaProblem:
    def __init__(self, weights, locations):
        # weights          -> np.array -> shape=(N,)
        # locations.shape  -> np.array -> shape=(N,d)
        self.weights = weights
        self.locations = locations
        self.n = self.weights.size
        self.d = self.locations.shape[1]
        assert self.weights.shape   == (self.n,)
        assert self.locations.shape == (self.n, self.d)
        self.weights_tch = torch.from_numpy(self.weights).to(tch_device, tch_dtype)
        self.locations_tch = torch.from_numpy(self.locations).to(tch_device, tch_dtype)
    
    def integrate_volumes(self, volumes):
        # volumes -> dictionary
        assert volumes['type'] == 'balls'
        centers = volumes['centers']
        radii = volumes['radii']
        N_v = centers.shape[0]
        N_mu, d = self.n, self.d
        assert radii.shape == (N_v,)
        lib = torch if torch.is_tensor(centers) else np 
        mu = self.locations_tch if torch.is_tensor(centers) else self.locations
        w = self.weights_tch if torch.is_tensor(centers) else self.weights
        
        c_diff_mu = centers.reshape(N_v, 1, d) - mu.reshape(1, N_mu, d)
        assert c_diff_mu.shape == (N_v, N_mu, d)
        distl2 = lib.sqrt(lib.square(c_diff_mu).sum(-1))
        assert distl2.shape == (N_v, N_mu)
        integ = ((distl2 < radii.reshape(N_v, 1)) * w.reshape(1, N_mu)).sum(-1)
        assert integ.shape == (N_v,)
        return integ
    
    def potential(self, x):
        lib = torch if torch.is_tensor(x) else np 
        w = self.weights_tch if torch.is_tensor(x) else self.weights
        mu = self.locations_tch if torch.is_tensor(x) else self.locations
        N_mu, d = self.n, self.d
        N_x = x.shape[0]
        assert x.shape == (N_x, d)
        x_diff_mu = x.reshape(N_x, 1, d) - mu.reshape(1, N_mu, d)
        assert x_diff_mu.shape == (N_x, N_mu, d)
        x_dists = lib.sqrt(lib.square(x_diff_mu).sum(-1))
        assert x_dists.shape == (N_x, N_mu)
        if d != 2:
            poten1 = (x_dists**(2-d))
            assert poten1.shape == (N_x, N_mu)
            poten2 = (poten1 * w.reshape(1, N_mu)).sum(-1)
            assert poten2.shape == (N_x,)
            cst = gamma(d/2) / (2*(lib.pi**(d/2)))
            cst = cst / (2-d)
            assert isscalar(cst)
            poten = cst * poten2
            assert poten.shape == (N_x,)
        else:
            poten1 = lib.log(x_dists)
            assert poten1.shape == (N_x, N_mu)
            poten2 = (poten1 * w.reshape(1, N_mu)).sum(-1)
            assert poten2.shape == (N_x,)
            poten = poten2 / (2*lib.pi)
            assert poten.shape == (N_x,)
        return poten
    
    def field(self, x):
        lib = torch if torch.is_tensor(x) else np 
        w = self.weights_tch if torch.is_tensor(x) else self.weights
        mu = self.locations_tch if torch.is_tensor(x) else self.locations
        N_mu, d = self.n, self.d
        N_x = x.shape[0]
        assert x.shape == (N_x, d)
        x_diff_mu = x.reshape(N_x, 1, d) - mu.reshape(1, N_mu, d)
        assert x_diff_mu.shape == (N_x, N_mu, d)
        x_dists = lib.sqrt(lib.square(x_diff_mu).sum(-1))
        assert x_dists.shape == (N_x, N_mu)
        poten1 = (x_dists**(-d))
        assert poten1.shape == (N_x, N_mu)
        poten2 = (poten1 * w.reshape(1, N_mu)).sum(-1)
        assert poten2.shape == (N_x,)
        cst = gamma(d/2) / (2*(lib.pi**(d/2)))
        assert isscalar(cst)
        poten = cst * poten2
        assert poten.shape == (N_x,)
        field = poten.reshape(N_x, 1) * x
        assert field.shape == (N_x, d)
        return field

## Visualizing the True Potential and Fields

In [6]:
x1_1d = torch.linspace(-1.0, 1.0, 250, requires_grad=True, dtype=tch_dtype, device=tch_device)
x2_1d = torch.linspace(-1.0, 1.0, 250, requires_grad=True, dtype=tch_dtype, device=tch_device)
x1_msh, x2_msh = torch.meshgrid(x1_1d, x2_1d)
x1 = x1_msh.reshape(-1, 1)
x2 = x2_msh.reshape(-1, 1)
x1_1d_c = x1_1d.reshape(-1, 1)
x2_1d_c = x2_1d.reshape(-1, 1)
x1_msh_np = x1_msh.detach().cpu().numpy()
x2_msh_np = x2_msh.detach().cpu().numpy()
x = torch.cat([x1, x2], dim=1)
x_np = x.detach().cpu().numpy()
x_plt, x_plt_np = x, x_np
x1_plt_msh_np, x2_plt_msh_np = x1_msh_np, x2_msh_np

In [7]:
def do_plot(x1_msh_np, x2_msh_np, v_msh_np, e_msh_np=None, e_percentile_cap=None, 
            dpi=72, fig_ax=None, vec_ss=None, cnorm=None, print_colorbar=True,
            cmap='RdBu'):
    plt.ioff()
    if fig_ax is None:
        fig = plt.figure(dpi=dpi)
        ax = plt.gca()
    else:
        fig, ax = fig_ax
    
    v_msh_np_ = v_msh_np - v_msh_np.mean()
    
    mappable = mpl.cm.ScalarMappable(norm=cnorm, cmap=cmap)
    im = ax.pcolormesh(x1_msh_np, x2_msh_np, v_msh_np_, shading='auto', 
                       norm=cnorm, cmap=cmap, linewidth=0, rasterized=True)
    if print_colorbar:
        fig.colorbar(mappable if cnorm is not None else im, ax=ax)
    
    if e_msh_np is not None:
        if e_percentile_cap is not None:
            e_size = np.sqrt((e_msh_np**2).sum(axis=-1))
            e_size_cap = np.percentile(a=e_size, q=e_percentile_cap, axis=None)
            cap_coef = np.ones_like(e_size)
            cap_coef[e_size > e_size_cap] = e_size_cap / e_size[e_size > e_size_cap]
            e_msh_capped = e_msh_np * cap_coef.reshape(*e_msh_np.shape[:-1], 1)
        else:
            e_msh_capped = e_msh_np
        
        if vec_ss is None:
            x1_msh_np_q, x2_msh_np_q, e_msh_capped_q = x1_msh_np, x2_msh_np, e_msh_capped
        else:
            assert isinstance(vec_ss, int)
            x1_msh_np_q = x1_msh_np[::vec_ss, ::vec_ss]
            x2_msh_np_q = x2_msh_np[::vec_ss, ::vec_ss]
            e_msh_capped_q = e_msh_capped[::vec_ss, ::vec_ss, :]
        ax.quiver(x1_msh_np_q, x2_msh_np_q, e_msh_capped_q[:, :, 0], e_msh_capped_q[:, :, 1])
    return fig, ax

In [8]:
prob2d_ex1 = DeltaProblem(weights=np.array([1.0]),
                         locations=np.array([[ 0.0,  0.0]]))

prob2d_ex1 = DeltaProblem(weights=np.array([1.0, 1.0, 1.0]),
                         locations=np.array([[ 0.0,  0.0],
                                             [-0.5, -0.5],
                                             [ 0.5,  0.5]]))

In [9]:
v_ex = prob2d_ex1.potential(x_plt_np)
e_ex = prob2d_ex1.field(x_plt_np)

fig, ax = do_plot(x1_plt_msh_np, x2_plt_msh_np, 
                  v_ex.reshape(*x1_plt_msh_np.shape),
                  e_ex.reshape(*x1_plt_msh_np.shape, 2),
                  e_percentile_cap=90, dpi=36, vec_ss=3)

if show_figs:
    plt.show(fig)

### Defining the Volume Sampler

In [10]:
class BallSampler:
    def __init__(self, x_min, x_max, r_min, r_max):
        self.np_random = None
        self.tch_random = None
        self.d = x_min.size
        self.x_min = x_min.reshape(1, self.d)
        self.x_max = x_max.reshape(1, self.d)
        self.r_min = r_min
        self.r_max = r_max
        self.x_size = (self.x_max - self.x_min)
        self.r_size = (self.r_max - self.r_min)
        
        self.x_min_tch = torch.from_numpy(self.x_min).to(device=tch_device, dtype=tch_dtype)
        self.x_max_tch = torch.from_numpy(self.x_max).to(device=tch_device, dtype=tch_dtype)
        self.r_min_tch = torch.tensor(self.r_min).to(device=tch_device, dtype=tch_dtype)
        self.r_max_tch = torch.tensor(self.r_max).to(device=tch_device, dtype=tch_dtype)
        self.x_size_tch = torch.from_numpy(self.x_size).to(device=tch_device, dtype=tch_dtype)
        self.r_size_tch = torch.tensor(self.r_size).to(device=tch_device, dtype=tch_dtype)
    
    def seed(self, seed=None):
        self.tch_random = torch.Generator(device=tch_device)
        self.np_random = np.random.RandomState(seed=seed)
        tch_seed = self.np_random.randint(0, 0x0fff_ffff_ffff_ffff)
        self.tch_random.manual_seed(tch_seed)
        return seed
    
    def __call__(self, n=1, lib='numpy'):
        if lib == 'numpy':
            radii = self.np_random.uniform(self.r_min, self.r_max, size=(n,))
            centers = self.np_random.uniform(0.0, 1.0, size=(n, self.d))
            centers = centers * self.x_size + self.x_min
        elif lib == 'torch':
            radii = torch.empty(n, device=tch_device, dtype=tch_dtype)
            radii = radii.uniform_(generator=self.tch_random) * self.r_size_tch + self.r_min_tch
            centers = torch.empty(n, self.d, device=tch_device, dtype=tch_dtype)
            centers = centers.uniform_(generator=self.tch_random) * self.x_size_tch + self.x_min_tch
        else:
            raise RuntimeError('Not implemented!')
        d = dict()
        d['type'] = 'balls'
        d['centers'] = centers
        d['radii'] = radii
        return d

### Visualizing the Sampler and Integrator

In [11]:
prob2d_ex2 = DeltaProblem(weights=np.array([1.0, 1.0, 1.0]),
                          locations=np.array([[ 0.0,  0.0],
                                              [-0.5, -0.5],
                                              [ 0.5,  0.5]]))

volsampler_2d = BallSampler(x_min=np.array([-1.0, -1.0]), x_max=np.array([1.0, 1.0]), r_min=0.1, r_max=1.5)
volsampler_2d.seed(12345)

vols = volsampler_2d(n=10, lib='torch')
integs = prob2d_ex2.integrate_volumes(vols)
for key, val in vols.items():
    if torch.is_tensor(val):
        vols[key] = val.detach().cpu().numpy()

In [12]:
fig = plt.figure(dpi=36)
ax = plt.gca()

max_integ = prob2d_ex2.weights[prob2d_ex2.weights > 0].sum()
min_integ = prob2d_ex2.weights[prob2d_ex2.weights < 0].sum()
cmap = mpl.cm.get_cmap('RdBu')
cnorm = mpl.colors.Normalize(vmin=min_integ, vmax=max_integ)

ax.scatter(prob2d_ex2.locations[:,0], prob2d_ex2.locations[:,1], marker='*', color='black', s=150)
for center, radius, integ in zip(vols['centers'], vols['radii'], integs):
    circle = plt.Circle(center, radius, fill=False, 
                        color=cmap(1.0-cnorm(integ.item())))
    ax.add_patch(circle)
ax.set_aspect('equal', adjustable='box')

if show_figs:
    plt.show(fig)

### Sphere Sampling

In [13]:
class SphereSampler:
    def __init__(self):
        self.np_random = None
        self.tch_random = None
    
    def seed(self, seed=None):
        self.tch_random = torch.Generator(device=tch_device)
        self.np_random = np.random.RandomState(seed=seed)
        tch_seed = self.np_random.randint(0, 0x0fff_ffff_ffff_ffff)
        self.tch_random.manual_seed(tch_seed)
        return seed
    
    def np_exlinspace(self, start, end, n):
        assert n >= 1
        a = np.linspace(start, end, n, endpoint=False) 
        b = a + 0.5 * (end - a[-1])
        return b
    
    def tch_exlinspace(self, start, end, n):
        assert n >= 1
        a = torch.linspace(start, end, n+1, device=tch_device, dtype=tch_dtype)[:-1] 
        b = a + 0.5 * (end - a[-1])
        return b
    
    def __call__(self, volumes, n, do_detspacing=True):
        # volumes -> dictionary
        assert volumes['type'] == 'balls'
        centers = volumes['centers']
        radii = volumes['radii']
        N_v, d = centers.shape
        assert centers.shape == (N_v, d)
        assert radii.shape == (N_v,)
        use_np = not torch.is_tensor(centers)
        exlinspace = self.np_exlinspace if use_np else self.tch_exlinspace
        meshgrid = np.meshgrid if use_np else torch.meshgrid
        sin = np.sin if use_np else torch.sin
        cos = np.cos if use_np else torch.cos
        matmul = np.matmul if use_np else torch.matmul
        
        if do_detspacing and (d == 2):
            theta = exlinspace(0.0, 2*np.pi, n)
            assert theta.shape == (n,)
            theta_2d = theta.reshape(n, 1)
            x_tilde_2d_list = [cos(theta_2d), sin(theta_2d)]
            if use_np:
                x_tilde_2d = np.concatenate(x_tilde_2d_list, axis=1)
            else:
                x_tilde_2d = torch.cat(x_tilde_2d_list, dim=1)
            assert x_tilde_2d.shape == (n ,d)
            x_tilde = x_tilde_2d.reshape(1, n, d)
            assert x_tilde.shape == (1, n ,d)
        elif do_detspacing and (d == 3):
            n_sqrt = int(np.sqrt(n))
            assert n == n_sqrt * n_sqrt, 'Need n to be int-square for now!'
            theta_1d = exlinspace(0.0, 2*np.pi, n_sqrt)
            unit_unif = exlinspace(0.0, 1.0, n_sqrt)
            if use_np:
                phi_1d = np.arccos(1-2*unit_unif)
            else:
                phi_1d = torch.arccos(1-2*unit_unif)
            theta_msh, phi_msh = meshgrid(theta_1d, phi_1d)
            assert theta_msh.shape == (n_sqrt, n_sqrt)
            assert phi_msh.shape == (n_sqrt, n_sqrt)
            theta_2d, phi_2d = theta_msh.reshape(n, 1), phi_msh.reshape(n, 1)
            assert theta_2d.shape == (n, 1)
            assert phi_2d.shape == (n, 1)
            x_tilde_lst = [sin(phi_2d) * cos(theta), sin(phi_2d) * sin(theta), cos(phi_2d)]
            if use_np:
                x_tilde_2d = np.concatenate(x_tilde_lst, axis=1)
            else:
                x_tilde_2d = torch.cat(x_tilde_lst, dim=1)
            assert x_tilde_2d.shape == (n ,d)
            x_tilde = x_tilde_2d.reshape(1, n, d)
            assert x_tilde.shape == (1, n ,d)
        elif (not do_detspacing) and use_np:
            x_tilde_unnorm = self.np_random.randn(N_v, n, d)
            x_tilde_l2 = np.sqrt(torch.square(x_tilde_unnorm).sum(axis=-1))
            x_tilde = x_tilde_unnorm / x_tilde_l2.reshape(N_v, n, 1)
            assert x_tilde.shape == (N_v, n ,d)
        elif (not do_detspacing) and (not use_np):
            x_tilde_unnorm = torch.empty(N_v, n, d, device=tch_device, dtype=tch_dtype)
            x_tilde_unnorm = x_tilde_unnorm.normal_(generator=self.tch_random)
            x_tilde_l2 = torch.sqrt(torch.square(x_tilde_unnorm).sum(dim=-1))
            x_tilde = x_tilde_unnorm / x_tilde_l2.reshape(N_v, n, 1)
            assert x_tilde.shape == (N_v, n ,d)
        else:
            raise RuntimeError('Not implemented yet!')
            
        rot_mats_np = special_ortho_group.rvs(dim=d, size=N_v, random_state=self.np_random)
        if use_np:
            rot_mats = rot_mats_np
        else:
            rot_mats = torch.from_numpy(rot_mats_np).to(device=tch_device, dtype=tch_dtype)
        assert rot_mats.shape == (N_v, d, d)
        
        x_tilde_rot = matmul(x_tilde, rot_mats)
        assert x_tilde_rot.shape == (N_v, n, d)
        
        points = x_tilde_rot * radii.reshape(N_v, 1, 1) + centers.reshape(N_v, 1, d)
        assert points.shape == (N_v, n, d)
        
        if use_np:
            x_tilde_bc = np.broadcast_to(x_tilde, (N_v, n, d))
        else:
            x_tilde_bc = x_tilde.expand(N_v, n, d)
        rot_x_tilde = matmul(x_tilde_bc, rot_mats)
        assert rot_x_tilde.shape == (N_v, n, d)
        
        cst = (2*(np.pi**(d/2))) / gamma(d/2)
        csts = cst * (radii**(d-1))
        assert csts.shape == (N_v,)
        
        ret_dict = dict(points=points, normals=rot_x_tilde, areas=csts)
        return ret_dict

In [14]:
prob2d_ex3 = DeltaProblem(weights=np.array([1.0, 1.0, 1.0]),
                          locations=np.array([[ 0.0,  0.0],
                                              [-0.5, -0.5],
                                              [ 0.5,  0.5]]))

volsampler_2d = BallSampler(x_min=np.array([-1.0, -1.0]), x_max=np.array([1.0, 1.0]), r_min=0.1, r_max=1.5)
volsampler_2d.seed(12345)

sphsampler_2d = SphereSampler()
sphsampler_2d.seed(12345)

vols = volsampler_2d(n=10, lib='torch')
sphsamps2d = sphsampler_2d(vols, 100, do_detspacing=True)
points = sphsamps2d['points']
surfacenorms = sphsamps2d['normals']
if torch.is_tensor(points):
    points = points.detach().cpu().numpy()
if torch.is_tensor(surfacenorms):
    surfacenorms = surfacenorms.detach().cpu().numpy()
points.shape, surfacenorms.shape

((10, 100, 2), (10, 100, 2))

In [15]:
fig = plt.figure(dpi=36)
ax = plt.gca()

max_integ = prob2d_ex2.weights[prob2d_ex3.weights > 0].sum()
min_integ = prob2d_ex2.weights[prob2d_ex3.weights < 0].sum()
cmap = mpl.cm.get_cmap('RdBu')
cnorm = mpl.colors.Normalize(vmin=min_integ, vmax=max_integ)

ax.scatter(prob2d_ex3.locations[:,0], prob2d_ex3.locations[:,1], marker='*', color='black', s=150)
for pnts, srfnrms, center, radius, integ in zip(points, surfacenorms, vols['centers'], vols['radii'], integs):
    ax.scatter(pnts[:,0], pnts[:,1], marker='o', color=cmap(1.0-cnorm(integ.item())), s=1)
    ax.quiver(pnts[:,0], pnts[:,1], srfnrms[:, 0], srfnrms[:, 1], width=0.002)
ax.set_aspect('equal', adjustable='box')

if show_figs:
    plt.show(fig)

### Defining the Problem

In [16]:
# Set the problem
d = 2
problem = DeltaProblem(weights=np.array([1.0, 1.0, 1.0]),
                       locations=np.array([[ 0.0,  0.0],
                                           [-0.5, -0.5],
                                           [ 0.5,  0.5]]))

volsampler = BallSampler(x_min=np.array([-1.0, -1.0]), x_max=np.array([1.0, 1.0]), r_min=0.1, r_max=1.5)
sphsampler = SphereSampler()

### Function Approximation

In [17]:
class ffnn(nn.Module):
    """basic FF network for approximating functions"""
    def __init__(self, inp_width=2, nn_width=10, num_hidden=2):
        super().__init__()
        
        self.layer_first = nn.Linear(inp_width, nn_width).to(device=tch_device, dtype=tch_dtype)
        
        layers = []
        for _ in range(num_hidden):
            layers.append(nn.Linear(nn_width, nn_width).to(device=tch_device, dtype=tch_dtype))
        self.layer_hidden = nn.ModuleList(layers)
        
        self.layer_last = nn.Linear(nn_width, 1).to(device=tch_device, dtype=tch_dtype)
        
    def forward(self, x):
        activation = nn.SiLU()
        u = activation(self.layer_first(x))
        for hidden in self.layer_hidden:
            u = activation(hidden(u))
        u = self.layer_last(u)
        return u

### Heatmap with MSE Losses

In [18]:
nn_width, nn_hidden = 64, 2
model = ffnn(d, nn_width, nn_hidden)

cmap = 'RdBu'

cnorm = mpl.colors.Normalize(vmin=-0.75, vmax=0.3, clip=False)
cnorm = None

ema_gamma = 0.999

def ema(x_np):
    x = x_np.tolist()
    y = np.zeros_like(x_np)
    lasty = y[0] = x[0]
    for i in range(1, y.size):
        y[i] = lasty = lasty * ema_gamma + (1.0-ema_gamma) * x[i]
    return y

In [19]:
lowvardir = f'01_poisson/01_lowvar_mse_*'
highvardir = f'01_poisson/02_highvar_mse_*'
lowvarckpts = f'{lowvardir}/checkpoints.pt'
highvarckpts = f'{highvardir}/checkpoints.pt'

fig, axes = plt.subplots(1, 3, figsize=(6.5, 1.9), dpi=72)
cmappable = mpl.cm.ScalarMappable(norm=cnorm, cmap=cmap)

v_true = problem.potential(x_plt_np)
e_true = problem.field(x_plt_np)
showvecs = False

fig, ax = do_plot(x1_plt_msh_np, x2_plt_msh_np, 
                  v_true.reshape(*x1_plt_msh_np.shape),
                  e_true.reshape(*x1_plt_msh_np.shape, 2) if showvecs else None, 
                  e_percentile_cap=90, fig_ax=(fig, axes[0]), cnorm=cnorm,
                  vec_ss=2, print_colorbar=False)
ax.set_title('Ground Truth')

loopvars = [(lowvarckpts, axes[1], 'Low Var. Training'), 
            (highvarckpts, axes[2], 'High Var. Training')]
for ckptglobpath, ax, ax_title in loopvars:
    ckptspath = glob.glob(ckptglobpath)[0]
    ckptweights = torch.load(ckptspath)

    ckptiters = sorted(ckptweights.keys())
    mdlsdcpu = ckptweights[max(ckptiters)]
    mdlsd = {key:val.to(device=tch_device) for key, val in mdlsdcpu.items()}
    model.load_state_dict(mdlsd)

    points_plt = nn.Parameter(x_plt.unsqueeze(0))
    v_pred = model(points_plt)
    e_pred, = torch.autograd.grad(v_pred.sum(), [points_plt], grad_outputs=None, retain_graph=False,
                                  create_graph=False, only_inputs=True, allow_unused=False)
    v_pred_np = v_pred.squeeze().detach().cpu().numpy()
    e_pred_np = e_pred.squeeze().detach().cpu().numpy()
    
    fig, ax = do_plot(x1_plt_msh_np, x2_plt_msh_np, 
                      v_pred_np.reshape(*x1_plt_msh_np.shape),
                      e_pred_np.reshape(*x1_plt_msh_np.shape, 2) if showvecs else None, 
                      e_percentile_cap=90, fig_ax=(fig, ax), cnorm=cnorm,
                      print_colorbar=False, vec_ss=2)
    ax.set_title(ax_title)
    ax.set_yticks([])
    ax.set_yticks([], minor=True)

if cnorm is not None:
    fig.colorbar(cmappable, ax=axes)

if save_figs:
    fig.savefig('./02_poisson/msegt_heatmap.pdf', dpi=200, bbox_inches="tight")
    
if show_figs:
    plt.show(fig)

### MSE Training Curves

In [20]:
lowvarcsvpath = glob.glob(f'{lowvardir}/progress.csv')[0]
highvarcsvpath = glob.glob(f'{highvardir}/progress.csv')[0]
dflowvar = pd.read_csv(lowvarcsvpath)
dfhighvar = pd.read_csv(highvarcsvpath)

fig, axes = plt.subplots(1, 2, figsize=(5.1, 2.1), dpi=72, sharex=True, sharey=True)
    
ax = axes[0]
ax.plot(dflowvar['epoch'], ema(dflowvar['loss']), color='blue', lw=2, label='Low Var')
ax.plot(dfhighvar['epoch'], ema(dfhighvar['loss']), color='red', lw=2, label='high Var')

ax.set_xticks(np.linspace(0, 200_000, 5))
engfmt = EngFormatter(sep='')
ax.xaxis.set_major_formatter(engfmt)
ax.set_xlabel(f'Epoch')

ax.set_yscale('log', base=10)
ax.set_yticks([0.01, 0.03, 0.1, 0.3, 1, 3])
ax.set_ylim(0.005, 5)
ax.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
ax.set_ylabel(f'Training Loss')

ax.annotate('N=2', xy=(100_000, 0.25), xytext=(140_000, 1.0),
            arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=-120"))

ax.annotate('N=100', xy=(100_000, 0.035), xytext=(140_000, 0.008),
            arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=120"))

ax = axes[1]
ax.plot(dflowvar['epoch'], ema(dflowvar['npvm']), color='blue', lw=2, label='Low Var')
ax.plot(dfhighvar['epoch'], ema(dfhighvar['npvm']), color='red', lw=2, label='high Var')

ax.set_xticks(np.linspace(0, 200_000, 5))
engfmt = EngFormatter(sep='')
ax.xaxis.set_major_formatter(engfmt)
ax.set_xlabel(f'Epoch')
ax.set_ylabel(f'Integration Variance')

ax.annotate('N=100', xy=(100_000, 1.6), xytext=(140_000, 2.5),
            arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=-150"))

ax.annotate('N=2', xy=(100_000, 0.26), xytext=(140_000, 0.1),
            arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=140"))

fig.set_tight_layout(True)

if save_figs:
    fig.savefig('./02_poisson/loss_vs_epoch_mse.pdf', dpi=200, bbox_inches="tight")

if show_figs:
    plt.show(fig)

### Heatmap with the Bootstrapped Loss

In [21]:
showvecs = False
loopvars = [('03_highvar_bstrap_*', 2000, '_diverged'),
            ('04_highvar_bstrap_*', 199000, '_lowq'),
            ('06_highvar_bstrap_*', 183000, '')]

for dirname, ckptiter, postfix in loopvars: 
    mainckpts = f'01_poisson/{dirname}/checkpoints.pt'
    trgtckpts = f'01_poisson/{dirname}/checkpoints_trg.pt'

    fig, axes = plt.subplots(1, 2, figsize=(4.1, 1.9), dpi=72)
    cmappable = mpl.cm.ScalarMappable(norm=cnorm, cmap=cmap)


    loopvars2 = [(mainckpts, axes[0], 'Main Model'), 
                (trgtckpts, axes[1], 'Target Model')]
    for i, (ckptglobpath, ax, ax_title) in enumerate(loopvars2):
        ckptspath = glob.glob(ckptglobpath)[0]
        ckptweights = torch.load(ckptspath)

        ckptiters = sorted(ckptweights.keys())
        mdlsdcpu = ckptweights[max(ckptiters) if ckptiter == 'max' else ckptiter]
        mdlsd = {key:val.to(device=tch_device) for key, val in mdlsdcpu.items()}
        model.load_state_dict(mdlsd)

        points_plt = nn.Parameter(x_plt.unsqueeze(0))
        v_pred = model(points_plt)
        e_pred, = torch.autograd.grad(v_pred.sum(), [points_plt], grad_outputs=None, retain_graph=False,
                                      create_graph=False, only_inputs=True, allow_unused=False)
        v_pred_np = v_pred.squeeze().detach().cpu().numpy()
        e_pred_np = e_pred.squeeze().detach().cpu().numpy()

        fig, ax = do_plot(x1_plt_msh_np, x2_plt_msh_np, 
                          v_pred_np.reshape(*x1_plt_msh_np.shape),
                          e_pred_np.reshape(*x1_plt_msh_np.shape, 2) if showvecs else None, 
                          e_percentile_cap=90, fig_ax=(fig, ax), cnorm=cnorm,
                          print_colorbar=False, vec_ss=5)
        ax.set_title(ax_title)
        if i > 0:
            ax.set_yticks([])
            ax.set_yticks([], minor=True)

    if cnorm is not None:
        fig.colorbar(cmappable, ax=axes)

    if save_figs:
        fig.savefig(f'./02_poisson/bstrap_heatmap{postfix}.pdf', dpi=200, bbox_inches="tight")

if show_figs:
    plt.show(fig)

### MSE Training Curves

In [22]:
for dirname, ckptiter, postfix in loopvars: 
    maincsvpath = glob.glob(f'01_poisson/{dirname}/progress.csv')[0]
    dfmain = pd.read_csv(maincsvpath)

    fig, axes = plt.subplots(1, 1, figsize=(2.5, 2.6), dpi=72, sharex=True, sharey=True)

    ema_gamma = 0.999

    ax = axes
    epochs, losses = dfmain['epoch'], ema(dfmain['loss'])
    if postfix in ('', '_lowq'):
        epochs, losses = epochs[::1000], losses[::1000]
    ax.plot(epochs, losses, color='green', lw=2, label='Bootstrapped')
    if postfix in ('', '_lowq'):
        ax.plot(dfhighvar['epoch'][::1000], ema(dfhighvar['loss'])[::1000], 
                color='red', lw=1., ls='--', label='high Var')
        ax.plot(dflowvar['epoch'][::1000], ema(dflowvar['loss'])[::1000], 
                color='blue', lw=1., ls='--', label='Low Var')

    if postfix in ('_diverged',):
        ax.set_xticks(np.linspace(0, 10_000, 6))
    else:
        ax.set_xticks(np.linspace(0, 200_000, 5))
    engfmt = EngFormatter(sep='')
    ax.xaxis.set_major_formatter(engfmt)
    ax.set_xlabel(f'Epoch')

    ax.set_yscale('log', base=10)
    if postfix not in ('_diverged',):
        ax.set_yticks([0.01, 0.03, 0.1, 0.3, 1, 3])
        ax.set_ylim(0.005, 5)
        ax.yaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_ylabel(f'Training Loss')

    if postfix == '':
        ax.annotate('N=2', xy=(40_000, 1.3), xytext=(80_000, 0.5),
                    arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=150"))
    elif postfix == '_lowq':
        ax.annotate('N=2', xy=(40_000, 0.5), xytext=(80_000, 1.2),
                arrowprops=dict(arrowstyle="->", connectionstyle="angle3,angleA=0,angleB=-90"))

    fig.set_tight_layout(True)
    
    if save_figs:
        fig.savefig(f'./02_poisson/loss_vs_epoch_bstrap{postfix}.pdf', dpi=200, bbox_inches="tight")
    
    if show_figs:
        plt.show(fig)

### Can we compute the analytical solution's surface integration variance?

In [23]:
n_spheres = 1000
n_points = 10000

volsampler.seed(12345)
sphsampler.seed(12345)

# Sampling the volumes
volsamps = volsampler(n=n_spheres, lib='torch')

# Sampling the points from the spheres
sphsamps = sphsampler(volsamps, n_points, do_detspacing=True)
points = nn.Parameter(sphsamps['points'])
surfacenorms = sphsamps['normals']
areas = sphsamps['areas']
assert points.shape == (n_spheres, n_points, d)
assert surfacenorms.shape == (n_spheres, n_points, d)
assert areas.shape == (n_spheres,)

points_e_true = problem.field(points.reshape(n_spheres * n_points, d)).reshape(n_spheres, n_points, d)

In [24]:
fig, axes = plt.subplots(1, 2, figsize=(6, 2.2), dpi=72)

field_sizes = points_e_true.norm(dim=-1).reshape(-1)
assert field_sizes.shape == (n_spheres*n_points,)
q = torch.linspace(0, 1, 1001, dtype=tch_dtype, device=tch_device)
quants = field_sizes.quantile(q)

keepidxs = torch.linspace(0, n_spheres*n_points-1, 10000)
keepidxs = keepidxs.round().to(dtype=torch.long, device=tch_device)
field_sizes_sorted = torch.sort(field_sizes).values
field_sizes_ss = field_sizes_sorted.index_select(dim=0, index=keepidxs)

ax = axes[0]
ax.hist(field_sizes_ss.log10().detach().cpu().numpy(), bins=100)
ax.set_ylabel('Count')
ax.set_xlabel(r'$\log_{10}||\vec{E}||_2$')

ax = axes[1]
ax.plot(q.detach().cpu().numpy(), quants.detach().cpu().numpy(), color='blue', lw=2)
ax.set_xlabel('Percentile')
ax.set_ylabel(r'$\log_{10}||\vec{E}||_2$')
ax.set_yscale('log')

fig.set_tight_layout(True)

if save_figs:
    fig.savefig('./02_poisson/count_vs_esize_true.pdf', dpi=200, bbox_inches="tight")

if show_figs:
    plt.show(fig)

As you can see above, the singularities in the true vector fields show up with a non-negligible probability, and push the surface normal product variance to infinity.