#### Summary

This script plays with 3 ideas on the Poisson problem:

1. Performing latent parameter search on the true problem oracle using the Gaussian average gradient idea.

2. Performing some meta-learning to learn a quickly adaptable parameter set to many charge location configurations.

3. Performing MCMC to look for the latent problem parameters using the true oracle.

Only MCMC showed promising results, so you can find the MCMC implementation in the `16_poisson` notebook.

## The Poisson Problem Script

In [None]:
import matplotlib.pyplot as plt
import importlib
%matplotlib inline
if importlib.util.find_spec("matplotlib_inline") is not None:
    import matplotlib_inline
    matplotlib_inline.backend_inline.set_matplotlib_formats('retina')
else:
    from IPython.display import set_matplotlib_formats
    set_matplotlib_formats('retina')

plt.ioff();

In [None]:
import numpy as np
import torch
import json
import time
import os
import shutil
import socket
import random
import pathlib
import fnmatch
import datetime
import resource
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorboardX
import psutil
import logging
import torch.distributions
from pyinstrument import Profiler
from torch import nn
from copy import deepcopy
from itertools import chain
from scipy.special import gamma
from os.path import exists, isdir
from collections import defaultdict
from collections import OrderedDict as odict
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
from bspinn.io_utils import DataWriter
from bspinn.io_utils import get_git_commit
from bspinn.io_utils import preproc_cfgdict
from bspinn.io_utils import hie2deep, deep2hie

from bspinn.tch_utils import isscalar
from bspinn.tch_utils import EMA
from bspinn.tch_utils import BatchRNG
from bspinn.tch_utils import bffnn
from bspinn.tch_utils import profmem

from bspinn.io_cfg import configs_dir
from bspinn.io_cfg import results_dir
from bspinn.io_cfg import storage_dir

## Theory

Consider the $d$-dimensional space $\mathbb{R}^{d}$, and the following charge:

$$\rho(x) = \delta^d(x).$$

For $d \neq 2$, the analytical solution to the system

$$\nabla \cdot \vec{E} = \rho$$

$$\nabla V = \vec{E}$$

can be defined as 

$$V_{\vec{x}} = \frac{\Gamma(d/2)}{2\cdot\pi^{d/2}\cdot (2-d)} \|\vec{x}\|^{2-d}, $$

$$\vec{E}_{\vec{x}} = \frac{\Gamma(d/2)}{2\cdot \pi^{d/2}\cdot \|\vec{x}\|^{d}} \vec{x}.$$

For $d=2$, $\vec{E}_{\vec{x}}$ is the same, but for $V_{\vec{x}}$ we have

$$V_{\vec{x}} = \frac{1}{2\pi} \ln(\|\vec{x}\|).$$

We want to solve this system using the divergence theorem:

$$\iint_{S_{d-1}(V)} \vec{E}\cdot \hat{n}\text{ d}S = \iiint_{V_d} \nabla.\vec{E}\text{ d}V.$$

Keep in mind that the $d-1$-dimensional surface of a $d$-dimensional shpere with radius $r$ is 
$$\iint_{S_{d-1}(V^{\text{d-Ball}}_{r})} 1\text{ d}S = \frac{2\cdot \pi^{d/2}}{\Gamma(d/2)}\cdot r^{d-1}.$$

### Dimensionality Scaling

We will assume that our domain of solution is a d-Ball centerred at zero with a radius of $r_b$.
$$C_1 := \int_{V_{r_b}^{d\text{-Ball}}} 1 d\vec{x} = \frac{2\pi^{d/2}}{d\cdot\Gamma(d/2)} r_b^d$$

#### The Expectation of the Anlytical Solution

$$E_v := \int_{V_r^{d\text{-Ball}}} V_{\vec{x}} d\vec{x} = \int \frac{\Gamma(d/2)}{2\cdot\pi^{d/2}\cdot (2-d)} \|\vec{x}\|^{2-d} d\vec{x}$$

$$ = C_1 \cdot \int \frac{\Gamma(d/2)}{2\cdot\pi^{d/2}\cdot (2-d)} \|\vec{x}\|^{2-d} \cdot \frac{1}{C_1} d\vec{x} $$

$$ = C_1 \cdot \frac{\Gamma(d/2)}{2\cdot\pi^{d/2}\cdot (2-d)} \int \|\vec{x}\|^{2-d} \cdot \frac{1}{C_1} d\vec{x} $$

$$ = \frac{r_b^d}{d\cdot(2-d)} \cdot \int \|\vec{x}\|^{2-d} \cdot \frac{1}{C_1} d\vec{x} $$

$$ = \frac{r_b^d}{d\cdot(2-d)} \cdot \mathbb{E}_{\vec{x}} [\|\vec{x}\|^{2-d}] $$

By defining the radius of $\vec{x}$ as $r=\|\vec{x}\|$, the distribution of $r$ is

$$Pr(\|\vec{x}\|<r) = (\frac{r}{r_b})^d$$

$$P(\|\vec{x}\|=r) = \frac{(d-1) \cdot r^d}{r_b^d}$$

Therefore, we have

$$E_v = \frac{r_b^d}{d\cdot(2-d)} \cdot \mathbb{E}_{\vec{x}} [r^{2-d}] $$

$$ = \frac{r_b^d}{d\cdot(2-d)} \cdot \int_{r=0}^{r_b} r^{2-d} \frac{(d-1) \cdot r^d}{r_b^d} dr$$

$$ = \frac{r_b^d}{d\cdot(2-d)} \cdot \frac{d}{r_b^d} \int_{r=0}^{r_b} r dr$$

$$ = \frac{r_b^d}{d\cdot(2-d)} \cdot \frac{d}{r_b^d} \int_{r=0}^{r_b} r dr$$

$$ = \frac{r_b^2}{2\cdot(2-d)}$$

#### The Expectation of the Volume Ratio

$$\mathbb{E}_{r\sim U[r_l, r_h]}[(\frac{r}{r_b})^d] = \frac{1}{r_h - r_l} \int_{r_l}^{r_h} (\frac{r}{r_b})^d dr$$

$$=\frac{1}{d+1} \cdot \frac{1}{r_b^d} \frac{r_h^{d+1} - r_l^{d+1}}{r_h - r_l}.$$

By setting $r_h=r_b$ and $r_l < r_h$, the above value closes in on $$\frac{1}{d+1}$$.

### Defining the Problem and the Analytical Solution

In [None]:
class DeltaProblem:
    def __init__(self, weights, locations, tch_device, tch_dtype):
        # weights          -> np.array -> shape=(n_bch, n_chrg)
        # locations.shape  -> np.array -> shape=(n_bch, n_chrg, d)
        self.weights = weights
        self.locations = locations
        self.n_bch, self.n_chrg = self.weights.shape
        self.d = self.locations.shape[-1]
        assert self.weights.shape == (self.n_bch, self.n_chrg,)
        assert self.locations.shape == (self.n_bch, self.n_chrg, self.d)
        self.weights_tch = torch.from_numpy(
            self.weights).to(tch_device, tch_dtype)
        self.locations_tch = torch.from_numpy(
            self.locations).to(tch_device, tch_dtype)
        self.shape = (self.n_bch,)
        self.tch_pi = torch.tensor(np.pi, device=tch_device, dtype=tch_dtype)
        self.ndim = 1

    def integrate_volumes(self, volumes):
        # volumes -> dictionary
        assert volumes['type'] == 'ball'
        centers = volumes['centers']
        radii = volumes['radii']
        n_v = radii.shape[-1]
        n_bch, n_chrg, d = self.n_bch, self.n_chrg, self.d
        assert radii.shape == (n_bch, n_v,)
        assert centers.shape == (n_bch, n_v, d)
        lib = torch if torch.is_tensor(centers) else np
        mu = self.locations_tch if torch.is_tensor(centers) else self.locations
        w = self.weights_tch if torch.is_tensor(centers) else self.weights

        c_diff_mu = centers.reshape(
            n_bch, n_v, 1, d) - mu.reshape(n_bch, 1, n_chrg, d)
        assert c_diff_mu.shape == (n_bch, n_v, n_chrg, d)
        distl2 = lib.sqrt(lib.square(c_diff_mu).sum(-1))
        assert distl2.shape == (n_bch, n_v, n_chrg)
        integ = ((distl2 < radii.reshape(n_bch, n_v, 1))
                 * w.reshape(n_bch, 1, n_chrg)).sum(-1)
        assert integ.shape == (n_bch, n_v)
        return integ

    def potential(self, x):
        lib = torch if torch.is_tensor(x) else np
        lib_pi = self.tch_pi if torch.is_tensor(x) else np.pi
        w = self.weights_tch if torch.is_tensor(x) else self.weights
        mu = self.locations_tch if torch.is_tensor(x) else self.locations
        n_bch, n_chrg, d = self.n_bch, self.n_chrg, self.d
        n_x = x.shape[-2]
        assert x.shape == (
            n_bch, n_x, d), f'x.shape={x.shape}, (n_bch, n_x, d)={(n_bch, n_x, d)}'
        x_diff_mu = x.reshape(n_bch, n_x, 1, d) - \
            mu.reshape(self.n_bch, 1, n_chrg, d)
        assert x_diff_mu.shape == (n_bch, n_x, n_chrg, d)
        x_dists = lib.sqrt(lib.square(x_diff_mu).sum(-1))
        assert x_dists.shape == (n_bch, n_x, n_chrg)
        if d != 2:
            poten1 = (x_dists**(2-d))
            assert poten1.shape == (n_bch, n_x, n_chrg)
            poten2 = (poten1 * w.reshape(n_bch, 1, n_chrg)).sum(-1)
            assert poten2.shape == (n_bch, n_x)
            cst = gamma(d/2) / (2*(lib_pi**(d/2)))
            cst = cst / (2-d)
            assert isscalar(cst)
            poten = cst * poten2
            assert poten.shape == (n_bch, n_x)
        else:
            poten1 = lib.log(x_dists)
            assert poten1.shape == (n_bch, n_x, n_chrg)
            poten2 = (poten1 * w.reshape(n_bch, 1, n_chrg)).sum(-1)
            assert poten2.shape == (n_bch, n_x)
            poten = poten2 / (2*lib_pi)
            assert poten.shape == (n_bch, n_x)
        return poten

    def field(self, x):
        lib = torch if torch.is_tensor(x) else np
        lib_pi = self.tch_pi if torch.is_tensor(x) else np.pi
        w = self.weights_tch if torch.is_tensor(x) else self.weights
        mu = self.locations_tch if torch.is_tensor(x) else self.locations
        n_bch, n_chrg, d = self.n_bch, self.n_chrg, self.d
        n_x = x.shape[-2]
        assert x.shape == (n_bch, n_x, d)
        x_diff_mu = x.reshape(n_bch, n_x, 1, d) - \
            mu.reshape(n_bch, 1, n_chrg, d)
        assert x_diff_mu.shape == (n_bch, n_x, n_chrg, d)
        x_dists = lib.sqrt(lib.square(x_diff_mu).sum(-1))
        assert x_dists.shape == (n_bch, n_x, n_chrg)
        poten1 = (x_dists**(-d))
        assert poten1.shape == (n_bch, n_x, n_chrg)
        poten2 = (poten1 * w.reshape(n_bch, 1, n_chrg)).sum(-1)
        assert poten2.shape == (n_bch, n_x)
        cst = gamma(d/2) / (2*(lib_pi**(d/2)))
        assert isscalar(cst)
        poten = cst * poten2
        assert poten.shape == (n_bch, n_x)
        field = poten.reshape(n_bch, n_x, 1) * x
        assert field.shape == (n_bch, n_x, d)
        return field
    
    def state_dict(self):
        return dict(weights=self.weights_tch, locations=self.locations_tch)
    
    def load_state_dict(self, state_dict):
        weights = state_dict['weights'].detach().cpu().numpy()
        locations = state_dict['locations'].detach().cpu().numpy()
        tch_device = self.weights_tch.device
        tch_dtype = self.weights_tch.dtype
        self.__init__(weights, locations, tch_device, tch_dtype)


### Defining the Volume Sampler

In [None]:
class BallSampler:
    def __init__(self, c_dstr, c_params, r_dstr, r_params, batch_rng):
        assert isinstance(c_params, dict)
        for name, param in c_params.items():
            msg_ = f'center param {name} is not np.array'
            assert isinstance(param, np.ndarray), msg_
        
        assert isinstance(r_params, dict)
        for name, param in r_params.items():
            msg_ = f'radius param {name} is not np.array'
            assert isinstance(param, np.ndarray), msg_

        self.batch_rng = batch_rng
        self.lib = batch_rng.lib
        
        ##############################################################
        ################# Center Sampling Parameters #################
        ##############################################################
        c_params_ = c_params.copy()
        self.c_dstr = c_dstr
        if c_dstr == 'uniform':
            c_low = c_params_.pop('low')
            c_high = c_params_.pop('high')
            
            n_bch, dim = c_low.shape
            
            self.c_low_np = c_low.reshape(n_bch, 1, dim)
            self.c_high_np = c_high.reshape(n_bch, 1, dim)
            self.c_size_np = (self.c_high_np - self.c_low_np)

            if self.lib == 'torch':
                self.c_low_tch = torch.from_numpy(self.c_low_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                self.c_high_tch = torch.from_numpy(self.c_high_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                self.c_size_tch = torch.from_numpy(self.c_size_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
            
            self.c_low = self.c_low_np if self.lib == 'numpy' else self.c_low_tch
            self.c_size = self.c_size_np if self.lib == 'numpy' else self.c_size_tch
        elif c_dstr == 'normal':
            c_loc = c_params_.pop('loc')
            c_scale = c_params_.pop('scale')
            
            n_bch, dim = c_loc.shape
            self.c_loc_np = c_loc.reshape(n_bch, 1, dim)
            self.c_scale_np = c_scale.reshape(n_bch, 1, 1)
            
            if self.lib == 'torch':
                self.c_loc_tch = torch.from_numpy(self.c_loc_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                self.c_scale_tch = torch.from_numpy(self.c_scale_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                
            self.c_loc = self.c_loc_np if self.lib == 'numpy' else self.c_loc_tch
            self.c_scale = self.c_scale_np if self.lib == 'numpy' else self.c_scale_tch
        elif c_dstr == 'ball':
            c_cntr = c_params_.pop('c')
            c_radi = c_params_.pop('r')
            
            n_bch, dim = c_cntr.shape
            self.c_cntr_np = c_cntr.reshape(n_bch, 1, dim)
            self.c_radi_np = c_radi.reshape(n_bch, 1, 1)
            
            if self.lib == 'torch':
                self.c_cntr_tch = torch.from_numpy(self.c_cntr_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                self.c_radi_tch = torch.from_numpy(self.c_radi_np).to(
                    device=self.batch_rng.device, dtype=self.batch_rng.dtype)
                
            self.c_cntr = self.c_cntr_np if self.lib == 'numpy' else self.c_cntr_tch
            self.c_radi = self.c_radi_np if self.lib == 'numpy' else self.c_radi_tch
        else:
            raise ValueError(f'c_dstr="{c_dstr}" not implemented')
        
        msg_ = f'Some center parameters were left unused: {list(c_params_.keys())}'
        assert len(c_params_) == 0, msg_
            
        self.n_bch, self.d = n_bch, dim
        
        ##############################################################
        ################# Radius Sampling Parameters #################
        ##############################################################
        r_params_ = r_params.copy()
        r_low = r_params_.pop('low')
        r_high = r_params_.pop('high')
        
        if r_dstr == 'uniform':
            self.r_upow = 1.0
        elif r_dstr == 'unifdpow':
            self.r_upow = 1.0 / self.d
        else:
            raise ValueError(f'r_dstr={r_dstr} not implemented')

        r_low_rshp = r_low.reshape(self.n_bch, 1)
        r_high_rshp = r_high.reshape(self.n_bch, 1)
        assert (r_low >= 0.0).all()
        assert (r_high >= r_low).all()
        
        self.r_dstr = r_dstr
        self.r_low_np = np.power(r_low_rshp, 1.0/self.r_upow)
        self.r_high_np = np.power(r_high_rshp, 1.0/self.r_upow)
        self.r_size_np = (self.r_high_np - self.r_low_np)
        
        if self.lib == 'torch':
            self.r_low_tch = torch.from_numpy(self.r_low_np).to(
                device=self.batch_rng.device, dtype=self.batch_rng.dtype)
            self.r_high_tch = torch.from_numpy(self.r_high_np).to(
                device=self.batch_rng.device, dtype=self.batch_rng.dtype)
            self.r_size_tch = torch.from_numpy(self.r_size_np).to(
                device=self.batch_rng.device, dtype=self.batch_rng.dtype)
            
        self.r_low = self.r_low_np if self.lib == 'numpy' else self.r_low_tch
        self.r_size = self.r_size_np if self.lib == 'numpy' else self.r_size_tch
        
        msg_ = f'Some center parameters were left unused: {list(r_params_.keys())}'
        assert len(r_params_) == 0, msg_

    def __call__(self, n=1):
        radii = self.r_low + self.r_size * \
            self.batch_rng.uniform((self.n_bch, n))
        radii = radii ** self.r_upow
        
        if self.c_dstr == 'uniform':
            centers = self.batch_rng.uniform((self.n_bch, n, self.d))
            centers = centers * self.c_size + self.c_low
        elif self.c_dstr == 'normal':
            centers = self.batch_rng.normal((self.n_bch, n, self.d))
            centers = centers * self.c_scale + self.c_loc
        elif self.c_dstr == 'ball':
            rnd1 = self.batch_rng.normal((self.n_bch, n, self.d))
            rnd1 = rnd1 / ((rnd1**2).sum(-1, keepdims=True)**0.5)
            
            rnd2 = self.batch_rng.uniform((self.n_bch, n, 1))
            rnd2 = rnd2 ** (1./self.d)
            
            centers = self.c_radi * rnd2 * rnd1 + self.c_cntr
        else:
            raise ValueError(f'c_dstr="{self.c_dstr}" not implemented')
        
        d = dict()
        d['type'] = 'ball'
        d['centers'] = centers
        d['radii'] = radii
        return d



### Sruface Sampling

In [None]:
class SphereSampler:
    def __init__(self, batch_rng):
        self.tch_dtype = batch_rng.dtype
        self.tch_device = batch_rng.device
        self.batch_rng = batch_rng

    def np_exlinspace(self, start, end, n):
        assert n >= 1
        a = np.linspace(start, end, n, endpoint=False)
        b = a + 0.5 * (end - a[-1])
        return b

    def tch_exlinspace(self, start, end, n):
        assert n >= 1
        a = torch.linspace(start, end, n+1,
                           device=self.tch_device,
                           dtype=self.tch_dtype)[:-1]
        b = a + 0.5 * (end - a[-1])
        return b

    def __call__(self, volumes, n, do_detspacing=True):
        # volumes -> dictionary
        assert volumes['type'] == 'ball'
        centers = volumes['centers']
        radii = volumes['radii']
        n_bch, n_v, d = centers.shape
        use_np = not torch.is_tensor(centers)
        assert centers.shape == (n_bch, n_v, d)
        assert radii.shape == (n_bch, n_v)
        assert not (use_np) or (self.batch_rng.lib == 'numpy')
        assert use_np or (self.batch_rng.device == centers.device)
        assert use_np or (self.batch_rng.dtype == centers.dtype)
        assert self.batch_rng.shape == (n_bch,)
        exlinspace = self.np_exlinspace if use_np else self.tch_exlinspace
        meshgrid = np.meshgrid if use_np else torch.meshgrid
        sin = np.sin if use_np else torch.sin
        cos = np.cos if use_np else torch.cos
        matmul = np.matmul if use_np else torch.matmul

        if do_detspacing and (d == 2):
            theta = exlinspace(0.0, 2*np.pi, n)
            assert theta.shape == (n,)
            theta_2d = theta.reshape(n, 1)
            x_tilde_2d_list = [cos(theta_2d), sin(theta_2d)]
            if use_np:
                x_tilde_2d = np.concatenate(x_tilde_2d_list, axis=1)
            else:
                x_tilde_2d = torch.cat(x_tilde_2d_list, dim=1)
            assert x_tilde_2d.shape == (n, d)
            x_tilde_4d = x_tilde_2d.reshape(1, 1, n, d)
            assert x_tilde_4d.shape == (1, 1, n, d)
            x_tilde = x_tilde_4d.expand(n_bch, 1, n, d)
            assert x_tilde.shape == (n_bch, 1, n, d)
        elif do_detspacing and (d == 3):
            n_sqrt = int(np.sqrt(n))
            assert n == n_sqrt * n_sqrt, 'Need n to be int-square for now!'
            theta_1d = exlinspace(0.0, 2*np.pi, n_sqrt)
            unit_unif = exlinspace(0.0, 1.0, n_sqrt)
            if use_np:
                phi_1d = np.arccos(1-2*unit_unif)
            else:
                phi_1d = torch.arccos(1-2*unit_unif)
            theta_msh, phi_msh = meshgrid(theta_1d, phi_1d)
            assert theta_msh.shape == (n_sqrt, n_sqrt)
            assert phi_msh.shape == (n_sqrt, n_sqrt)
            theta_2d, phi_2d = theta_msh.reshape(n, 1), phi_msh.reshape(n, 1)
            assert theta_2d.shape == (n, 1)
            assert phi_2d.shape == (n, 1)
            x_tilde_lst = [sin(phi_2d) * cos(theta),
                           sin(phi_2d) * sin(theta), cos(phi_2d)]
            if use_np:
                x_tilde_2d = np.concatenate(x_tilde_lst, axis=1)
            else:
                x_tilde_2d = torch.cat(x_tilde_lst, dim=1)
            assert x_tilde_2d.shape == (n, d)
            x_tilde_4d = x_tilde_2d.reshape(1, 1, n, d)
            assert x_tilde_4d.shape == (1, 1, n, d)
            x_tilde = x_tilde_4d.expand(n_bch, 1, n, d)
            assert x_tilde.shape == (n_bch, 1, n, d)
        elif (not do_detspacing) and (not use_np):
            x_tilde_unnorm = self.batch_rng.normal((n_bch, n_v, n, d))
            x_tilde_l2 = torch.sqrt(torch.square(x_tilde_unnorm).sum(dim=-1))
            x_tilde = x_tilde_unnorm / x_tilde_l2.reshape(n_bch, n_v, n, 1)
            assert x_tilde.shape == (n_bch, n_v, n, d)
        else:
            raise RuntimeError('Not implemented yet!')

        if do_detspacing:
            rot_mats = self.batch_rng.so_n((n_bch, n_v, d, d))
            assert rot_mats.shape == (n_bch, n_v, d, d)

        if do_detspacing:
            x_tilde_rot = matmul(x_tilde, rot_mats)
        else:
            x_tilde_rot = x_tilde
        assert x_tilde_rot.shape == (n_bch, n_v, n, d)

        points = x_tilde_rot * \
            radii.reshape(n_bch, n_v, 1, 1) + centers.reshape(n_bch, n_v, 1, d)
        assert points.shape == (n_bch, n_v, n, d)

        if use_np:
            x_tilde_bc = np.broadcast_to(x_tilde, (n_bch, n_v, n, d))
        else:
            x_tilde_bc = x_tilde.expand(n_bch, n_v, n, d)

        if do_detspacing:
            rot_x_tilde = matmul(x_tilde_bc, rot_mats)
        else:
            rot_x_tilde = x_tilde_bc
        assert rot_x_tilde.shape == (n_bch, n_v, n, d)

        cst = (2*(np.pi**(d/2))) / gamma(d/2)
        csts = cst * (radii**(d-1))
        assert csts.shape == (n_bch, n_v)

        ret_dict = dict(points=points, normals=rot_x_tilde, areas=csts)
        return ret_dict



### Visualization

In [None]:
def get_nn_sol(model, x, n_eval=None, get_field=True, 
    out_lib='numpy'):
    """
    Gets a model and evaluates it minibatch-wise on the tensor x. 
    The minibatch size is capped at n_eval. The output will have the 
    predicted potentials and the vector fields at them.

    Parameters
    ----------
    model: (nn.module) the batched neural network.

    x: (torch.tensor) the evaluation points. This array should be 
        >2-dimensional and have a shape of `(..., x_rows, x_cols)`.

    n_eval: (int or None) the maximum mini-batch size. If None is 
        given, `x_rows` will be used as `n_eval`.
        
    out_lib: (str) determines the output tensor type. Should be either 
        'numpy' or 'torch'.
    
    Output Dictionary
    ----------
    v: (np.array or torch.tensor) the evaluated potentials 
        with a shape of `(*model.shape, x_rows)` where
        model.shape is the batch dimensions of the model. 

    e: (np.array or torch.tensor) the evaluated vector fields 
        with a shape of `(*model.shape, x_rows, x_cols)` where
        model.shape is the batch dimensions of the model.
    """
    x_rows, x_cols = tuple(x.shape)[-2:]
    x_bd_ = tuple(x.shape)[:-2]
    x_bd = (1,) if len(x_bd_) == 0 else x_bd_
    msg_ = f'Cannot have {x.shape} fed to {model.shape}'
    assert len(x_bd) <= model.ndim, msg_
    if len(x_bd) < model.ndim:
        x_bd = tuple([1] * (model.ndim-len(x_b)) + list(x_bd))
    assert all((a == b) or (a == 1) or (b == 1) 
               for a, b in zip(x_bd, model.shape)), msg_
    n_eval = x_rows if n_eval is None else n_eval
    if out_lib == 'numpy':
        to_lib = lambda a: a.detach().cpu().numpy()
        lib_cat = lambda al: np.concatenate(al, axis=1)
        lpf = '_np'
    elif out_lib == 'torch':
        to_lib = lambda a: a
        lib_cat = lambda al: torch.cat(al, dim=1)
        lpf = ''
    else:
        raise ValueError(f'outlib={outlib} not defined.')

    n_batches = int(np.ceil(x_rows / n_eval))
    v_pred_list = []
    e_pred_list = []
    for i in range(n_batches):
        x_i = x[..., (i*n_eval):((i+1)*n_eval), :]
        xi_rows = x_i.shape[-2]
        x_ii = x_i.reshape(*x_bd, xi_rows, x_cols)
        x_iii = x_ii.expand(*model.shape, xi_rows, x_cols)
        x_iiii = nn.Parameter(x_iii)
        v_pred_i = model(x_iiii).squeeze(-1)
        v_pred_ii = to_lib(v_pred_i.detach())
        v_pred_list.append(v_pred_ii)
        if get_field:
            e_pred_i, = torch.autograd.grad(v_pred_i.sum(), [x_iiii],
                grad_outputs=None, retain_graph=False, create_graph=False,
                only_inputs=True, allow_unused=False).squeeze(-1).detach()
            e_pred_ii = to_lib(e_pred_i)
            e_pred_list.append(e_pred_ii)

    v_pred = lib_cat(v_pred_list)
    if get_field:
        e_pred = lib_cat(e_pred_list)
    else:
        e_pred = None

    outdict = {f'v{lpf}': v_pred, f'e{lpf}': e_pred}
    return outdict


def get_prob_sol(problem, x, n_eval=None, get_field=True, 
    out_lib='numpy'):
    """
    Gets a problem and evaluates the analytical solution to its 
    potentials and vector fields minibatch-wise on the tensor x. 
    The minibatch size is capped at n_eval. The output will have the 
    predicted potentials and the vector fields at them.

    Parameters
    ----------
    problem: (object) the problem with both the `potential` and 
        `field` methods for analytical solution evaluation.

    x: (torch.tensor) the evaluation points. This array should be 
        >2-dimensional and have a shape of `(..., x_rows, x_cols)`.

    n_eval: (int or None) the maximum mini-batch size. If None is 
        given, `x_rows` will be used as `n_eval`.

    Output Dictionary
    ----------
    v_np: (np.array) the evaluated potentials with a shape of
        `(..., x_rows)`. 

    e_np: (np.array) the evaluated vector fields with a shape of
        `(..., x_rows, x_cols)`.
    """

    assert hasattr(problem, 'potential')
    assert callable(problem.potential)
    assert hasattr(problem, 'field')
    assert callable(problem.field)

    x_rows, x_cols = tuple(x.shape)[-2:]
    x_bd_ = tuple(x.shape)[:-2]
    x_bd = (1,) if len(x_bd_) == 0 else x_bd_
    msg_ = f'Cannot have {x.shape} fed to {problem.shape}'
    assert len(x_bd) <= problem.ndim, msg_
    if len(x_bd) < problem.ndim:
        x_bd = tuple([1] * (problem.ndim-len(x_b)) + list(x_bd))
    assert all((a == b) or (a == 1) or (b == 1) 
               for a, b in zip(x_bd, problem.shape)), msg_
    n_eval = x_rows if n_eval is None else n_eval
    if out_lib == 'numpy':
        to_lib = lambda a: a.detach().cpu().numpy()
        lib_cat = lambda al: np.concatenate(al, axis=1)
        lpf = '_np'
    elif out_lib == 'torch':
        to_lib = lambda a: a
        lib_cat = lambda al: torch.cat(al, dim=1)
        lpf = ''
    else:
        raise ValueError(f'outlib={outlib} not defined.')

    n_batches = int(np.ceil(x_rows / n_eval))
    v_list = []
    e_list = []
    for i in range(n_batches):
        x_i = x[..., (i*n_eval):((i+1)*n_eval), :]
        xi_rows = x_i.shape[-2]
        x_ii = x_i.reshape(*x_bd, xi_rows, x_cols)
        x_iii = x_ii.expand(*problem.shape, xi_rows, x_cols)
        v_i = problem.potential(x_iii)
        v_list.append(to_lib(v_i))
        if get_field:
            e_i = problem.field(x_iii)
            e_list.append(to_lib(e_i))

    v = lib_cat(v_list)
    if get_field:
        e = lib_cat(e_list)
    else:
        e = None
    outdict = {f'v{lpf}': v, f'e{lpf}': e}
    return outdict


def make_grid(x_low, x_high, dim, n_gpd, lib):
    """
    Creates a grid of points using the mesgrid functions
    
    Parameters
    ----------
    x_low: (list) a list of length `dim` with floats 
        representing the lower limits of the grid.
    
    x_high: (list) a list of length `dim` with floats 
        representing the higher limits of the grid.
    
    dim: (int) the dimension of the grid space.
    
    n_gpd: (int) the number of points in each 
        grid dimension. This yields a total of 
        `n_gpd**dim` points in the total grid.
        
    lib: (str) either 'torch' or 'numpy'. This determines 
        the type of `x` output.
        
    Outputs
    -------
    x: (torch.tensor or np.array) a 2-d tensor or array 
        with the shape of `(n_gpd**dim, dim)`. 
    
    xi_msh_np: (list of np.array) a list of length `dim` 
        with meshgrid tensors each with a shape of 
        `[n_gpd] * dim`.
    """
    
    assert dim == 2, 'not implemented yet'
    assert len(x_low) == dim
    assert len(x_high) == dim
    assert lib in ('torch', 'numpy')
    library = torch if lib == 'torch' else np
    tnper = lambda a: a.cpu().detach().numpy()
    nper = tnper if lib == 'torch' else lambda a: a
    
    x1_low, x2_low = x_low
    x1_high, x2_high = x_high
    n_g_plt = n_gpd ** dim

    x1_1d = library.linspace(x1_low, x1_high, n_gpd)
    assert x1_1d.shape == (n_gpd,)

    x2_1d = library.linspace(x2_low, x2_high, n_gpd)
    assert x2_1d.shape == (n_gpd,)

    x1_msh, x2_msh = library.meshgrid(x1_1d, x2_1d)
    assert x1_msh.shape == (n_gpd, n_gpd)
    assert x2_msh.shape == (n_gpd, n_gpd)

    x1 = x1_msh.reshape(n_g_plt, 1)
    assert x1.shape == (n_g_plt, 1)

    x2 = x2_msh.reshape(n_g_plt, 1)
    assert x2.shape == (n_g_plt, 1)

    x1_1d_c = x1_1d.reshape(n_gpd, 1)
    assert x1_1d_c.shape == (n_gpd, 1)

    x2_1d_c = x2_1d.reshape(n_gpd, 1)
    assert x2_1d_c.shape == (n_gpd, 1)

    x1_msh_np = nper(x1_msh)
    assert x1_msh_np.shape == (n_gpd, n_gpd)

    x2_msh_np = nper(x2_msh)
    assert x1_msh_np.shape == (n_gpd, n_gpd)

    x = torch.cat([x1, x2], dim=1)
    assert x.shape == (n_g_plt, dim)

    x_np = nper(x)
    assert x_np.shape == (n_g_plt, dim)
    
    xi_msh_np = [x1_msh_np, x2_msh_np]
    outdict = dict(x=x, xi_msh_np=xi_msh_np)

    return outdict


def plot_sol(x1_msh_np, x2_msh_np, sol_dict, fig=None, ax=None, cax=None):
    n_gpd, dim = x1_msh_np.shape[0], x1_msh_np.ndim
    assert dim == 2, f'dim={dim}, x1_msh_np.shape={x1_msh_np.shape}'
    assert x1_msh_np.shape == (n_gpd, n_gpd)
    assert x2_msh_np.shape == (n_gpd, n_gpd)
    n_g = (n_gpd ** dim)
   
    if fig is None:
        assert ax is None
        assert cax is None
        fig, ax = plt.subplots(1, 1, figsize=(3.0, 2.5), dpi=72)
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
    else:
        assert ax is not None
   
    e_percentile_cap = 90
    
    v_np = sol_dict['v_np']
    assert v_np.shape[-1] == n_g
    
    v_msh_np = v_np.reshape(-1, n_gpd, n_gpd).mean(axis=0)
    im = ax.pcolormesh(x1_msh_np, x2_msh_np, v_msh_np,
                        shading='auto', cmap='RdBu')
    if cax is not None:
        fig.colorbar(im, cax=cax)

    e_msh_np = sol_dict['e_np']
    if e_msh_np is not None:
        assert e_msh_np.shape[-2:] == (n_g, dim)
        e_msh_np = e_msh_np.reshape(-1, n_gpd,
            n_gpd, dim).mean(axis=0)
        if e_percentile_cap is not None:
            e_size = np.sqrt((e_msh_np**2).sum(axis=-1))
            e_size_cap = np.percentile(a=e_size, 
                q=e_percentile_cap, axis=None)
            cap_coef = np.ones_like(e_size)
            cap_coef[e_size > e_size_cap] = e_size_cap / \
                e_size[e_size > e_size_cap]
            e_msh_capped = e_msh_np * \
                cap_coef.reshape(*e_msh_np.shape[:-1], 1)
        else:
            e_msh_capped = e_msh_np

        ax.quiver(x1_msh_np, x2_msh_np,
            e_msh_capped[:, :, 0], e_msh_capped[:, :, 1])
    return fig, ax, cax


def get_perfdict(e_pnts, e_mdlsol, e_prbsol):
    """
    Computes the biased, bias-corrected, and slope-corrected error 
    metrics for the solutions of a Poisson problem.
    
    This function computes three types of MSE and MAE statistics:
        
        1. Plain: just take the model and ground truth solution
            and subtract them to get the errors. No bias- or slope-correction 
            is applied to offset those degrees of freedom.
            
            shorthand: 'pln'
            
        2. Bias-corrected: subtracts the average value from both the model 
            and ground truth solutions, and then computes the errors.
            
            shorthand: 'bc'
            
        3. Slope-corrected: Since any linear function can be added to the
            Poisson solutions without violating the poisson equation, this
            function fits an ordinary least squares to both the model and
            ground truth solutions, and then subtracts it from them. This
            way, even the arbitrary-slope issue can be addressed.
            
            shorthand: 'slc'
            
    Parameters
    ----------
    e_pnts: (torch.tensor) The input points to the model and the ground truth.
        This should have a shape of (n_seeds, n_evlpnts, dim).
        
    e_mdlsol: (torch.tensor) The model solution with a
        (n_seeds, n_evlpnts) shape.
    
    e_prbsol: (torch.tensor) The ground truth solution with a
        (n_seeds, n_evlpnts) shape.
        
    Output
    ------
    outdict: (dict) A mapping between the error keys and their numpy arrays.
        The error keys are the cartesian product of ('pln', 'bc', 'slc') 
        and ('mse', 'mae').
    """
    n_seeds, n_evlpnts, dim = e_pnts.shape
    assert e_mdlsol.shape == (n_seeds, n_evlpnts)
    assert e_prbsol.shape == (n_seeds, n_evlpnts)
    
    with torch.no_grad():
        # The plain non-processed error matrix
        err_pln = e_mdlsol - e_prbsol
        assert err_pln.shape == (n_seeds, n_evlpnts)
        
        # The bias-corrected error matrix
        e_mdlsol2 = e_mdlsol - e_mdlsol.mean(dim=1, keepdims=True)
        assert e_mdlsol2.shape == (n_seeds, n_evlpnts)
        e_prbsol2 = e_prbsol - e_prbsol.mean(dim=1, keepdims=True)
        assert e_prbsol2.shape == (n_seeds, n_evlpnts)
        err_bc = e_mdlsol2 - e_prbsol2
        assert err_bc.shape == (n_seeds, n_evlpnts)
        
        # The slope-corrected error matrix
        e_pntstrans = e_pnts.transpose(-1, -2)
        assert e_pntstrans.shape == (n_seeds, dim, n_evlpnts)
        e_pntsig = e_pntstrans.matmul(e_pnts)
        assert e_pntsig.shape == (n_seeds, dim, dim)
        e_pntsiginv = torch.pinverse(e_pntsig)
        assert e_pntsiginv.shape == (n_seeds, dim, dim)
        e_pntpinv = e_pntsiginv.matmul(e_pntstrans)
        assert e_pntpinv.shape == (n_seeds, dim, n_evlpnts)
        
        # e_pntpinv = torch.pinverse(e_pnts)
        # assert e_pntpinv.shape == (n_seeds, dim, n_evlpnts)
        
        e_mdlbeta = e_pntpinv.matmul(e_mdlsol2.unsqueeze(-1))
        assert e_mdlbeta.shape == (n_seeds, dim, 1)
        e_mdlslpcrc = e_pnts.matmul(e_mdlbeta)
        assert e_mdlslpcrc.shape == (n_seeds, n_evlpnts, 1)
        e_mdlsol3 = e_mdlsol2 - e_mdlslpcrc.squeeze(-1)
        assert e_mdlsol3.shape == (n_seeds, n_evlpnts)
        
        e_prbbeta = e_pntpinv.matmul(e_prbsol2.unsqueeze(-1))
        assert e_prbbeta.shape == (n_seeds, dim, 1)
        e_prbslpcrc = e_pnts.matmul(e_prbbeta)
        assert e_prbslpcrc.shape == (n_seeds, n_evlpnts, 1)
        e_prbsol3 = e_prbsol2 - e_prbslpcrc.squeeze(-1)
        assert e_prbsol3.shape == (n_seeds, n_evlpnts)
        
        err_slc = e_mdlsol3 - e_prbsol3
        assert err_slc.shape == (n_seeds, n_evlpnts)
        
        # The normalized slope-corrected error matrix
        e_mdlsol4 = e_mdlsol3 / e_mdlsol3.std(dim=1, keepdim=True)
        assert e_mdlsol4.shape == (n_seeds, n_evlpnts)
        
        e_prbsol4 = e_prbsol3 / e_prbsol3.std(dim=1, keepdim=True)
        assert e_prbsol4.shape == (n_seeds, n_evlpnts)
        
        err_scn = e_mdlsol4 - e_prbsol4
        assert err_scn.shape == (n_seeds, n_evlpnts)
        
        # Computing the mse and mae values
        e_plnmse = err_pln.square().mean(dim=-1)
        assert e_plnmse.shape == (n_seeds,)
        e_plnmae = err_pln.abs().mean(dim=-1)
        assert e_plnmse.shape == (n_seeds,)
        
        e_bcmse = err_bc.square().mean(dim=-1)
        assert e_bcmse.shape == (n_seeds,)
        e_bcmae = err_bc.abs().mean(dim=-1)
        assert e_bcmse.shape == (n_seeds,)
        
        e_slcmse = err_slc.square().mean(dim=-1)
        assert e_slcmse.shape == (n_seeds,)
        e_slcmae = err_slc.abs().mean(dim=-1)
        assert e_slcmse.shape == (n_seeds,)
        
        e_scnmse = err_scn.square().mean(dim=-1)
        assert e_scnmse.shape == (n_seeds,)
        e_scnmae = err_scn.abs().mean(dim=-1)
        assert e_scnmse.shape == (n_seeds,)
    
        outdict = {'pln/mse': e_plnmse.detach().cpu().numpy(),
                   'pln/mae': e_plnmae.detach().cpu().numpy(),
                   'bc/mse': e_bcmse.detach().cpu().numpy(),
                   'bc/mae': e_bcmae.detach().cpu().numpy(),
                   'slc/mse': e_slcmse.detach().cpu().numpy(),
                   'slc/mae': e_slcmae.detach().cpu().numpy(),
                   'scn/mse': e_scnmse.detach().cpu().numpy(),
                   'scn/mae': e_scnmae.detach().cpu().numpy()}
    
    return outdict


def eval_pnts(problem, model, target, e_pnts, do_bootstrap,
    n_seeds, n_evlpnts, dim, eval_bs):
    assert e_pnts.shape == (n_seeds, n_evlpnts, dim)

    # Computing the model, target and ground truth solutions
    e_prbsol = get_prob_sol(problem, e_pnts, n_eval=eval_bs, 
        get_field=False, out_lib='torch')['v']
    assert e_prbsol.shape == (n_seeds, n_evlpnts)

    # Computing the model solution
    with torch.no_grad():
        e_mdlsol = get_nn_sol(model, e_pnts, n_eval=eval_bs,
            get_field=False, out_lib='torch')['v']
        assert e_mdlsol.shape == (n_seeds, n_evlpnts)

    # Computing the target solution
    if do_bootstrap:
        with torch.no_grad():
            e_trgsol = get_nn_sol(target, e_pnts, n_eval=eval_bs, 
                get_field=False, out_lib='torch')['v']
        assert e_trgsol.shape == (n_seeds, n_evlpnts)

    eperfs = dict()
    eperfs['mdl'] = get_perfdict(e_pnts, e_mdlsol, e_prbsol)
    if do_bootstrap:
        eperfs['trg'] = get_perfdict(e_pnts, e_trgsol, e_prbsol)
    eperfs = deep2hie(eperfs, dictcls=dict)
    # Example: eperfs = {'mdl/pln/mse': ...,
    #                    'mdl/pln/mae': ...,
    #                    'mdl/bc/mse': ...,
    #                    'mdl/bc/mae': ...,
    #                    'mdl/slc/mse': ...,
    #                    'mdl/slc/mae': ...,
    #                    'trg/pln/mse': ...,
    #                    'trg/pln/mae': ...,
    #                    'trg/bc/mse': ...,
    #                    'trg/bc/mae': ...,
    #                    'trg/slc/mse': ...,
    #                    'trg/slc/mae': ...,
    #                   }
    return eperfs


## Utility Functions for Sanity Checks

In [None]:
#########################################################
########### Sanity Checking Utility Functions ###########
#########################################################

msg_bcast = '{} should be np broadcastable to {}={}. '
msg_bcast += 'However, it has an inferred shape of {}.'


def get_arr(name, trgshp_str, trns_opts):
    """
    Gets a list of values, and checks if it is broadcastable to a 
    target shape. If the shape does not match, it will raise a proper
    assertion error with a meaninful message. The output is a numpy 
    array that is guaranteed to be broadcastable to the target shape.

    Parameters
    ----------
    name: (str) name of the option / hyper-parameter.

    trgshp_str: (str) the target shape elements representation. Must be a 
        valid python expression where the needed elements .

    trns_opts: (dict) a dictionary containing the variables needed 
        for the string to list translation of val.

    Key Variables
    -------------
    `val = trns_opts[name]`: (list or str) list of values read 
        from the config file. If a string is provided, python's 
        `eval` function will be used to translate it into a list.
        
    `trg_shape = eval_formula(trgshp_str, trns_opts)`: (tuple) 
        the target shape.
    
    Output
    ----------
    val_np: (np.array) the numpy array of val. 
    """
    msg_ =  f'"{name}" must be in trns_opts but it isnt: {trns_opts}'
    assert name in trns_opts, msg_
    val = trns_opts[name]
    
    if isinstance(val, str):
        val_list = eval_formula(val, trns_opts)
    else:
        val_list = val
    val_np = np.array(val_list)
    src_shape = val_np.shape
    trg_shape = eval_formula(trgshp_str, trns_opts)
    msg_ = msg_bcast.format(name, trgshp_str, trg_shape, src_shape)

    assert len(val_np.shape) == len(trg_shape), msg_

    is_bcastble = all((x == y or x == 1 or y == 1) for x, y in
                      zip(src_shape, trg_shape))
    assert is_bcastble, msg_

    return val_np


def eval_formula(formula, variables):
    """
    Gets a string formula and uses the `eval` function of python to  
    translate it into a python variable. The necessary variables for 
    translation are provided through the `variables` argument.

    Parameters
    ----------
    formula (str): a string that can be passed to `eval`.
        Example: "[np.sqrt(dim), 'a', None]"

    variables (dict): a dictionary of variables used in the formula.
        Example: {"dim": 4}

    Output
    ------
    pyobj (object): the translated formula into a python object
        Example: [2.0, 'a', None]

    """
    locals().update(variables)
    pyobj = eval(formula)
    return pyobj


def chck_dstrargs(opt, cfgdict, dstr2args, opt2req, parnt_optdstr=None):
    """
    Checks if the distribution arguments are provided correctly. Works 
    with hirarchical models through recursive applications. Proper error 
    messages are displayed if one of the checks fails.

    Parameters
    ----------
    opt: (str) the option name.

    cfgdict: (dict) the config dictionary.

    dstr2args: (dict) a mapping between distribution and their 
        required arguments.
        
    opt2req: (dict) required arguments for an option itself, not 
        necessarily required by the option's distribution.
    """
    opt_dstr = cfgdict.get(f'{opt}/dstr', 'fixed')

    msg_ = f'Unknown {opt}_dstr: it should be one of {list(dstr2args.keys())}'
    assert opt_dstr in dstr2args, msg_

    opt2req = dict() if opt2req is None else opt2req
    optreqs = opt2req.get(opt, tuple())
    must_spec = list(dstr2args[opt_dstr]) + list(optreqs)
    avid_spec = list(chain.from_iterable(
        v for k, v in dstr2args.items() if k != opt_dstr))
    avid_spec = [k for k in avid_spec if k not in must_spec]

    if opt_dstr == 'fixed':
        # To avoid infinite recursive calls, we should end this here.
        msg_ = f'"{opt}" must be specified.'
        if parnt_optdstr is not None:
            parnt_opt, parnt_dstr = parnt_optdstr
            msg_ += f'"{parnt_opt}" was specified as "{parnt_dstr}", and'
        msg_ += f' "{opt}" was specified as "{opt_dstr}".'
        if len(optreqs) > 0:
            msg_ += f' Also, "{opt}" requires "{optreqs}" to be specified.'
        opt_val = cfgdict.get(opt, None)
        assert opt_val is not None, msg_
    else:
        for arg in must_spec:
            opt_arg = f'{opt}{arg}'
            chck_dstrargs(opt_arg, cfgdict, dstr2args, opt2req, (opt, opt_dstr))

    for arg in avid_spec:
        opt_arg = f'{opt}{arg}'
        opt_arg_val = cfgdict.get(opt_arg, None)
        msg_ = f'"{opt_arg}" should not be specified, since "{opt}" '
        msg_ += f'appears to follow the "{opt_dstr}" distribution.'
        assert opt_arg_val is None, msg_



## Plain Training Example

In [None]:
#########################################################
################### Mandatory Options ###################
#########################################################
rng_seed_list = list(range(0, 100_000, 1000))
dim = 2

n_srf = 400
n_srfpts_mdl = 1
n_srfpts_trg = 1
do_detspacing = False
do_dblsampling = False

do_bootstrap = True

# # The original values for 3-charges
# tau = 0.999
# w_trgreg = 1.0
# w_trg = 0.99

# Fast training values for 1-charge
tau = 0.984
w_trgreg = 2.0
w_trg = 0.99

opt_type = 'sgd'
n_epochs = 5_000
lr = 0.001

nn_dstr = 'mlp'
nn_width = 64
nn_hidden = 2
nn_act = 'tanh'

eval_bs = 256
n_evlpnts = 1000
eval_frq = 100
chkpnt_period = 20000

In [None]:
# Derived options and assertions
n_points = n_srfpts_mdl + n_srfpts_trg

assert not (do_dblsampling) or (n_srfpts_trg > 1)
if w_trg is None:
    w_trg = n_srfpts_trg / n_points
assert not (n_srfpts_mdl == 0) or (w_trg == 1.0)
n_rsdls = 2 if do_dblsampling else 1

if eval_bs is None:
    eval_bs = max(n_srfpts_mdl, n_srfpts_trg) * n_srf

#########################################################
########### I/O-Related Options and Operations ##########
#########################################################
device_name = 'cuda:0'
tch_device = torch.device(device_name)
tch_dtype = torch.float32

#########################################################
########### Constructing the Batch RNG Object ###########
#########################################################
n_seeds = len(rng_seed_list)
rng_seeds = np.array(rng_seed_list)
rng = BatchRNG(shape=(n_seeds,), lib='torch',
               device=tch_device, dtype=tch_dtype,
               unif_cache_cols=1_000_000,
               norm_cache_cols=5_000_000)
rng.seed(np.broadcast_to(rng_seeds, rng.shape))

#########################################################
########## Defining the Poisson Problem Object ##########
#########################################################
chrg_n = 1
chrg_w = np.ones((n_seeds, chrg_n))
chrg_mu = np.zeros((n_seeds, chrg_n, dim))
problem = DeltaProblem(weights=chrg_w, locations=chrg_mu,
    tch_device=tch_device, tch_dtype=tch_dtype)

vol_c_params = dict(c=np.zeros((n_seeds, dim)), r=np.ones(n_seeds))
vol_r_params = dict(low=np.zeros(n_seeds), high=np.ones(n_seeds))
volsampler = BallSampler(c_dstr='ball', c_params=vol_c_params,
                         r_dstr='unifdpow', r_params=vol_r_params,
                         batch_rng=rng)

srfsampler = SphereSampler(batch_rng=rng)

In [None]:
storage_dir = './15_search/01_plain/'
pathlib.Path(storage_dir).mkdir(parents=True, exist_ok=True)
strgidx = sum(isdir(f'{storage_dir}/{x}') for x in os.listdir(storage_dir))
dtnow = datetime.datetime.now().isoformat(timespec='seconds')
dtnow_ = dtnow[2:].replace('-', '').replace(':', '').replace('.', '')
cfgstrg_dir = f'{storage_dir}/{strgidx:02d}_{dtnow_}'
pathlib.Path(cfgstrg_dir).mkdir(parents=True, exist_ok=True)

if 'tbwriter' in locals():
    tbwriter.close()
tbwriter = tensorboardX.SummaryWriter(cfgstrg_dir)
logging.getLogger("tensorboardX.x2num").setLevel(logging.CRITICAL) 

In [None]:
# Initializing the model
model = bffnn(dim, nn_width, nn_hidden, nn_act, (n_seeds,), rng)
if do_bootstrap:
    target = bffnn(dim, nn_width, nn_hidden, nn_act, (n_seeds,), rng)
    target.load_state_dict(model.state_dict())
else:
    target = model

# Set the optimizer
opt = torch.optim.SGD(model.parameters(), lr)

# Evaluation tools
erng = rng
last_perfdict = dict()
ema = EMA(gamma=0.999, gamma_sq=0.998)
trn_sttime = time.time()

# Constructing the grid points
with torch.no_grad():
    n_gpd = 50
    n_g = (n_gpd ** dim)
    elow, ehigh = -np.ones(dim), np.ones(dim)
    gdict = make_grid(elow, ehigh, dim, n_gpd, 'torch')
    grid_x_ = gdict['x']
    assert grid_x_.shape == (n_g, dim)
    grid_x = grid_x_.reshape(1, n_g, dim).expand(n_seeds, n_g, dim)
    grid_x = grid_x.to(tch_device, tch_dtype)
    assert grid_x.shape == (n_seeds, n_g, dim)
    
    x1_msh_np, x2_msh_np = gdict['xi_msh_np']

with plt.style.context('default'):
    figax_list = [plt.subplots(1, 1, figsize=(3.2, 2.5), dpi=100) for _ in range(3)]
    (fig_mdl, ax_mdl), (fig_trg, ax_trg), (fig_gt, ax_gt) = figax_list
    cax_list = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) 
                for ax in (ax_mdl, ax_trg, ax_gt)]
    cax_mdl, cax_trg, cax_gt = cax_list
stat_history = defaultdict(list)
train_history = odict()

for epoch in range(n_epochs+1):
    opt.zero_grad()

    # Sampling the volumes
    volsamps = volsampler(n=n_srf)

    # Sampling the points from the srferes
    srfsamps = srfsampler(volsamps, n_points, do_detspacing=do_detspacing)
    points = nn.Parameter(srfsamps['points'])
    surfacenorms = srfsamps['normals']
    areas = srfsamps['areas']
    assert points.shape == (n_seeds, n_srf, n_points, dim)
    assert surfacenorms.shape == (n_seeds, n_srf, n_points, dim)
    assert areas.shape == (n_seeds, n_srf,)

    points_mdl = points[:, :, :n_srfpts_mdl, :]
    assert points_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    points_trg = points[:, :, n_srfpts_mdl:, :]
    assert points_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    surfacenorms_mdl = surfacenorms[:, :, :n_srfpts_mdl, :]
    assert surfacenorms_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    surfacenorms_trg = surfacenorms[:, :, n_srfpts_mdl:, :]
    assert surfacenorms_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    # Making surface integral predictions using the reference model
    u_mdl = model(points_mdl)
    assert u_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, 1)
    nabla_x_u_mdl, = torch.autograd.grad(u_mdl.sum(), [points_mdl],
        grad_outputs=None, retain_graph=True, create_graph=True,
        only_inputs=True, allow_unused=False)
    assert nabla_x_u_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    normprods_mdl = (nabla_x_u_mdl * surfacenorms_mdl).sum(dim=-1)
    assert normprods_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl)
    if n_srfpts_mdl > 0:
        mean_normprods_mdl = normprods_mdl.mean(dim=-1, keepdim=True)
        assert mean_normprods_mdl.shape == (n_seeds, n_srf, 1)
    else:
        mean_normprods_mdl = 0.0

    # Making surface integral predictions using the target model
    u_trg = target(points_trg)
    assert u_trg.shape == (n_seeds, n_srf, n_srfpts_trg, 1)
    nabla_x_u_trg, = torch.autograd.grad(u_trg.sum(), [points_trg],
        grad_outputs=None, retain_graph=True, create_graph=not(do_bootstrap),
        only_inputs=True, allow_unused=False)
    assert nabla_x_u_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    normprods_trg = (nabla_x_u_trg * surfacenorms_trg).sum(dim=-1)
    assert normprods_trg.shape == (n_seeds, n_srf, n_srfpts_trg)
    if do_dblsampling:
        assert n_rsdls == 2

        mean_normprods_trg1 = normprods_trg[..., 0::2].mean(
            dim=-1, keepdim=True)
        assert mean_normprods_trg1.shape == (n_seeds, n_srf, 1)

        mean_normprods_trg2 = normprods_trg[..., 1::2].mean(
            dim=-1, keepdim=True)
        assert mean_normprods_trg2.shape == (n_seeds, n_srf, 1)

        mean_normprods_trg = torch.cat(
            [mean_normprods_trg1, mean_normprods_trg2], dim=-1)
        assert mean_normprods_trg.shape == (n_seeds, n_srf, n_rsdls)
    else:
        assert n_rsdls == 1

        mean_normprods_trg = normprods_trg.mean(dim=-1, keepdim=True)
        assert mean_normprods_trg.shape == (n_seeds, n_srf, n_rsdls)

    # Linearly combining the reference and target predictions
    mean_normprods = (       w_trg  * mean_normprods_trg +
                      (1.0 - w_trg) * mean_normprods_mdl)
    assert mean_normprods.shape == (n_seeds, n_srf, n_rsdls)

    # Considering the surface areas
    pred_surfintegs = mean_normprods * areas.reshape(n_seeds, n_srf, 1)
    assert pred_surfintegs.shape == (n_seeds, n_srf, n_rsdls)

    # Getting the reference volume integrals
    ref_volintegs = problem.integrate_volumes(volsamps)
    assert ref_volintegs.shape == (n_seeds, n_srf)

    # Getting the residual terms
    resterms = pred_surfintegs - ref_volintegs.reshape(n_seeds, n_srf, 1)
    assert resterms.shape == (n_seeds, n_srf, n_rsdls)

    # Multiplying the residual terms
    if do_dblsampling:
        resterms_prod = resterms.prod(dim=-1)
        assert resterms_prod.shape == (n_seeds, n_srf)
    else:
        resterms_prod = torch.square(resterms).squeeze(-1)
        assert resterms_prod.shape == (n_seeds, n_srf)

    # Computing the main loss
    loss_main = resterms_prod.mean(-1)
    assert loss_main.shape == (n_seeds,)

    if do_bootstrap:
        with torch.no_grad():
            u_mdl_prime = target(points_mdl)
        loss_trgreg = torch.square(u_mdl - u_mdl_prime).mean([-3, -2, -1])
        assert loss_trgreg.shape == (n_seeds,)
    else:
        loss_trgreg = torch.zeros(n_seeds, device=tch_device, dtype=tch_dtype)
        assert loss_trgreg.shape == (n_seeds,)

    # The total loss
    loss = loss_main + w_trgreg * loss_trgreg
    assert loss.shape == (n_seeds,)

    loss_sum = loss.sum()
    loss_sum.backward()

    # We will not update in the first epoch so that we will 
    # record the initialization statistics as well. Instead, 
    # we will update an extra epoch at the end.
    if (epoch > 0):
        opt.step()

    # Updating the target network
    if do_bootstrap and (epoch > 0):
        model_sd = model.state_dict()
        target_sd = target.state_dict()
        newtrg_sd = dict()
        with torch.no_grad():
            for key, param in model_sd.items():
                param_trg = target_sd[key]
                newtrg_sd[key] = tau * param_trg + (1-tau) * param
        target.load_state_dict(newtrg_sd)
    
    # Saving the grid solutions to tensorboard
    if epoch % 1000 == 0:
        with torch.no_grad():
            grid_prbsol = get_prob_sol(problem, grid_x, n_eval=eval_bs, 
                get_field=False, out_lib='numpy')
            assert grid_prbsol['v_np'].shape == (n_seeds, n_g)

            grid_mdlsol = get_nn_sol(model, grid_x, n_eval=eval_bs,
                get_field=False, out_lib='numpy')
            assert grid_mdlsol['v_np'].shape == (n_seeds, n_g)

            if do_bootstrap:
                grid_trgsol = get_nn_sol(target, grid_x, n_eval=eval_bs, 
                    get_field=False, out_lib='numpy')
            assert grid_trgsol['v_np'].shape == (n_seeds, n_g)
            
            soltd_list = [('gt', grid_prbsol, fig_gt, ax_gt, cax_gt, 'Ground Truth'),
                          ('mdl', grid_mdlsol, fig_mdl, ax_mdl, cax_mdl, 'Prediction')]
            if do_bootstrap:
                soltd_list += [('trg', grid_trgsol, fig_trg, ax_trg, cax_trg, 'Target')]
            for sol_t, sol_dict, fig, ax, cax, ttl in soltd_list:
                plot_sol(x1_msh_np, x2_msh_np, sol_dict, fig=fig, ax=ax, cax=cax)
                ax.set_title(ttl)
                fig.set_tight_layout(True)
                tbwriter.add_figure(f'viz/{sol_t}', fig, epoch)
            tbwriter.flush()
            
    if epoch % eval_frq == 0:
        # Sampling the evaluation points
        with torch.no_grad():
            evols = volsampler(n=n_evlpnts)
            assert evols['type'] == 'ball'

            e_c = evols['centers']
            assert e_c.shape == (n_seeds, n_evlpnts, dim)

            e_r_ = evols['radii']
            assert e_r_.shape == (n_seeds, n_evlpnts)

            e_r = e_r_.unsqueeze(dim=-1)
            assert e_r.shape == (n_seeds, n_evlpnts, 1)

            untrd = erng.uniform((n_seeds, n_evlpnts, 1))
            assert untrd.shape == (n_seeds, n_evlpnts, 1)

            untr = untrd.pow(1.0 / dim)
            assert untr.shape == (n_seeds, n_evlpnts, 1)

            e_pntrs = untr * e_r
            assert e_pntrs.shape == (n_seeds, n_evlpnts, 1)

            etheta = erng.normal((n_seeds, n_evlpnts, dim))
            assert etheta.shape == (n_seeds, n_evlpnts, dim)

            ethtilde = etheta / etheta.norm(dim=-1, keepdim=True)
            assert ethtilde.shape == (n_seeds, n_evlpnts, dim)

            e_pnts = e_c + ethtilde * e_pntrs
            assert e_pnts.shape == (n_seeds, n_evlpnts, dim)
        
        eperfs = eval_pnts(problem, model, target, e_pnts, do_bootstrap,
            n_seeds, n_evlpnts, dim, eval_bs)
        for kk, vv in eperfs.items():
            if '/bc/' in kk:
                tbwriter.add_scalar(f'perf/{kk}', vv.mean(), epoch)
            
    # computing the normal product variances
    with torch.no_grad(): 
        normprods = torch.cat([normprods_mdl, normprods_trg], dim=-1)
        npvm = (normprods.var(dim=-1)*areas.square()).mean(-1)

    # Computing the loss moving averages
    loss_ema_mean, loss_ema_std_mean = ema('loss', loss)
    npvm_ema_mean, npvm_ema_std_mean = ema('npvm', npvm)
    if epoch % 1000 == 0:
        print_str = f'Epoch {epoch}, EMA loss = {loss_ema_mean:.4f}'
        print_str += f' +/- {2*loss_ema_std_mean:.4f}'
        print_str += f', EMA Field-Norm Product Variance = {npvm_ema_mean:.4f}'
        print_str += f' +/- {2*npvm_ema_std_mean:.4f} ({time.time()-trn_sttime:0.1f} s)'
        print(print_str, flush=True)
 
    tbwriter.add_scalar('loss/total', loss.mean(), epoch)
    tbwriter.add_scalar('loss/main', loss_main.mean(), epoch)
    tbwriter.add_scalar('loss/trgreg', loss_trgreg.mean(), epoch)
    tbwriter.add_scalar('loss/npvm', npvm.mean(), epoch)

    if epoch % chkpnt_period == 0:
        train_history[f'{epoch}/mdl'] = deepcopy({k: v.cpu() for k, v
            in model.state_dict().items()})
        train_history[f'{epoch}/trg'] = deepcopy({k: v.cpu() for k, v
            in target.state_dict().items()})
        train_history[f'{epoch}/prb'] = deepcopy({k: v.cpu() for k, v
            in problem.state_dict().items()})
        
print(f'Training finished in {time.time() - trn_sttime:.1f} seconds.')
tbwriter.flush()

In [None]:
n_rows, n_cols = 1, 2 + do_bootstrap
fig, axes = plt.subplots(n_rows, n_cols, figsize=(
    n_cols * 3.5, n_rows * 3), dpi=72, sharex=True, sharey=True)
cax = None

# Computing the model, target and ground truth solutions
prob_sol = get_prob_sol(problem, grid_x, n_eval=eval_bs, get_field=False)
with torch.no_grad():
    mdl_sol = get_nn_sol(model, grid_x, n_eval=eval_bs, get_field=False) 
    if do_bootstrap:
        trg_sol = get_nn_sol(target, grid_x, n_eval=eval_bs, get_field=False)

soltd_list = [('gt', prob_sol, axes[0], 'Ground Truth'),
              ('mdl', mdl_sol, axes[1], 'Prediction')]
if do_bootstrap:
    soltd_list += [('trg', trg_sol, axes[2], 'Target')]
for sol_t, sol_dict, ax, ttl in soltd_list:
    plot_sol(x1_msh_np, x2_msh_np, sol_dict, fig=fig, ax=ax, cax=cax)
    ax.set_title(ttl)
fig

In [None]:
torch.save(train_history, f'{cfgstrg_dir}/train_history.pt')

hp_dict = dict(dim=dim, n_srf=n_srf, n_srfpts_mdl=n_srfpts_mdl, 
               n_srfpts_trg=n_srfpts_trg, do_detspacing=do_detspacing, 
               do_dblsampling=do_dblsampling, do_bootstrap=do_bootstrap,
               tau=tau, w_trgreg=w_trgreg, w_trg=w_trg, opt_type=opt_type, 
               n_epochs=n_epochs, lr=lr, nn_dstr=nn_dstr, nn_width=nn_width, 
               nn_hidden=nn_hidden, nn_act=nn_act)

with open(f'{cfgstrg_dir}/config.json', "w") as outfile:
    json.dump(hp_dict, outfile)

### Toy Example: Latent Parameter Identification

Here, we implement a zero-order search method in the `GPMBO` class and use it to recover the charge locations in the following Poisson problem.

In [None]:
class GPMBO:
    def __init__(self, dim, n_mdl, n_seeds, lr, init_mu, init_std, 
                 gamma, yb_gamma, rng, optim, tch_device, tch_dtype, 
                 opt_siglog=True):
        mu = init_mu * torch.ones(n_seeds, 1, dim, device=tch_device, dtype=tch_dtype)
        mu = torch.nn.Parameter(mu)
        assert mu.shape == (n_seeds, 1, dim)
        
        with torch.no_grad():
            sig_rot = torch.eye(dim, device=tch_device, dtype=tch_dtype)
            sig_rot = sig_rot.reshape(1, 1, dim, dim).expand(n_seeds, 1, dim, dim).clone()
            sig_log = torch.full((n_seeds, 1, dim), np.log(init_std), device=tch_device, dtype=tch_dtype)
        sig_rot = torch.nn.Parameter(sig_rot)
        assert sig_rot.shape == (n_seeds, 1, dim, dim)
        sig_log = torch.nn.Parameter(sig_log)
        assert sig_log.shape == (n_seeds, 1, dim)
        
        assert not sig_rot.isnan().any()
        
        epsilon = rng.normal((n_seeds, n_mdl, dim, 1))
        opt_vars = [mu, sig_rot, sig_log] if opt_siglog else [mu, sig_rot]
        if optim == 'adam':
            opt = torch.optim.Adam(opt_vars, lr=lr)
        elif optim == 'sgd':
            opt = torch.optim.SGD(opt_vars, lr=lr)
        else:
            raise ValueError(optim)
        
        self.mu, self.sig_rot, self.sig_log = mu, sig_rot, sig_log
        self.opt = opt
        self.rng = rng
        self.epsilon = epsilon
        self.dim = dim
        self.n_mdl = n_mdl
        self.n_seeds = n_seeds
        self.pi = torch.tensor(np.pi).to(device=tch_device, dtype=tch_dtype)
        self.y_base = torch.zeros(n_seeds, 1, device=tch_device, dtype=tch_dtype)
        self.gamma = gamma
        self.yb_gamma = yb_gamma
         
    def ask(self):
        n_seeds, n_mdl, dim = self.n_seeds, self.n_mdl, self.dim
        epsilon, mu = self.epsilon, self.mu
        sig_log, sig_rot = self.sig_log, self.sig_rot
        
        # Computing the "std" matrix
        sig_rot_ = sig_rot.tril()
        assert sig_rot_.shape == (n_seeds, 1, dim, dim)
        sig_r = torch.matrix_exp(sig_rot_ - sig_rot_.transpose(-1, -2))
        assert sig_r.shape == (n_seeds, 1, dim, dim)
        sig_sig = sig_log.exp().reshape(n_seeds, 1, 1, dim)
        assert sig_sig.shape == (n_seeds, 1, 1, dim)
        sig = sig_r * sig_sig
        assert sig.shape == (n_seeds, 1, dim, dim)
        
        with torch.no_grad():
            x_query = sig.matmul(epsilon)
            assert x_query.shape == (n_seeds, n_mdl, dim, 1)

            x_query = x_query.squeeze(dim=-1)
            assert x_query.shape == (n_seeds, n_mdl, dim)

            x_query = x_query + self.mu
            assert x_query.shape == (n_seeds, n_mdl, dim)
            
        return x_query
    
    def tell(self, y):
        n_seeds, n_mdl, dim = self.n_seeds, self.n_mdl, self.dim
        epsilon, mu = self.epsilon, self.mu
        gamma, y_bias = self.gamma, self.y_base
        sig_log, sig_rot = self.sig_log, self.sig_rot
        yb_gamma = self.yb_gamma
        
        x = self.ask().detach()
        assert x.shape == (n_seeds, n_mdl, dim)
        assert y.shape == (n_seeds, n_mdl)
        assert not(y.isnan().any())
        
        self.opt.zero_grad()
        
        e = (x - mu).unsqueeze(-2)
        assert e.shape == (n_seeds, n_mdl, 1, dim)
        
        sig_rot_ = sig_rot.tril()
        assert sig_rot_.shape == (n_seeds, 1, dim, dim)
        sig_r = torch.matrix_exp(sig_rot_ - sig_rot_.transpose(-1, -2))
        assert sig_r.shape == (n_seeds, 1, dim, dim)
        
        sig_lam = (-sig_log).exp().reshape(n_seeds, 1, 1, dim)
        assert sig_lam.shape == (n_seeds, 1, 1, dim)
        sig = sig_r * sig_lam
        assert sig.shape == (n_seeds, 1, dim, dim)
        eT_sig = e.matmul(sig)
        assert eT_sig.shape == (n_seeds, n_mdl, 1, dim)
        eT_siginv_e = eT_sig.square().sum(dim=-1).squeeze(-1)
        assert eT_siginv_e.shape == (n_seeds, n_mdl)
        
        sigma_halflogdet = sig_log.sum(dim=-1)
        assert sigma_halflogdet.shape == (n_seeds, 1)
        logpdf = -0.5 * eT_siginv_e - sigma_halflogdet
        assert logpdf.shape == (n_seeds, n_mdl)
        
        prob_ratio = (logpdf - logpdf.detach()).exp()
        assert prob_ratio.shape == (n_seeds, n_mdl)
        y_probratio = (y - y_bias) * prob_ratio
        assert y_probratio.shape == (n_seeds, n_mdl)
        max_obj = y_probratio.mean(dim=-1)
        assert max_obj.shape == (n_seeds,)
        
        min_obj = -max_obj.sum()
        min_obj.backward()
        
        self.opt.step()
        
        with torch.no_grad():
            # Updating epsilon
            deps = rng.normal((n_seeds, n_mdl, dim, 1))
            self.epsilon = gamma * epsilon + (1. - self.gamma) * deps
            assert self.epsilon.shape == (n_seeds, n_mdl, dim, 1)

            # Updating y_bias
            dyb = y_bias.mean(dim=-1, keepdims=True)
            self.y_bias = yb_gamma * y_bias + (1. - yb_gamma) * dyb

In [None]:
# Example 1: Fitting to the potentials without query noise
chrg_n = 3
n_mdl = 100
query_noise = 0.0
mbo_lr = 0.05
n_mboiter = 400
y_type = 'potential'

# # Example 2: Fitting to the potentials with query noise
# chrg_n = 3
# n_mdl = 100
# query_noise = 0.05
# mbo_lr = 0.005
# n_mboiter = 4000
# y_type = 'potential'

# # Example 3: Fitting to the fields with noise
# chrg_n = 1
# n_mdl = 1000
# query_noise = 0.05
# mbo_lr = 0.005
# n_mboiter = 3500
# y_type = 'field'

# Common settings
dim = 2
n_seeds = 10
n_true = 20
opt_siglog = True
init_std = 0.5

tch_device = torch.device('cuda:0')
tch_dtype = torch.float32

# Creating the RNG
seeds_arr = (np.arange(n_seeds) * 1000).tolist()
rng = BatchRNG(shape=(n_seeds,), lib='torch',
       device=tch_device, dtype=tch_dtype,
       unif_cache_cols=1_000_000,
       norm_cache_cols=5_000_000)
rng.seed(np.array(seeds_arr))

# Creating the true problem
true_chrg_w = np.ones((n_seeds, chrg_n))
assert true_chrg_w.shape == (n_seeds, chrg_n)
true_chrg_mu = np.array([[-0.5, -0.5],
                         [ 0.5,  0.5],
                         [ 0.0,  0.0]])[:chrg_n, :]
true_chrg_mu = np.broadcast_to(true_chrg_mu[None, ...], 
    (n_seeds, chrg_n, dim)).copy()
assert true_chrg_mu.shape == (n_seeds, chrg_n, dim)
true_problem = DeltaProblem(weights=true_chrg_w, 
    locations=true_chrg_mu,
    tch_device=tch_device, tch_dtype=tch_dtype)

true_x = rng.uniform((n_seeds, n_true, dim)) * 2 - 1
assert true_x.shape == (n_seeds, n_true, dim)
if y_type == 'potential':
    y_dim = 1
    mse_mul = 100
    mse_clip = np.inf
    true_y = true_problem.potential(true_x).unsqueeze(-1)
    assert true_y.shape == (n_seeds, n_true, 1)
elif y_type == 'field':
    y_dim = dim
    mse_mul = 1
    mse_clip = 10
    true_y = true_problem.field(true_x)
    assert true_y.shape == (n_seeds, n_true, y_dim)
else:
    raise ValueError(f'y_type={y_type} undefined')

sdim = chrg_n * dim
mbo = GPMBO(dim=sdim, n_mdl=n_mdl, n_seeds=n_seeds, 
            lr=mbo_lr, init_mu=0.0, init_std=init_std, gamma=0.0, 
            yb_gamma=0.9, rng=rng, optim='adam',
            tch_device=tch_device, tch_dtype=tch_dtype, 
            opt_siglog=opt_siglog)

all_mbo_mu = []
for mbo_iter in range(n_mboiter):
    php_query = mbo.ask()
    assert php_query.shape == (n_seeds, n_mdl, sdim)
    
    # Adding noise to the query to make the problem more challenging
    php_query = php_query + query_noise * rng.normal((n_seeds, n_mdl, sdim))
    assert php_query.shape == (n_seeds, n_mdl, sdim)

    # Running a fake query system
    loc_query = php_query.reshape(n_seeds*n_mdl, chrg_n, dim).detach().cpu().numpy()
    w_query = torch.ones(n_seeds*n_mdl, chrg_n).detach().cpu().numpy()

    problem_query = DeltaProblem(weights=w_query, 
        locations=loc_query,
        tch_device=tch_device, tch_dtype=tch_dtype)

    x_query = true_x.reshape(n_seeds, 1, n_true, dim)
    x_query = x_query.expand(n_seeds, n_mdl, n_true, dim)
    x_query = x_query.reshape(n_seeds*n_mdl, n_true, dim)
    
    if y_type == 'potential':
        y_query_ = problem_query.potential(x_query).unsqueeze(-1)
        assert y_query_.shape == (n_seeds*n_mdl, n_true, y_dim)
    elif y_type == 'field':
        y_query_ = problem_query.field(x_query)
        assert y_query_.shape == (n_seeds*n_mdl, n_true, y_dim)
    else:
        raise ValueError(f'y_type={y_type} undefined')

    y_query = y_query_.reshape(n_seeds, n_mdl, n_true, y_dim)
    assert y_query.shape == (n_seeds, n_mdl, n_true, y_dim)

    y_err = y_query - true_y.reshape(n_seeds, 1, n_true, y_dim)
    assert y_err.shape == (n_seeds, n_mdl, n_true, y_dim)

    y_mse_ = y_err.square().sum(dim=-1)
    assert y_mse_.shape == (n_seeds, n_mdl, n_true)
    
    y_mse_ = torch.clip(y_mse_ , 0.0, mse_clip)
    assert y_mse_.shape == (n_seeds, n_mdl, n_true)
    
    y_mse = mse_mul * y_mse_.mean(dim=-1)
    assert y_mse.shape == (n_seeds, n_mdl)
    
    if mbo_iter % 100 == 0:
        print(f'MSE: {y_mse.mean():.4f} +/- {y_mse.std()/np.sqrt(n_seeds*n_mdl):.4f}')

    mbo.tell(-y_mse)
    all_mbo_mu.append(mbo.mu)
    
mbo_mus_ = torch.cat(all_mbo_mu, dim=-2)
assert mbo_mus_.shape == (n_seeds, n_mboiter, chrg_n*dim)
mbo_mus = mbo_mus_.reshape(n_seeds, n_mboiter, chrg_n, dim)
assert mbo_mus.shape == (n_seeds, n_mboiter, chrg_n, dim)

print('The final Gaussian means:')
a = mbo.mu.reshape(n_seeds, chrg_n*dim).detach().cpu().numpy()
print(np.array_str(a, precision=3, suppress_small=True) + '\n' + '-'*80)

print('The final Gaussian stds:')
a = mbo.sig_log.exp().reshape(n_seeds, chrg_n*dim).detach().cpu().numpy()
print(np.array_str(a, precision=3, suppress_small=True) + '\n' + '-'*80)

In [None]:
# Example 1: Fitting to the potentials without query noise
chrg_n = 3
n_mdl = 100
query_noise = 0.0
mbo_lr = 0.01
n_mboiter = 4000
y_type = 'potential'

# # Example 2: Fitting to the potentials with query noise
# chrg_n = 3
# n_mdl = 100
# query_noise = 0.05
# mbo_lr = 0.005
# n_mboiter = 4000
# y_type = 'potential'

# # Example 3: Fitting to the fields with noise
# chrg_n = 1
# n_mdl = 1000
# query_noise = 0.05
# mbo_lr = 0.005
# n_mboiter = 3500
# y_type = 'field'

# Example 4: Fitting to the fields with query noise
chrg_n = 3
n_mdl = 100
query_noise = 0.0
mbo_lr = 0.005
n_mboiter = 40000
y_type = 'field'

# Common settings
dim = 2
n_seeds = 100
n_true = 50
opt_siglog = False
init_std = 0.01

sdim = chrg_n * dim
import torch.distributions
tch_normal = torch.distributions.normal.Normal
prior_dist = tch_normal(loc=torch.zeros(sdim, device=tch_device, dtype=tch_dtype),
                        scale=torch.ones(sdim, device=tch_device, dtype=tch_dtype))

tch_device = torch.device('cuda:0')
tch_dtype = torch.float32

# Creating the RNG
seeds_arr = (np.arange(n_seeds) * 1000).tolist()
rng = BatchRNG(shape=(n_seeds,), lib='torch',
       device=tch_device, dtype=tch_dtype,
       unif_cache_cols=1_000,
       norm_cache_cols=5_000)
rng.seed(np.array(seeds_arr))

# Creating the true problem
true_chrg_w = np.ones((n_seeds, chrg_n))
assert true_chrg_w.shape == (n_seeds, chrg_n)
true_chrg_mu = np.array([[-0.5, -0.5],
                         [ 0.5,  0.5],
                         [ 0.0,  0.0]])[:chrg_n, :]
true_chrg_mu = np.broadcast_to(true_chrg_mu[None, ...], 
    (n_seeds, chrg_n, dim)).copy()
assert true_chrg_mu.shape == (n_seeds, chrg_n, dim)
true_problem = DeltaProblem(weights=true_chrg_w, 
    locations=true_chrg_mu,
    tch_device=tch_device, tch_dtype=tch_dtype)

true_x = rng.uniform((n_seeds, n_true, dim)) * 2 - 1
assert true_x.shape == (n_seeds, n_true, dim)
if y_type == 'potential':
    y_dim = 1
    mse_mul = 100
    mse_clip = np.inf
    true_y = true_problem.potential(true_x).unsqueeze(-1)
    assert true_y.shape == (n_seeds, n_true, 1)
elif y_type == 'field':
    y_dim = dim
    mse_mul = 1
    mse_clip = 10
    true_y = true_problem.field(true_x)
    assert true_y.shape == (n_seeds, n_true, y_dim)
else:
    raise ValueError(f'y_type={y_type} undefined')


mbo = GPMBO(dim=sdim, n_mdl=n_mdl, n_seeds=n_seeds, 
            lr=mbo_lr, init_mu=0.0, init_std=init_std, gamma=0.0, 
            yb_gamma=0.9, rng=rng, optim='adam',
            tch_device=tch_device, tch_dtype=tch_dtype, 
            opt_siglog=opt_siglog)

all_mbo_mu = []
for mbo_iter in range(n_mboiter):
    php_query = mbo.ask()
    assert php_query.shape == (n_seeds, n_mdl, sdim)
    
    # Adding noise to the query to make the problem more challenging
    php_query = php_query + query_noise * rng.normal((n_seeds, n_mdl, sdim))
    assert php_query.shape == (n_seeds, n_mdl, sdim)

    # Running a fake query system
    loc_query = php_query.reshape(n_seeds*n_mdl, chrg_n, dim).detach().cpu().numpy()
    w_query = torch.ones(n_seeds*n_mdl, chrg_n).detach().cpu().numpy()

    problem_query = DeltaProblem(weights=w_query, 
        locations=loc_query,
        tch_device=tch_device, tch_dtype=tch_dtype)

    x_query = true_x.reshape(n_seeds, 1, n_true, dim)
    x_query = x_query.expand(n_seeds, n_mdl, n_true, dim)
    x_query = x_query.reshape(n_seeds*n_mdl, n_true, dim)
    
    if y_type == 'potential':
        y_query_ = problem_query.potential(x_query).unsqueeze(-1)
        assert y_query_.shape == (n_seeds*n_mdl, n_true, y_dim)
    elif y_type == 'field':
        y_query_ = problem_query.field(x_query)
        assert y_query_.shape == (n_seeds*n_mdl, n_true, y_dim)
    else:
        raise ValueError(f'y_type={y_type} undefined')

    y_query = y_query_.reshape(n_seeds, n_mdl, n_true, y_dim)
    assert y_query.shape == (n_seeds, n_mdl, n_true, y_dim)

    y_err = y_query - true_y.reshape(n_seeds, 1, n_true, y_dim)
    assert y_err.shape == (n_seeds, n_mdl, n_true, y_dim)

    y_mse_ = y_err.square().sum(dim=-1)
    assert y_mse_.shape == (n_seeds, n_mdl, n_true)
    
    y_mse_ = torch.clip(y_mse_ , 0.0, mse_clip)
    assert y_mse_.shape == (n_seeds, n_mdl, n_true)
    
    y_mse = mse_mul * y_mse_.mean(dim=-1)
    assert y_mse.shape == (n_seeds, n_mdl)
    
    if mbo_iter % 100 == 0:
        print(f'MSE: {y_mse.mean():.4f} +/- {y_mse.std()/np.sqrt(n_seeds*n_mdl):.4f}')

    mbo_score = -y_mse + prior_dist.log_prob(php_query).sum(dim=-1).detach() * 1.0
    assert mbo_score.shape == (n_seeds, n_mdl)
    
    mbo.tell(mbo_score)
    all_mbo_mu.append(mbo.mu)
    
mbo_mus_ = torch.cat(all_mbo_mu, dim=-2)
assert mbo_mus_.shape == (n_seeds, n_mboiter, chrg_n*dim)
mbo_mus = mbo_mus_.reshape(n_seeds, n_mboiter, chrg_n, dim)
assert mbo_mus.shape == (n_seeds, n_mboiter, chrg_n, dim)

print('The final Gaussian means:')
a = mbo.mu.reshape(n_seeds, chrg_n*dim).detach().cpu().numpy()
print(np.array_str(a, precision=3, suppress_small=True) + '\n' + '-'*80)

if opt_siglog:
    print('The final Gaussian stds:')
    a = mbo.sig_log.exp().reshape(n_seeds, chrg_n*dim).detach().cpu().numpy()
    print(np.array_str(a, precision=3, suppress_small=True) + '\n' + '-'*80)

## Meta-Training

In [None]:
def set_configs():
    #########################################################
    ################### Mandatory Options ###################
    #########################################################
    rng_seed_list = list(range(0, 100_000, 1000))
    dim = 2

    n_srf = 400
    n_srfpts_mdl = 1
    n_srfpts_trg = 1
    do_detspacing = False
    do_dblsampling = False

    do_bootstrap = True

    # # The original values for 3-charges
    # tau = 0.999
    # w_trgreg = 1.0
    # w_trg = 0.99

    # Fast training values for 1-charge
    tau = 0.984
    w_trgreg = 2.0
    w_trg = 0.99
    
    n_srf = 8
    n_srfpts_mdl = 1
    n_srfpts_trg = 100
    do_bootstrap = False

    opt_type = 'sgd'
    n_epochs = 1_000_000

    nn_dstr = 'mlp'
    nn_width = 64
    nn_hidden = 2
    nn_act = 'tanh'

    eval_bs = 256
    n_evlpnts = 1000
    eval_frq = 100
    chkpnt_period = 20000
    
    # Derived options and assertions
    n_points = n_srfpts_mdl + n_srfpts_trg

    assert not (do_dblsampling) or (n_srfpts_trg > 1)
    if w_trg is None:
        w_trg = n_srfpts_trg / n_points
    assert not (n_srfpts_mdl == 0) or (w_trg == 1.0)
    n_rsdls = 2 if do_dblsampling else 1

    if eval_bs is None:
        eval_bs = max(n_srfpts_mdl, n_srfpts_trg) * n_srf

    #########################################################
    ########### I/O-Related Options and Operations ##########
    #########################################################
    device_name = 'cuda:0'
    tch_device = torch.device(device_name)
    tch_dtype = torch.float32
    
    globals().update(locals())

def make_rng():
    #########################################################
    ########### Constructing the Batch RNG Object ###########
    #########################################################
    n_seeds = len(rng_seed_list)
    rng_seeds = np.array(rng_seed_list)
    rng = BatchRNG(shape=(n_seeds,), lib='torch',
                   device=tch_device, dtype=tch_dtype,
                   unif_cache_cols=1_000_000,
                   norm_cache_cols=5_000_000)
    rng.seed(np.broadcast_to(rng_seeds, rng.shape))
    globals().update(locals())

def make_problem():
    chrg_w = np.ones((n_seeds, chrg_n))
    chrg_mu = np.zeros((n_seeds, chrg_n, dim))
    problem = DeltaProblem(weights=chrg_w, locations=chrg_mu,
        tch_device=tch_device, tch_dtype=tch_dtype)
    globals().update(locals())    

def make_vol_samplers():
    vol_c_params = dict(c=np.zeros((n_seeds, dim)), r=np.ones(n_seeds))
    vol_r_params = dict(low=np.zeros(n_seeds), high=np.ones(n_seeds))
    volsampler = BallSampler(c_dstr='ball', c_params=vol_c_params,
                             r_dstr='unifdpow', r_params=vol_r_params,
                             batch_rng=rng)
    srfsampler = SphereSampler(batch_rng=rng)
    globals().update(locals())
    
def prep_tb():
    storage_dir = './15_search/02_meta/'
    pathlib.Path(storage_dir).mkdir(parents=True, exist_ok=True)
    strgidx = sum(isdir(f'{storage_dir}/{x}') for x in os.listdir(storage_dir))
    dtnow = datetime.datetime.now().isoformat(timespec='seconds')
    dtnow_ = dtnow[2:].replace('-', '').replace(':', '').replace('.', '')
    cfgstrg_dir = f'{storage_dir}/{strgidx:02d}_{dtnow_}'
    pathlib.Path(cfgstrg_dir).mkdir(parents=True, exist_ok=True)

    #if 'tbwriter' in locals():
    #    tbwriter.close()
    tbwriter = tensorboardX.SummaryWriter(cfgstrg_dir)
    logging.getLogger("tensorboardX.x2num").setLevel(logging.CRITICAL)
    globals().update(locals())
    
def make_model():
    # Initializing the model
    model = bffnn(dim, nn_width, nn_hidden, nn_act, (n_seeds,), rng)
    if do_bootstrap:
        target = bffnn(dim, nn_width, nn_hidden, nn_act, (n_seeds,), rng)
        target.load_state_dict(model.state_dict())
    else:
        target = model
    globals().update(locals())

def pre_train_prep():
    # Evaluation tools
    erng = rng
    last_perfdict = dict()
    ema = EMA(gamma=0.999, gamma_sq=0.998)
    trn_sttime = time.time()

    # Constructing the grid points
    with torch.no_grad():
        n_gpd = 50
        n_g = (n_gpd ** dim)
        elow, ehigh = -np.ones(dim), np.ones(dim)
        gdict = make_grid(elow, ehigh, dim, n_gpd, 'torch')
        grid_x_ = gdict['x']
        assert grid_x_.shape == (n_g, dim)
        grid_x = grid_x_.reshape(1, n_g, dim).expand(n_seeds, n_g, dim)
        grid_x = grid_x.to(tch_device, tch_dtype)
        assert grid_x.shape == (n_seeds, n_g, dim)

        x1_msh_np, x2_msh_np = gdict['xi_msh_np']

    with plt.style.context('default'):
        figax_list = [plt.subplots(1, 1, figsize=(3.2, 2.5), dpi=100) for _ in range(3)]
        (fig_mdl, ax_mdl), (fig_trg, ax_trg), (fig_gt, ax_gt) = figax_list
        cax_list = [make_axes_locatable(ax).append_axes('right', size='5%', pad=0.05) 
                    for ax in (ax_mdl, ax_trg, ax_gt)]
        cax_mdl, cax_trg, cax_gt = cax_list
    stat_history = defaultdict(list)
    train_history = odict()
    globals().update(locals())

def calc_grad():
    # Sampling the points from the srferes
    srfsamps = srfsampler(volsamps, n_points, do_detspacing=do_detspacing)
    points = nn.Parameter(srfsamps['points'])
    surfacenorms = srfsamps['normals']
    areas = srfsamps['areas']
    assert points.shape == (n_seeds, n_srf, n_points, dim)
    assert surfacenorms.shape == (n_seeds, n_srf, n_points, dim)
    assert areas.shape == (n_seeds, n_srf,)

    points_mdl = points[:, :, :n_srfpts_mdl, :]
    assert points_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    points_trg = points[:, :, n_srfpts_mdl:, :]
    assert points_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    surfacenorms_mdl = surfacenorms[:, :, :n_srfpts_mdl, :]
    assert surfacenorms_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    surfacenorms_trg = surfacenorms[:, :, n_srfpts_mdl:, :]
    assert surfacenorms_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    # Making surface integral predictions using the reference model
    u_mdl = model(points_mdl)
    assert u_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, 1)
    nabla_x_u_mdl, = torch.autograd.grad(u_mdl.sum(), [points_mdl],
        grad_outputs=None, retain_graph=True, create_graph=True,
        only_inputs=True, allow_unused=False)
    assert nabla_x_u_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl, dim)
    normprods_mdl = (nabla_x_u_mdl * surfacenorms_mdl).sum(dim=-1)
    assert normprods_mdl.shape == (n_seeds, n_srf, n_srfpts_mdl)
    if n_srfpts_mdl > 0:
        mean_normprods_mdl = normprods_mdl.mean(dim=-1, keepdim=True)
        assert mean_normprods_mdl.shape == (n_seeds, n_srf, 1)
    else:
        mean_normprods_mdl = 0.0

    # Making surface integral predictions using the target model
    u_trg = target(points_trg)
    assert u_trg.shape == (n_seeds, n_srf, n_srfpts_trg, 1)
    nabla_x_u_trg, = torch.autograd.grad(u_trg.sum(), [points_trg],
        grad_outputs=None, retain_graph=True, create_graph=not(do_bootstrap),
        only_inputs=True, allow_unused=False)
    assert nabla_x_u_trg.shape == (n_seeds, n_srf, n_srfpts_trg, dim)

    normprods_trg = (nabla_x_u_trg * surfacenorms_trg).sum(dim=-1)
    assert normprods_trg.shape == (n_seeds, n_srf, n_srfpts_trg)
    if do_dblsampling:
        assert n_rsdls == 2

        mean_normprods_trg1 = normprods_trg[..., 0::2].mean(
            dim=-1, keepdim=True)
        assert mean_normprods_trg1.shape == (n_seeds, n_srf, 1)

        mean_normprods_trg2 = normprods_trg[..., 1::2].mean(
            dim=-1, keepdim=True)
        assert mean_normprods_trg2.shape == (n_seeds, n_srf, 1)

        mean_normprods_trg = torch.cat(
            [mean_normprods_trg1, mean_normprods_trg2], dim=-1)
        assert mean_normprods_trg.shape == (n_seeds, n_srf, n_rsdls)
    else:
        assert n_rsdls == 1

        mean_normprods_trg = normprods_trg.mean(dim=-1, keepdim=True)
        assert mean_normprods_trg.shape == (n_seeds, n_srf, n_rsdls)

    # Linearly combining the reference and target predictions
    mean_normprods = (       w_trg  * mean_normprods_trg +
                      (1.0 - w_trg) * mean_normprods_mdl)
    assert mean_normprods.shape == (n_seeds, n_srf, n_rsdls)

    # Considering the surface areas
    pred_surfintegs = mean_normprods * areas.reshape(n_seeds, n_srf, 1)
    assert pred_surfintegs.shape == (n_seeds, n_srf, n_rsdls)

    # Getting the reference volume integrals
    ref_volintegs = problem.integrate_volumes(volsamps)
    assert ref_volintegs.shape == (n_seeds, n_srf)

    # Getting the residual terms
    resterms = pred_surfintegs - ref_volintegs.reshape(n_seeds, n_srf, 1)
    assert resterms.shape == (n_seeds, n_srf, n_rsdls)

    # Multiplying the residual terms
    if do_dblsampling:
        resterms_prod = resterms.prod(dim=-1)
        assert resterms_prod.shape == (n_seeds, n_srf)
    else:
        resterms_prod = torch.square(resterms).squeeze(-1)
        assert resterms_prod.shape == (n_seeds, n_srf)

    # Computing the main loss
    loss_main = resterms_prod.mean(-1)
    assert loss_main.shape == (n_seeds,)

    if do_bootstrap:
        with torch.no_grad():
            u_mdl_prime = target(points_mdl)
        loss_trgreg = torch.square(u_mdl - u_mdl_prime).mean([-3, -2, -1])
        assert loss_trgreg.shape == (n_seeds,)
    else:
        loss_trgreg = torch.zeros(n_seeds, device=tch_device, dtype=tch_dtype)
        assert loss_trgreg.shape == (n_seeds,)

    # The total loss
    loss = loss_main + w_trgreg * loss_trgreg
    assert loss.shape == (n_seeds,)

    loss_sum = loss.sum()
    loss_sum.backward()
    globals().update(locals())

def print_tb_fig():
    with torch.no_grad():
        grid_prbsol = get_prob_sol(problem, grid_x, n_eval=eval_bs, 
            get_field=False, out_lib='numpy')
        assert grid_prbsol['v_np'].shape == (n_seeds, n_g)

        grid_mdlsol = get_nn_sol(model, grid_x, n_eval=eval_bs,
            get_field=False, out_lib='numpy')
        assert grid_mdlsol['v_np'].shape == (n_seeds, n_g)

        if do_bootstrap:
            grid_trgsol = get_nn_sol(target, grid_x, n_eval=eval_bs, 
                get_field=False, out_lib='numpy')
            assert grid_trgsol['v_np'].shape == (n_seeds, n_g)

        soltd_list = [('gt', grid_prbsol, fig_gt, ax_gt, cax_gt, 'Ground Truth'),
                      ('mdl', grid_mdlsol, fig_mdl, ax_mdl, cax_mdl, 'Prediction')]
        if do_bootstrap:
            soltd_list += [('trg', grid_trgsol, fig_trg, ax_trg, cax_trg, 'Target')]
        for sol_t, sol_dict, fig, ax, cax, ttl in soltd_list:
            plot_sol(x1_msh_np, x2_msh_np, sol_dict, fig=fig, ax=ax, cax=cax)
            ax.set_title(ttl)
            fig.set_tight_layout(True)
            tbwriter.add_figure(f'viz/{sol_t}', fig, epoch)
        tbwriter.flush()
    globals().update(locals())
    
    if epoch % eval_frq == 0:
        # Sampling the evaluation points
        def print_tb_perf():
            with torch.no_grad():
                evols = volsampler(n=n_evlpnts)
                assert evols['type'] == 'ball'

                e_c = evols['centers']
                assert e_c.shape == (n_seeds, n_evlpnts, dim)

                e_r_ = evols['radii']
                assert e_r_.shape == (n_seeds, n_evlpnts)

                e_r = e_r_.unsqueeze(dim=-1)
                assert e_r.shape == (n_seeds, n_evlpnts, 1)

                untrd = erng.uniform((n_seeds, n_evlpnts, 1))
                assert untrd.shape == (n_seeds, n_evlpnts, 1)

                untr = untrd.pow(1.0 / dim)
                assert untr.shape == (n_seeds, n_evlpnts, 1)

                e_pntrs = untr * e_r
                assert e_pntrs.shape == (n_seeds, n_evlpnts, 1)

                etheta = erng.normal((n_seeds, n_evlpnts, dim))
                assert etheta.shape == (n_seeds, n_evlpnts, dim)

                ethtilde = etheta / etheta.norm(dim=-1, keepdim=True)
                assert ethtilde.shape == (n_seeds, n_evlpnts, dim)

                e_pnts = e_c + ethtilde * e_pntrs
                assert e_pnts.shape == (n_seeds, n_evlpnts, dim)

            eperfs = eval_pnts(problem, model, target, e_pnts, do_bootstrap,
                n_seeds, n_evlpnts, dim, eval_bs)
            for kk, vv in eperfs.items():
                if '/bc/' in kk:
                    tbwriter.add_scalar(f'perf/{kk}', vv.mean(), epoch)
            globals().update(locals())
    
    def print_tb_scalars():
        # computing the normal product variances
        with torch.no_grad(): 
            normprods = torch.cat([normprods_mdl, normprods_trg], dim=-1)
            npvm = (normprods.var(dim=-1)*areas.square()).mean(-1)

        # Computing the loss moving averages
        loss_ema_mean, loss_ema_std_mean = ema('loss', loss)
        npvm_ema_mean, npvm_ema_std_mean = ema('npvm', npvm)

        tbwriter.add_scalar('loss/total', loss.mean(), epoch)
        tbwriter.add_scalar('loss/main', loss_main.mean(), epoch)
        tbwriter.add_scalar('loss/trgreg', loss_trgreg.mean(), epoch)
        tbwriter.add_scalar('loss/npvm', npvm.mean(), epoch)
        globals().update(locals())
    
    if epoch % 1000 == 0:
        def print_stdout():
            print_str = f'Epoch {epoch}, EMA loss = {loss_ema_mean:.4f}'
            print_str += f' +/- {2*loss_ema_std_mean:.4f}'
            print_str += f', EMA Field-Norm Product Variance = {npvm_ema_mean:.4f}'
            print_str += f' +/- {2*npvm_ema_std_mean:.4f} ({time.time()-trn_sttime:0.1f} s)'
            print(print_str, flush=True)
            globals().update(locals())
    
    if epoch % chkpnt_period == 0:
        def take_checkpoint():
            train_history[f'{epoch}/mdl'] = deepcopy({k: v.cpu() for k, v
                in model.state_dict().items()})
            train_history[f'{epoch}/trg'] = deepcopy({k: v.cpu() for k, v
                in target.state_dict().items()})
            train_history[f'{epoch}/prb'] = deepcopy({k: v.cpu() for k, v
                in problem.state_dict().items()})
            globals().update(locals())    

def print_tb_perf():
    with torch.no_grad():
        evols = volsampler(n=n_evlpnts)
        assert evols['type'] == 'ball'

        e_c = evols['centers']
        assert e_c.shape == (n_seeds, n_evlpnts, dim)

        e_r_ = evols['radii']
        assert e_r_.shape == (n_seeds, n_evlpnts)

        e_r = e_r_.unsqueeze(dim=-1)
        assert e_r.shape == (n_seeds, n_evlpnts, 1)

        untrd = erng.uniform((n_seeds, n_evlpnts, 1))
        assert untrd.shape == (n_seeds, n_evlpnts, 1)

        untr = untrd.pow(1.0 / dim)
        assert untr.shape == (n_seeds, n_evlpnts, 1)

        e_pntrs = untr * e_r
        assert e_pntrs.shape == (n_seeds, n_evlpnts, 1)

        etheta = erng.normal((n_seeds, n_evlpnts, dim))
        assert etheta.shape == (n_seeds, n_evlpnts, dim)

        ethtilde = etheta / etheta.norm(dim=-1, keepdim=True)
        assert ethtilde.shape == (n_seeds, n_evlpnts, dim)

        e_pnts = e_c + ethtilde * e_pntrs
        assert e_pnts.shape == (n_seeds, n_evlpnts, dim)

    eperfs = eval_pnts(problem, model, target, e_pnts, do_bootstrap,
        n_seeds, n_evlpnts, dim, eval_bs)
    for kk, vv in eperfs.items():
        if '/bc/' in kk:
            tbwriter.add_scalar(f'perf/{kk}', vv.mean(), epoch)
    globals().update(locals())

def print_tb_scalars():
    # computing the normal product variances
    with torch.no_grad(): 
        normprods = torch.cat([normprods_mdl, normprods_trg], dim=-1)
        npvm = (normprods.var(dim=-1)*areas.square()).mean(-1)

    # Computing the loss moving averages
    loss_ema_mean, loss_ema_std_mean = ema('loss', loss)
    npvm_ema_mean, npvm_ema_std_mean = ema('npvm', npvm)

    tbwriter.add_scalar('loss/total', loss.mean(), epoch)
    tbwriter.add_scalar('loss/main', loss_main.mean(), epoch)
    tbwriter.add_scalar('loss/trgreg', loss_trgreg.mean(), epoch)
    tbwriter.add_scalar('loss/npvm', npvm.mean(), epoch)
    globals().update(locals())

def print_stdout():
    print_str = f'Epoch {epoch}, EMA loss = {loss_ema_mean:.4f}'
    print_str += f' +/- {2*loss_ema_std_mean:.4f}'
    print_str += f', EMA Field-Norm Product Variance = {npvm_ema_mean:.4f}'
    print_str += f' +/- {2*npvm_ema_std_mean:.4f} ({time.time()-trn_sttime:0.1f} s)'
    print(print_str, flush=True)
    globals().update(locals())

def take_checkpoint():
    train_history[f'{epoch}/mdl'] = deepcopy({k: v.cpu() for k, v
        in model.state_dict().items()})
    train_history[f'{epoch}/trg'] = deepcopy({k: v.cpu() for k, v
        in target.state_dict().items()})
    train_history[f'{epoch}/prb'] = deepcopy({k: v.cpu() for k, v
        in problem.state_dict().items()})
    globals().update(locals())
    
### Utility functions
def sample_unit_ball(n_seeds, chrg_n, dim, rng):
    rnd1 = rng.normal((n_seeds, chrg_n, dim))
    rnd1 = rnd1 / ((rnd1**2).sum(-1, keepdims=True)**0.5)

    rnd2 = rng.uniform((n_seeds, chrg_n, 1))
    rnd2 = rnd2 ** (1./dim)

    return rnd2 * rnd1

In [None]:
set_configs()

make_rng()
chrg_n = 1
make_problem()
make_vol_samplers()
prep_tb()
make_model()

# Set the optimizer
lr = 0.001
opt = torch.optim.SGD(model.parameters(), lr)

# Setting up the meta-parameters
n_mseeds = 1
n_mtasks = n_seeds // n_mseeds
assert (n_mseeds * n_mtasks) == n_seeds

meta_model_sd = dict()
if do_bootstrap:
    meta_target_sd = dict()
for name, param in model.state_dict().items():
    pshape = param.shape[1:]
    param1 = param.reshape(n_mseeds, n_mtasks, *pshape)
    assert param1.shape == (n_mseeds, n_mtasks, *pshape)
    param2 = param1.mean(dim=1)
    assert param2.shape == (n_mseeds, *pshape)
    meta_model_sd[name] = torch.nn.Parameter(param2.detach())      
    if do_bootstrap:
        meta_target_sd[name] = param2.clone().detach()
meta_opt = torch.optim.SGD(list(meta_model_sd.values()), lr=0.1)

pre_train_prep()

for epoch in range(n_epochs+1):
    with torch.no_grad():
        mesd = {name: param.reshape(n_mseeds,        1, *param.shape[1:]
                           ).expand(n_mseeds, n_mtasks, *param.shape[1:]
                          ).reshape(n_seeds,            *param.shape[1:])
                for name, param in meta_model_sd.items()}
    model.load_state_dict(mesd)
    if do_bootstrap:
        with torch.no_grad():
            tesd = {name: param.reshape(n_mseeds,        1, *param.shape[1:]
                               ).expand(n_mseeds, n_mtasks, *param.shape[1:]
                              ).reshape(n_seeds,            *param.shape[1:])
                for name, param in meta_target_sd.items()}
        target.load_state_dict(tesd)
    
    # Sampling new tasks
    problem.locations = None
    problem.locations_tch = sample_unit_ball(n_seeds, chrg_n, dim, rng)
    
    # Meta-train
    opt.zero_grad()
    volsamps = volsampler(n=n_srf)
    calc_grad()
    if (epoch > 0):
        opt.step()

    model_sd = model.state_dict()
    if do_bootstrap and (epoch > 0):
        target_sd = target.state_dict()
        newtrg_sd = dict()
        with torch.no_grad():
            for key, param in model_sd.items():
                param_trg = target_sd[key]
                newtrg_sd[key] = tau * param_trg + (1-tau) * param
        target.load_state_dict(newtrg_sd)
    
    # Meta-validation
    volsamps = volsampler(n=n_srf)
    calc_grad()
    
    meta_opt.zero_grad()
    
    mnp = dict(model.named_parameters())
    for name, mparam in meta_model_sd.items():
        param = mnp[name]
        pshape = param.shape[1:]
        assert param.shape == (n_seeds, *pshape)
        
        pgrad = param.grad
        if pgrad is None:
            pgrad = torch.zeros_like(param)
        assert pgrad.shape == (n_seeds, *pshape)
        
        mpgrad_ = pgrad.reshape(n_mseeds, n_mtasks, *pshape)
        assert mpgrad_.shape == (n_mseeds, n_mtasks, *pshape)
        
        mpgrad = mpgrad_.mean(dim=1)
        assert mpgrad.shape == (n_mseeds, *pshape)
        assert mparam.shape == (n_mseeds, *pshape)
        
        mparam.grad = mpgrad
    
    if epoch > 0:
        meta_opt.step()
    
    if do_bootstrap and (epoch > 0):
        for name, mparam in meta_model_sd.items():
            mtparam = meta_target_sd[name]
            meta_target_sd[name] = tau * mtparam + (1-tau) * mparam
        
    if epoch % 1000 == 0:
        print_tb_fig()
    if epoch % eval_frq == 0:
        print_tb_perf()
    print_tb_scalars()
    if epoch % 100 == 0:
        print_stdout()
    if epoch % chkpnt_period == 0:
        take_checkpoint()
        
print(f'Training finished in {time.time() - trn_sttime:.1f} seconds.')
tbwriter.flush()

In [None]:
n_rows, n_cols = 1, 2 + do_bootstrap
fig, axes = plt.subplots(n_rows, n_cols, figsize=(
    n_cols * 3.5, n_rows * 3), dpi=72, sharex=True, sharey=True)
cax = None

# Computing the model, target and ground truth solutions
prob_sol = get_prob_sol(problem, grid_x, n_eval=eval_bs, get_field=False)
with torch.no_grad():
    mdl_sol = get_nn_sol(model, grid_x, n_eval=eval_bs, get_field=False) 
    if do_bootstrap:
        trg_sol = get_nn_sol(target, grid_x, n_eval=eval_bs, get_field=False)

soltd_list = [('gt', prob_sol, axes[0], 'Ground Truth'),
              ('mdl', mdl_sol, axes[1], 'Prediction')]
if do_bootstrap:
    soltd_list += [('trg', trg_sol, axes[2], 'Target')]
for sol_t, sol_dict, ax, ttl in soltd_list:
    plot_sol(x1_msh_np, x2_msh_np, sol_dict, fig=fig, ax=ax, cax=cax)
    ax.set_title(ttl)
fig

In [None]:
torch.save(train_history, f'{cfgstrg_dir}/train_history.pt')

hp_dict = dict(dim=dim, n_srf=n_srf, n_srfpts_mdl=n_srfpts_mdl, 
               n_srfpts_trg=n_srfpts_trg, do_detspacing=do_detspacing, 
               do_dblsampling=do_dblsampling, do_bootstrap=do_bootstrap,
               tau=tau, w_trgreg=w_trgreg, w_trg=w_trg, opt_type=opt_type, 
               n_epochs=n_epochs, lr=lr, nn_dstr=nn_dstr, nn_width=nn_width, 
               nn_hidden=nn_hidden, nn_act=nn_act)

with open(f'{cfgstrg_dir}/config.json', "w") as outfile:
    json.dump(hp_dict, outfile)

# MCMC

In [None]:
# Example 1: Fitting to the potentials without query noise
chrg_n = 3
query_noise = 0.00
n_mcmciter = 4000
y_type = 'field'

# Common settings
dim = 2
n_seeds = 10000
n_true = 50

tch_device = torch.device('cuda:0')
tch_dtype = torch.float32

# Creating the RNG
seeds_arr = (np.arange(n_seeds) * 1000).tolist()
rng = BatchRNG(shape=(n_seeds,), lib='torch',
       device=tch_device, dtype=tch_dtype,
       unif_cache_cols=1_000,
       norm_cache_cols=5_000)
rng.seed(np.array(seeds_arr))

# Creating the true problem
true_chrg_w = np.ones((n_seeds, chrg_n))
assert true_chrg_w.shape == (n_seeds, chrg_n)
true_chrg_mu = np.array([[-0.5, -0.5],
                         [ 0.5,  0.5],
                         [ 0.0,  0.0]])[:chrg_n, :]
true_chrg_mu = np.broadcast_to(true_chrg_mu[None, ...], 
    (n_seeds, chrg_n, dim)).copy()
assert true_chrg_mu.shape == (n_seeds, chrg_n, dim)
true_problem = DeltaProblem(weights=true_chrg_w, 
    locations=true_chrg_mu,
    tch_device=tch_device, tch_dtype=tch_dtype)
sdim = chrg_n * dim

# Tensorboard
storage_dir = './15_search/03_mcmc'
pathlib.Path(storage_dir).mkdir(parents=True, exist_ok=True)
strgidx = sum(isdir(f'{storage_dir}/{x}') for x in os.listdir(storage_dir))
dtnow = datetime.datetime.now().isoformat(timespec='seconds')
dtnow_ = dtnow[2:].replace('-', '').replace(':', '').replace('.', '')
cfgstrg_dir = f'{storage_dir}/{strgidx:02d}_{dtnow_}'
pathlib.Path(cfgstrg_dir).mkdir(parents=True, exist_ok=True)

if 'tbwriter' in locals():
    tbwriter.close()
tbwriter = tensorboardX.SummaryWriter(cfgstrg_dir)
logging.getLogger("tensorboardX.x2num").setLevel(logging.CRITICAL) 

# MCMC Parameters
import torch.distributions
mu_mcmc = rng.uniform((n_seeds, sdim)) * 2 - 1.0
sig_mcmc = 0.01
logl_mcmc = None
temp_mcmc = 100.0
n_cpg = 100
tch_normal = torch.distributions.normal.Normal
prior_mcmc = tch_normal(loc=torch.zeros(sdim, device=tch_device, dtype=tch_dtype),
                        scale=torch.ones(sdim, device=tch_device, dtype=tch_dtype))

reset_period_mcmc = n_mcmciter // 2
reset_topk_mcmc = n_cpg // 2

n_grps = n_seeds // n_cpg
assert (n_grps * n_cpg) == n_seeds

# Creating the data
true_x1 = rng.uniform((n_seeds, n_true, dim)) * 2 - 1
assert true_x1.shape == (n_seeds, n_true, dim)
true_x2 = true_x1.reshape(n_grps, n_cpg, n_true, dim)
assert true_x2.shape == (n_grps, n_cpg, n_true, dim)
true_x3 = true_x2[:, :1, ...]
assert true_x3.shape == (n_grps, 1, n_true, dim)
true_x4 = true_x3.expand(n_grps, n_cpg, n_true, dim)
assert true_x4.shape == (n_grps, n_cpg, n_true, dim)
true_x = true_x4.reshape(n_seeds, n_true, dim)
assert true_x.shape == (n_seeds, n_true, dim)

if y_type == 'potential':
    y_dim = 1
    mse_mul = 100
    mse_clip = np.inf
    true_y = true_problem.potential(true_x).unsqueeze(-1)
    assert true_y.shape == (n_seeds, n_true, 1)
elif y_type == 'field':
    y_dim = dim
    mse_mul = 1
    mse_clip = 10
    true_y = true_problem.field(true_x)
    assert true_y.shape == (n_seeds, n_true, y_dim)
else:
    raise ValueError(f'y_type={y_type} undefined')


trace_mcmc = defaultdict(list)
for mcmc_iter in range(n_mcmciter):
    with torch.no_grad():
        mu_proposed = mu_mcmc
        if mcmc_iter > 0:
            mu_proposed = mu_proposed + sig_mcmc * rng.normal((n_seeds, sdim))
        assert mu_proposed.shape == (n_seeds, sdim)
        
        # Adding noise to the query to make the problem more challenging
        mu_proposed = mu_proposed + query_noise * rng.normal((n_seeds, sdim))
        assert mu_proposed.shape == (n_seeds, sdim)

        # Running a fake query system
        loc_proposed = mu_proposed.reshape(n_seeds, chrg_n, dim).detach().cpu().numpy()
        w_proposed = torch.ones(n_seeds, chrg_n).detach().cpu().numpy()

        problem_proposed = DeltaProblem(weights=w_proposed, 
            locations=loc_proposed,
            tch_device=tch_device, tch_dtype=tch_dtype)
    
    if y_type == 'potential':
        y_proposed_ = problem_proposed.potential(true_x).unsqueeze(-1)
        assert y_proposed_.shape == (n_seeds, n_true, y_dim)
    elif y_type == 'field':
        y_proposed_ = problem_proposed.field(true_x)
        assert y_proposed_.shape == (n_seeds, n_true, y_dim)
    else:
        raise ValueError(f'y_type={y_type} undefined')
   
    with torch.no_grad():
        y_proposed = y_proposed_.reshape(n_seeds, n_true, y_dim)
        assert y_proposed.shape == (n_seeds, n_true, y_dim)

        y_err = y_proposed - true_y.reshape(n_seeds, n_true, y_dim)
        assert y_err.shape == (n_seeds, n_true, y_dim)

        y_mse_ = y_err.square().sum(dim=-1)
        assert y_mse_.shape == (n_seeds, n_true)
        
        y_mse_ = torch.clip(y_mse_ , 0.0, mse_clip)
        assert y_mse_.shape == (n_seeds, n_true)
        
        y_mse = mse_mul * y_mse_.mean(dim=-1)
        assert y_mse.shape == (n_seeds,)

        
        logl_proposed = -y_mse
        assert logl_proposed.shape == (n_seeds,)
        logpri_proposed = prior_mcmc.log_prob(mu_proposed).sum(-1)
        assert logpri_proposed.shape == (n_seeds,)
        logpost_proposed = logpri_proposed + logl_proposed
        assert logpost_proposed.shape == (n_seeds,)
        
        logpri_mcmc = prior_mcmc.log_prob(mu_mcmc).sum(-1)
        assert logpri_mcmc.shape == (n_seeds,)
        if mcmc_iter == 0:
            logl_mcmc = logl_proposed
            assert logl_mcmc.shape == (n_seeds,)
            best_logl_mcmc = logl_mcmc
            assert best_logl_mcmc.shape == (n_seeds,)
        logpost_mcmc = logpri_mcmc + logl_mcmc
        assert logpost_mcmc.shape == (n_seeds,)
        if mcmc_iter == 0:
            best_logpost_mcmc = logpost_mcmc
            assert best_logpost_mcmc.shape == (n_seeds,)
        
        logpost_diff = logpost_proposed - logpost_mcmc
        assert logpost_diff.shape == (n_seeds,)
        # logl_diff = logl_proposed - logl_mcmc
        # assert logl_diff.shape == (n_seeds,)
        # logpri_diff = prior_mcmc.log_prob(mu_proposed).sum(-1) - prior_mcmc.log_prob(mu_mcmc).sum(-1)
        # assert logpri_diff.shape == (n_seeds,)
        # logpost_diff = logpri_diff + logl_diff
        # assert logpost_diff.shape == (n_seeds,)
        acceptance_mcmc = (temp_mcmc * logpost_diff.unsqueeze(-1)) > rng.uniform((n_seeds,1)).log()
        assert acceptance_mcmc.shape == (n_seeds, 1)
        mu_mcmc = torch.where(acceptance_mcmc, mu_proposed, mu_mcmc)
        assert mu_mcmc.shape == (n_seeds, sdim)
        logl_mcmc = torch.where(acceptance_mcmc.squeeze(-1), logl_proposed, logl_mcmc)
        assert logl_mcmc.shape == (n_seeds,)
        
        best_logl_mcmc = torch.where(logl_mcmc > best_logl_mcmc, logl_mcmc, best_logl_mcmc)
        assert best_logl_mcmc.shape == (n_seeds,)
        best_logpost_mcmc = torch.where(logpost_mcmc > best_logpost_mcmc, logpost_mcmc, best_logpost_mcmc)
        assert best_logpost_mcmc.shape == (n_seeds,)
            
        trace_mcmc['mu'].append(mu_mcmc)
        trace_mcmc['acceptance'].append(acceptance_mcmc)
        trace_mcmc['logl'].append(logl_mcmc)
        
        blm = best_logl_mcmc.reshape(n_grps, n_cpg)
        tbwriter.add_scalar('acceptance', acceptance_mcmc.float().mean(), mcmc_iter)
        for q in np.linspace(0, 1, 5):
            tbwriter.add_scalar(f'best_logl/q{int(100*q):02d}', 
                                blm.quantile(q, dim=-1).median(), mcmc_iter)
        tbwriter.add_scalar('logl', logl_mcmc.mean(), mcmc_iter)
        
        if mcmc_iter % 100 == 0:
            print(f'MCMC Iteration {mcmc_iter:04d}: Best MSE: {-best_logl_mcmc.max():.4f}')
            tbwriter.flush()
        
        if (mcmc_iter % reset_period_mcmc == 0) and (mcmc_iter > 0):
            mu_mcmc_ = mu_mcmc.reshape(n_grps, n_cpg, sdim)
            assert mu_mcmc_.shape == (n_grps, n_cpg, sdim)
            
            logpost_mcmc_ = logpost_mcmc.reshape(n_grps, n_cpg)
            assert logpost_mcmc_.shape == (n_grps, n_cpg)
            
            top_idx = torch.topk(logpost_mcmc_, reset_topk_mcmc, dim=1, largest=True, sorted=True).indices
            assert top_idx.shape == (n_grps, reset_topk_mcmc)
            
            top_chains = torch.take_along_dim(mu_mcmc_, top_idx.unsqueeze(-1), dim=-2)
            assert top_chains.shape == (n_grps, reset_topk_mcmc, sdim)
            
            aa = n_cpg // reset_topk_mcmc
            mu_mcmc = top_chains.reshape(n_grps, 1, reset_topk_mcmc, sdim)
            mu_mcmc = mu_mcmc.expand(n_grps, aa, reset_topk_mcmc, sdim)
            mu_mcmc = mu_mcmc.reshape(n_seeds, sdim)
            assert mu_mcmc.shape == (n_seeds, sdim)

tbwriter.flush()
with torch.no_grad():
    trace_mcmc_ = dict()
    for key, tnesor_list in trace_mcmc.items():
        trace_mcmc_[key] = torch.stack(tnesor_list, dim=0)

In [None]:
trace_mu = trace_mcmc_['mu']
n_srchdraws, n_seeds, sdim = trace_mu.shape

trace_mu = trace_mu.reshape(n_srchdraws, n_seeds, chrg_n, dim)
assert trace_mu.shape == (n_srchdraws, n_seeds, chrg_n, dim)

trace_mu = trace_mu.detach().cpu().numpy()

kdedata = trace_mu[-1].reshape(n_seeds*chrg_n, dim)
assert kdedata.shape == (n_seeds*chrg_n, dim)

kdedata = kdedata.reshape(n_grps, n_cpg, chrg_n, dim)[1].reshape(n_cpg*chrg_n, dim)

%matplotlib inline
import seaborn as sns
fig_srch, ax_srch = plt.subplots(1, 1, dpi=100)

# sns.kdeplot(
#     x=kdedata[:, 0], y=kdedata[:, 1],
#     fill=True, thresh=0, levels=100, cmap="RdBu",
# )

ax_srch.scatter(kdedata[:, 0], kdedata[:, 1], s=1)

In [None]:
trace_mu = trace_mcmc_['mu']
n_srchdraws, n_seeds, sdim = trace_mu.shape

ii = trace_mcmc_['logl'][700].argmax().item()
bst = trace_mu[:, ii, ...].reshape(n_srchdraws, chrg_n, dim).transpose(0, 1)
assert bst.shape == (chrg_n, n_srchdraws, dim)

bst = bst.detach().cpu().numpy()
bst = bst[:, 0:700, :]

%matplotlib inline
fig_srch, ax_srch = plt.subplots(1, 1, dpi=100)

for i_chrg, chrg_traj in enumerate(bst):
    ax_srch.plot(chrg_traj[:, 0], chrg_traj[:, 1])
    
ax_srch.set_xlim(-1, 1)
ax_srch.set_ylim(-1, 1)