In [None]:
import jax
from jax import numpy as np
from jax import vmap
import optax
import numpy as onp
import sys
sys.path.append('../py')
%load_ext autoreload
%autoreload 2


import common.drifts as drifts
import common.networks as networks
import haiku as hk
import common.grid_utils as grid_utils
import launchers.few_particles_split_score as launcher
import matplotlib.animation as animation
from typing import Tuple, Dict, Callable
from ml_collections import config_dict
import functools
import time


import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns


mpl.rcParams['axes.grid']           = True
mpl.rcParams['axes.grid.which']     = 'both'
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['ytick.minor.visible'] = True
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['axes.facecolor']      = 'white'
mpl.rcParams['grid.color']          = '0.8'
mpl.rcParams['grid.alpha']          = '0.1'
mpl.rcParams['text.usetex']         = True
mpl.rcParams['font.family']         = 'serif'
mpl.rcParams['figure.figsize']      = (8, 4)
mpl.rcParams['figure.titlesize']    = 7.5
mpl.rcParams['font.size']           = 10
mpl.rcParams['legend.fontsize']     = 7.5
mpl.rcParams['figure.dpi']          = 300


from tqdm.auto import tqdm
import dill as pickle

# Load data

In [None]:
base_folder = '/scratch/nb3397/results/mips/lowd_cates/'
v0s         = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
s

## score matching + PINN
output_names = ["7_26_23/7_26_23_N64_smaller_dt_0",
                "8_6_23/compressed/8_6_23_N64_sm_pinn_only_a100_0",
                "8_15_23/8_15_23_N64_sm_pinn_only_a100_extend_8_4_23_1_0_56.0",
                "8_6_23/compressed/8_6_23_N64_sm_pinn_only_a100_2",
                "8_6_23/compressed/8_6_23_N64_sm_pinn_only_a100_3",
                "8_6_23/compressed/8_6_23_N64_sm_pinn_only_a100_4"]

data_dicts = {
    v0: pickle.load(open(f'{base_folder}/{output_names[ii]}.npy', 'rb')) for ii, v0 in enumerate(v0s)


N = data_dicts[0.0]['cfg'].N
d = data_dicts[0.0]['cfg'].d

In [None]:
data_dicts[0.2]['xgs'].shape

# Utility Functions

In [None]:
def make_cfg_hashable(cfg: config_dict.ConfigDict):
    """Fix some non-hashable types in older versions of the code."""
    try:
        del cfg.radii
    except:
        pass

    try:
        float_width = float(cfg.width)
        del cfg.width
        cfg.width = float_width
    except:
        print('Already frozen!')

    try:
        float_sig0x = float(cfg.sig0x)
        del cfg.sig0x
        cfg.sig0x = float_sig0x
    except:
        print('Already frozen!')
 
    return config_dict.FrozenConfigDict(cfg)


def get_params(
    data_dict: dict,
    ema_fac: float
) -> hk.Params:
    """Backwards compatibility for old saving routines."""
    try:
        print("Trying to load via params_list.")
        if ema_fac > 0:
            params = data_dict['ema_params_list'][-1][ema_fac]
        else:
            params = data_dict['params_list'][-1]
        print('Success!')
    except:
        print("Failed! Loading params directly.")
        if ema_fac > 0:
            params = data_dict['ema_params'][ema_fac]
        else:
            params = data_dict['params']
        print('Success!')
            
    return params

In [None]:
compute_traj_output_info = jax.jit(
    jax.vmap(launcher.compute_output_info, in_axes=(0, None, None, None, None)), 
    static_argnums=(2, 3, 4)
)


def compute_particle_quantities(
    batch_size: int,
    average_xgs: np.ndarray, # [n_v0s, n_pts_average, 2N, d]
    ema_fac: float
) -> np.ndarray:
    n_pts_average = average_xgs.shape[1]
    assert(n_pts_average % batch_size == 0)
    n_batches = n_pts_average // batch_size

    ## obtain particle-based quantities for spatial averaging
    n_v0s = len(v0s)
    particle_quantities = {
        'gdot':   onp.zeros((n_v0s, n_pts_average, N)),
        'xdot':   onp.zeros((n_v0s, n_pts_average, N)),
        'v':      onp.zeros((n_v0s, n_pts_average, N)),
        'div_vg': onp.zeros((n_v0s, n_pts_average, N)),
        'div_vx': onp.zeros((n_v0s, n_pts_average, N)),
        'div_v':  onp.zeros((n_v0s, n_pts_average, N)),
        'xs':     onp.zeros((n_v0s, n_pts_average, N, 2))
    }


    for kk, v0 in enumerate(v0s):
        # unpack everything
        data_dict   = data_dicts[v0]
        cfg         = make_cfg_hashable(data_dict['cfg'])
        params      = get_params(data_dict, ema_fac)
        xgs         = average_xgs[kk]
        xs, gs      = np.split(xgs, 2, axis=1)
        particle_quantities['xs'][kk] = xs
        score_net, particle_div_net, _, _ = launcher.construct_network(cfg)

        # compute quantity of interest for spatial averaging
        for curr_batch in range(n_batches):
            lb = curr_batch*batch_size
            ub = lb + batch_size
            batch_outputs = compute_traj_output_info(xgs[lb:ub], params, cfg, score_net, particle_div_net)
            
            particle_quantities['gdot'][kk, lb:ub]   = batch_outputs[2]
            particle_quantities['xdot'][kk, lb:ub]   = batch_outputs[3]
            particle_quantities['v'][kk, lb:ub]      = np.linalg.norm(batch_outputs[7], axis=-1)
            particle_quantities['div_vg'][kk, lb:ub] = batch_outputs[-3]
            particle_quantities['div_vx'][kk, lb:ub] = batch_outputs[-4]
            particle_quantities['div_v'][kk, lb:ub]  = batch_outputs[-5]
            
        print(f'Finished computing particle quantities for {kk+1}/{len(v0s)}')

    return particle_quantities

# On-Particle Plot

In [None]:
def compute_trajectories(
    n_steps: int,
    start_index: int,
) -> np.ndarray: # [nv0s, n_steps, 2*N, d]
    ## obtain particle-based quantities for spatial averaging
    n_v0s = len(v0s)
    trajs = onp.zeros((n_v0s, n_steps, 2*N, d))

    for kk, v0 in enumerate(v0s):
        # unpack everything
        data_dict   = data_dicts[v0]
        cfg         = make_cfg_hashable(data_dict['cfg'])
        init = data_dict['xgs'][start_index]
        _, trajs[kk] = launcher.rollout(init, onp.random.randn(n_steps, 2*N, d), cfg)
        print(f'Finished trajectory {kk+1}/{n_v0s}.')

    return trajs

In [None]:
def radius_to_points(ax, radius):
    """Convert radius in axis units to radius in points, for use
    with making scatterplots. Written by GPT4."""
    # Convert the radius from data units to display units (pixels)
    pixel_radius = ax.transData.transform([(radius, radius)]) - ax.transData.transform([(0, 0)])
    
    # Compute the area of the circle in pixels²
    area_pixels = np.pi * (pixel_radius[0, 0] ** 2)

    # Convert the area from pixels² to points²
    area_points = area_pixels / (plt.gcf().dpi / 72) ** 2
    
    return area_points

In [None]:
def make_particle_plot(
    quantities: np.ndarray,
    xs_plot: np.ndarray,
    normalize: bool,
    clip_quantile: float,
    title: str,
    save_str: str
) -> None:
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 5.5
    fraction  = 0.15
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 20
    
    # normalize by v0
    if normalize:
        quantities = quantities / (np.array(v0s) + 1e-4)[:quantities.shape[0], None]
        title      = title + r"$/v_0$"


    # individual panels
    titles = [rf"$v_0={v0}$" for v0 in v0s[:quantities.shape[0]]]
    cmap   = sns.color_palette('mako', as_cmap=True)
    nrows  = 1
    ncols  = len(titles)


    # set up the figure
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )


    # single colorbar
    vmin, vmax = onp.quantile(quantities, q=clip_quantile), onp.quantile(quantities, q=(1-clip_quantile))
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
    mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs.ravel(), 
                        fraction=fraction, shrink=shrink, pad=pad,
                        orientation='vertical')
    cbar.set_label(title, fontsize=fontsize)
    cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)


    # make scatter plot
    for kk, ax in enumerate(axs):
        # update basic visual aspects
        cfg = data_dicts[v0s[kk]]['cfg']
        ax.set_facecolor('k')
        ax.set_xlim([-0.8*cfg.width, 0.8*cfg.width])
        ax.set_ylim([-0.8*cfg.width, 0.8*cfg.width])
        ax.set_xticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.axes.set_aspect(1.0)
        ax.grid(which='both', axis='both', color='0.90', alpha=0.2)
        ax.tick_params(axis='both', length=0, width=0, labelsize=fontsize)
        s = radius_to_points(ax, 1.5)

        # make plot
        scat = ax.scatter(xs_plot[kk, :, 0], xs_plot[kk, :, 1], s=s, marker='o', c=quantities[kk], cmap=cmap, norm=norm)

        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)


    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}_particles.pdf', dpi=300, bbox_inches='tight')

### Generate Data

In [None]:
n_steps             = 4096*16
SDE_trajs           = onp.array(compute_trajectories(n_steps=n_steps, start_index=0))
skip                = 64
particle_quantities = jax.device_put(
    compute_particle_quantities(
        batch_size=32, average_xgs=SDE_trajs[:, ::skip], ema_fac=0.9999
    ), jax.devices('cpu')[0]
)

### Plot individual quantity for visualization

In [None]:
plot_index    = onp.random.randint(n_steps)
npts_average  = 128
clip_quantile = 0.05
nv0s_plot     = 5

make_particle_plot(
    quantities=np.mean(particle_quantities['gdot'][:nv0s_plot, :npts_average], axis=1),
    xs_plot=particle_quantities['xs'][:nv0s_plot, npts_average, :N],
    normalize=True,
    clip_quantile=clip_quantile,
    title='',
    save_str=''
)

### Loop over all quantities for figure saving

In [None]:
plot_index   = 0
npts_average = 128
nv0s_plot    = 5


titles = {
    'gdot':   r"$|v_g|$",
    'xdot':   r"$|v_x|$",
    'v':      r"$|v|$", 
    'div_v':  r"$\nabla\cdot v$", 
    'div_vg': r"$\nabla_g\cdot v_g$", 
    'div_vx': r"$\nabla_x\cdot v_x$"
}


base_folder  = '/scratch/nb3397/results/mips/lowd_cates/figures/'
date_folder  = '8_18_23'


for key in particle_quantities.keys():
    if key != 'xs':
        make_particle_plot(
            quantities=np.mean(particle_quantities[key][:nv0s_plot, :npts_average], axis=1),
            xs_plot=particle_quantities['xs'][:nv0s_plot, plot_index, :, :],
            normalize=True,
            clip_quantile=0.05,
            title=titles[key],
            save_str=f'{base_folder}/{date_folder}/{key}'
        )

# Spatial gridding

In [None]:
def make_gridded_entropy_plot(
    particle_quantities: np.ndarray,
    particle_xs: np.ndarray,
    batch_size: int,
    npts_grid: int,
    density_cutoff: float,
    clip_quantile: float,
    compute_density: bool,
    map_to_particle: bool,
    normalize: bool,
    title: str,
    save_str: str
) -> None:
    ## individual panels
    titles = [rf"$v_0={v0}$" for v0 in v0s[:particle_quantities.shape[0]]]
    nrows  = 1
    ncols  = len(titles)
    
    ### bin the data
    gridded_quantities       = onp.zeros((ncols, npts_grid, npts_grid))
    xgrid_plots, ygrid_plots = onp.zeros((ncols, npts_grid, npts_grid)), onp.zeros((ncols, npts_grid, npts_grid))
    xedges, yedges           = onp.zeros((ncols, npts_grid+1)), onp.zeros((ncols, npts_grid+1))
    for kk in range(ncols):
        gridded_quantities[kk], xedges[kk], yedges[kk] = np.histogram2d(
            particle_xs[kk, :, :, 0].ravel(), particle_xs[kk, :, :, 1].ravel(), weights=particle_quantities[kk].ravel(), bins=npts_grid
        )
        
        if compute_density:
            multiplicity = onp.array(np.histogram2d(
                particle_xs[kk, :, :, 0].ravel(), particle_xs[kk, :, :, 1].ravel(), bins=(xedges[kk], yedges[kk])
            )[0])
            
            # remove noisy regions
            multiplicity[multiplicity <= density_cutoff] = 1.0
            gridded_quantities[kk][multiplicity <= density_cutoff] = 0.0
            gridded_quantities[kk] /= multiplicity
        else:
            gridded_quantities[kk] /= particle_xs[kk, :, :, 0].size
        
        curr_xgrid, curr_ygrid = np.meshgrid(xedges[kk], yedges[kk], indexing='ij')
        xgrid_plots[kk] = curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges[kk])[:, None]
        ygrid_plots[kk] = curr_ygrid[:-1, :-1] + 0.5*np.diff(yedges[kk])[None, :]


    # normalize by v0
    if normalize:
        gridded_quantities[1:] = gridded_quantities[1:] / np.array(v0s)[1:gridded_quantities.shape[0], None, None]
        title  = title + r"$/v_0$"


    ### make figure
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 5.5
    fraction  = 0.15
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 21


    # define the overall figure
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )
    

    ## single colorbar
    cmap = sns.color_palette('mako', as_cmap=True)
    vmin, vmax = onp.quantile(gridded_quantities, q=clip_quantile), onp.quantile(gridded_quantities, q=(1-clip_quantile))
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    cbar = fig.colorbar(mappable=mappable, ax=axs.ravel(), fraction=fraction,
                        shrink=shrink, pad=pad, orientation='vertical')
    cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
    cbar.ax.yaxis.get_offset_text().set(size=0.8*fontsize)
    cbar.set_label(title, fontsize=fontsize)
    
    ## manually clip to avoid weird background square artifact
    gridded_quantities[gridded_quantities < vmin] = vmin
    gridded_quantities[gridded_quantities > vmax] = vmax
    gridded_quantities = np.array(gridded_quantities)


    # make the actual plot
    for kk, ax in enumerate(axs):
        # set up simple plot parameters
        cfg = data_dicts[v0s[kk]]['cfg']
        ax.set_xlim([-cfg.width, cfg.width])
        ax.set_ylim([-cfg.width, cfg.width])

        ax.set_xticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])

        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        ax.axes.set_aspect(1.0)
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.set_facecolor(mappable.to_rgba(0.0))


        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)


        # make the plot
        if map_to_particle:
            particle_positions = particle_xs[kk, 0, :]
            iis, jjs           = jax.vmap(grid_utils.find_grid_pt, in_axes=(0, None, None))(particle_positions, xedges[kk], yedges[kk])
            particle_weights   = jax.vmap(lambda ii, jj: gridded_quantities[kk, ii, jj])(iis, jjs)
            s                  = radius_to_points(ax, 1.5)
            scat               = ax.scatter(particle_positions[:, 0], particle_positions[:, 1], s=s, marker='o', c=particle_weights, cmap=cmap, norm=norm)
        else:
            # plot real data
            ctr = ax.contourf(xgrid_plots[kk], ygrid_plots[kk], gridded_quantities[kk], cmap=cmap, norm=norm, levels=100)

            # fix strange aliasing artifact
            for c in ctr.collections:
                c.set_edgecolor("face")
                c.set_rasterized(True)


    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        if map_to_particle:
            output_str = f'{save_str}_spatial_pp_on_particles'
        else:
            output_str = f'{save_str}_spatial_pp' if compute_density else f'{save_str}_spatial'
        plt.savefig(f'{output_str}.pdf', dpi=300, bbox_inches='tight')

In [None]:
@jax.jit
@functools.partial(jax.vmap, in_axes=(None, 0, None, None), out_axes=1)
@functools.partial(jax.vmap, in_axes=(0, None, None, None), out_axes=0)
def find_radial_values(
    x: float,
    y: float,
    binned_quantity: np.ndarray,
    redges: np.ndarray
) -> float:
    r  = np.sqrt(x**2 + y**2)
    ii = np.argmax(redges > r) - 1
    return binned_quantity[ii]


@jax.jit
@functools.partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def find_radial_vec(
    pos: np.ndarray,
    binned_quantity: np.ndarray,
    redges: np.ndarray
) -> float:
    r  = np.linalg.norm(pos)
    ii = np.argmax(redges > r) - 1
    return binned_quantity[ii]

In [None]:
def make_radial_gridded_entropy_plot(
    particle_quantities: np.ndarray,
    particle_xs: np.ndarray,
    batch_size: int,
    npts_grid: int,
    density_cutoff: float,
    clip_quantile: float,
    compute_density: bool,
    map_to_particle: bool,
    normalize: bool,
    title: str,
    save_str: str
) -> None:
    ## individual panels
    nv0s   = particle_quantities.shape[0]
    titles = [rf"$v_0={v0}$" for v0 in v0s[:nv0s]]
    nrows  = 1
    ncols  = nv0s


    ## convert particle position data to radii
    ntrajs              = particle_xs.shape[1]
    particle_xs         = particle_xs.reshape((nv0s, -1, 2))
    particle_quantities = particle_quantities.reshape((nv0s, -1))
    particle_rs         = jax.vmap(jax.vmap(np.linalg.norm))(particle_xs)


    ### bin the data radially and map to grid data
    binned_quantities        = onp.zeros((ncols, npts_grid))
    gridded_quantities       = onp.zeros((ncols, npts_grid, npts_grid))
    xgrid_plots, ygrid_plots = onp.zeros((ncols, npts_grid, npts_grid)), onp.zeros((ncols, npts_grid, npts_grid))
    redges                   = onp.zeros((ncols, npts_grid+1))
    theta_bins               = onp.linspace(0.0, 2*onp.pi, npts_grid)
    for kk in range(ncols):
        binned_quantities[kk], redges[kk] = np.histogram(particle_rs[kk], weights=particle_quantities[kk], bins=npts_grid)

        if compute_density:
            multiplicity = onp.array(np.histogram(particle_rs[kk], bins=redges[kk])[0])

            ### filter noisy regions
            multiplicity[multiplicity <= density_cutoff] = 1.0
            binned_quantities[kk][multiplicity <= density_cutoff] = 0.0
            binned_quantities[kk] /= multiplicity
        else:
            ## probability of landing in a radial bin with radius r scales like r -- normalize.
            rcenters               = redges[kk][:-1] + np.diff(redges[kk])
            binned_quantities[kk]  = binned_quantities[kk] / rcenters
            binned_quantities[kk] /= particle_rs[kk].size
            
        ## compute x and y grids to plot
        rgrid, theta_grid      = np.meshgrid(redges[kk][:-1] + 0.5*np.diff(redges[kk]), theta_bins)
        xgrid_plots[kk]        = rgrid*np.cos(theta_grid)
        ygrid_plots[kk]        = rgrid*np.sin(theta_grid)
        gridded_quantities[kk] = np.tile(binned_quantities[kk], (npts_grid, 1))


    ### normalize by v0
    if normalize:
        gridded_quantities[1:] = gridded_quantities[1:] / onp.array(v0s)[1:nv0s, None, None]
        title              = title + r"$/v_0$"
    

    ### compute particle mappings
    if map_to_particle:
        particle_positions = particle_xs.reshape((nv0s, -1, N, d))[:, onp.random.randint(ntrajs)]
        particle_weights   = onp.array(find_radial_vec(particle_positions, np.array(binned_quantities), np.array(redges)))
        
        
    ### compute colorbar
    cmap = sns.color_palette('mako', as_cmap=True)
    
    if map_to_particle:
        vmin, vmax = onp.quantile(particle_weights, q=clip_quantile), onp.quantile(particle_weights, q=(1-clip_quantile))
    else:
        vmin, vmax = onp.quantile(gridded_quantities, q=clip_quantile), onp.quantile(gridded_quantities, q=(1-clip_quantile))
        
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])


    ## manually clip to avoid weird background square artifact
    if map_to_particle:
        particle_weights[particle_weights < vmin] = vmin
        particle_weights[particle_weights > vmax] = vmax
    else:
        gridded_quantities[gridded_quantities < vmin] = vmin
        gridded_quantities[gridded_quantities > vmax] = vmax
        gridded_quantities = np.array(gridded_quantities)
        

    ### make figure
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 5.5
    fraction  = 0.15
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 21


    # define the overall figure
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )


    # make the actual plot
    for kk, ax in enumerate(axs):        
        # set up simple plot parameters
        cfg = data_dicts[v0s[kk]]['cfg']
        ax.set_xlim([-cfg.width, cfg.width])
        ax.set_ylim([-cfg.width, cfg.width])

        ax.set_xticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])

        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        ax.axes.set_aspect(1.0)
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.set_facecolor(mappable.to_rgba(0.0))


        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)


        # make the plot
        if map_to_particle:
            s    = radius_to_points(ax, 1.5)
            scat = ax.scatter(particle_positions[kk, :, 0], particle_positions[kk, :, 1], 
                              s=s, marker='o', c=particle_weights[kk], 
                              cmap=cmap, norm=norm)
        else:
            ctr = ax.contourf(xgrid_plots[kk], ygrid_plots[kk], gridded_quantities[kk], cmap=cmap, norm=norm, levels=100)

            # fix strange aliasing artifact
            for c in ctr.collections:
                c.set_edgecolor("face")
                c.set_rasterized(True)


    cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs.ravel(), fraction=fraction, shrink=shrink, pad=pad, orientation='vertical')
    cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
    cbar.ax.yaxis.get_offset_text().set(size=0.8*fontsize)
    cbar.set_label(title, fontsize=fontsize)


    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        output_str = f'{save_str}_radial_pp' if compute_density else f'{save_str}_radial'
        plt.savefig(f'{output_str}.pdf', dpi=300, bbox_inches='tight')

### Compute particle data.

In [None]:
n_pts_average = 4096*16
# n_pts_average = 4096
average_xgs   = onp.zeros((len(v0s), n_pts_average, 2*N, d))

## pick points to average over
for kk, v0 in enumerate(v0s):
    curr_ntrajs     = data_dicts[v0]['xgs'].shape[0]
    inds            = onp.random.choice(onp.arange(curr_ntrajs), size=n_pts_average, replace=True)
    average_xgs[kk] = data_dicts[v0]['xgs'][inds]


particle_quantities_grid = compute_particle_quantities(
    batch_size=32, average_xgs=average_xgs, ema_fac=0.9999
)

### Plot individual quantity for visualization

In [None]:
nv0s_plot = 5


make_radial_gridded_entropy_plot(
    particle_quantities_grid['gdot'][:nv0s_plot],
    particle_quantities_grid['xs'][:nv0s_plot],
    batch_size=256,
    npts_grid=128,
    density_cutoff=25,
    clip_quantile=0.1,
    normalize=True,
    map_to_particle=False,
    compute_density=True,
    title='',
    save_str=''
)

In [None]:
nv0s_plot = 5


make_radial_gridded_entropy_plot(
    particle_quantities_grid['gdot'][:nv0s_plot],
    particle_quantities_grid['xs'][:nv0s_plot],
    batch_size=256,
    npts_grid=128,
    density_cutoff=5,
    clip_quantile=0.05,
    normalize=True,
    map_to_particle=True,
    compute_density=True,
    title='',
    save_str=''
)

In [None]:
nv0s_plot = 5


make_radial_gridded_entropy_plot(
    particle_quantities_grid['gdot'][:nv0s_plot],
    particle_quantities_grid['xs'][:nv0s_plot],
    batch_size=256,
    npts_grid=128,
    density_cutoff=25,
    clip_quantile=0.05,
    normalize=True,
    map_to_particle=False,
    compute_density=False,
    title='',
    save_str=''
)

In [None]:
nv0s_plot = 5


make_gridded_entropy_plot(
    particle_quantities_grid['div_v'][:nv0s_plot],
    particle_quantities_grid['xs'][:nv0s_plot],
    batch_size=256,
    npts_grid=128,
    density_cutoff=0,
    clip_quantile=0,
    normalize=True,
    map_to_particle=False,
    compute_density=False,
    title='',
    save_str=''
)

In [None]:
nv0s_plot = 5


make_gridded_entropy_plot(
    particle_quantities_grid['div_v'][:nv0s_plot],
    particle_quantities_grid['xs'][:nv0s_plot],
    batch_size=256,
    npts_grid=128,
    density_cutoff=10,
    clip_quantile=0.01,
    normalize=True,
    map_to_particle=False,
    compute_density=True,
    title='',
    save_str=''
)

### Make all Cartesian grids.

In [None]:
titles = {
    'gdot':   r"$|v_g|$",
    'xdot':   r"$|v_x|$",
    'v':      r"$|v|$", 
    'div_v':  r"$\nabla\cdot v$", 
    'div_vg': r"$\nabla_g\cdot v_g$", 
    'div_vx': r"$\nabla_x\cdot v_x$"
}

nv0s_plot    = 5
base_folder  = '/scratch/nb3397/results/mips/lowd_cates/figures/'
date_folder  = '8_18_23'


for key in particle_quantities_grid.keys():
    if key != 'xs':
        for compute_density in [True, False]:
            for map_to_particle in [True, False]:
                if map_to_particle and not compute_density:
                    pass
                else:
                    print(f'Starting {key} with compute_density={compute_density} and map_to_particle={map_to_particle}')
                    make_gridded_entropy_plot(
                        particle_quantities_grid[key][:nv0s_plot],
                        particle_quantities_grid['xs'][:nv0s_plot],
                        batch_size=256,
                        npts_grid=128,
                        density_cutoff=5,
                        normalize=True,
                        clip_quantile=0.025,
                        compute_density=compute_density,
                        map_to_particle=map_to_particle,
                        title=titles[key],
                        save_str=f'{base_folder}/{date_folder}/{key}'
                    )
print('Finished!')

### Make all radial grids.

In [None]:
titles = {
    'gdot':   r"$|v_g|$",
    'xdot':   r"$|v_x|$",
    'v':      r"$|v|$", 
    'div_v':  r"$\nabla\cdot v$", 
    'div_vg': r"$\nabla_g\cdot v_g$", 
    'div_vx': r"$\nabla_x\cdot v_x$"
}

nv0s_plot    = 5
base_folder  = '/scratch/nb3397/results/mips/lowd_cates/figures/'
date_folder  = '8_18_23'


for key in particle_quantities_grid.keys():
    if key != 'xs':
        for compute_density in [True, False]:
            print(f'Starting {key} with compute_density={compute_density} and map_to_particle={map_to_particle}')
            make_radial_gridded_entropy_plot(
                particle_quantities_grid[key][:nv0s_plot],
                particle_quantities_grid['xs'][:nv0s_plot],
                batch_size=256,
                npts_grid=256,
                density_cutoff=10,
                normalize=True,
                clip_quantile=0.025,
                compute_density=compute_density,
                map_to_particle=map_to_particle,
                title=titles[key],
                save_str=f'{base_folder}/{date_folder}/{key}'
            )
print('Finished!')

# Combined Radial + Particle

In [None]:
def make_radial_gridded_and_particle_combined_entropy_plot(
    particle_quantities_grid: np.ndarray,
    particle_grid_xs: np.ndarray,
    particle_quantities: np.ndarray,
    particle_xs: np.ndarray,
    batch_size: int,
    npts_grid: int,
    density_cutoff: float,
    clip_quantile: float,
    colormap_str: str,
    radial: bool,
    particle_colorbar: bool,
    map_to_particle: bool,
    title: str,
    save_str: str
) -> None:
    ##### define figure and axes
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 4


    if particle_colorbar:
        fraction  = 0.15
        pad       = 0.025
        shrink    = 0.75
    else:
        fraction  = 0.15
        pad       = 0.025
        shrink    = 0.5


    fontsize  = 21
    nv0s      = particle_quantities_grid.shape[0]
    titles    = [rf"$v_0={v0}$" for v0 in v0s[:nv0s]]
    nrows     = 2
    ncols     = nv0s
    title     = title + r"$/v_0$"
    cmap      = sns.color_palette(colormap_str, as_cmap=True)
    fig, axs  = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=True, sharey=True, constrained_layout=True
    )


    ## basic properties common to all axes
    cfg = data_dicts[v0s[0]]['cfg']
    for ax in axs.ravel():
        ax.set_xlim([-cfg.width, cfg.width])
        ax.set_ylim([-cfg.width, cfg.width])
        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        ax.axes.set_aspect(1.0)
        ax.set_xticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])


        for spine in ax.spines.values():
             spine.set_edgecolor('black')


    ######### radially-filtered grid #########
    ## convert particle position data to radii
    ntrajs                   = particle_grid_xs.shape[1]
    particle_grid_xs         = particle_grid_xs.reshape((nv0s, -1, 2))
    particle_quantities_grid = particle_quantities_grid.reshape((nv0s, -1))
    particle_rs              = jax.vmap(jax.vmap(np.linalg.norm))(particle_grid_xs)


    ### bin the data and map to grid data
    gridded_quantities       = onp.zeros((ncols, npts_grid, npts_grid))
    xgrid_plots, ygrid_plots = onp.zeros((ncols, npts_grid, npts_grid)), onp.zeros((ncols, npts_grid, npts_grid))
    if radial:
        binned_quantities = onp.zeros((ncols, npts_grid))
        redges            = onp.zeros((ncols, npts_grid+1))
        theta_edges       = np.linspace(0.0, 2*np.pi, npts_grid)
        for kk in range(ncols):
            ## radially bin
            binned_quantities[kk], redges[kk] = np.histogram(particle_rs[kk], weights=particle_quantities_grid[kk], bins=npts_grid)
            multiplicity                      = onp.array(np.histogram(particle_rs[kk], bins=redges[kk])[0])
            rcenters                          = redges[kk, :-1] + 0.5*np.diff(redges[kk])

            ## filter out noisy regions
            inds                         = (multiplicity / rcenters) <= density_cutoff
            multiplicity[inds]           = 1.0
            binned_quantities[kk][inds]  = 0.0
            binned_quantities[kk]       /= multiplicity

            ## compute x and y grids to plot
            rgrid, theta_grid      = np.meshgrid(redges[kk][:-1] + 0.5*np.diff(redges[kk]), theta_edges)
            xgrid_plots[kk]        = rgrid*np.cos(theta_grid)
            ygrid_plots[kk]        = rgrid*np.sin(theta_grid)
            gridded_quantities[kk] = np.tile(binned_quantities[kk], (npts_grid, 1))
    else:
        xedges, yedges           = onp.zeros((ncols, npts_grid+1)), onp.zeros((ncols, npts_grid+1))
        for kk in range(ncols):
            ## cartesian bin
            gridded_quantities[kk], xedges[kk], yedges[kk] = np.histogram2d(
                particle_grid_xs[kk, :, 0], particle_grid_xs[kk, :, 1], weights=particle_quantities_grid[kk], bins=npts_grid
            )
            
            multiplicity = onp.array(
                np.histogram2d(particle_grid_xs[kk, :, 0], particle_grid_xs[kk, :, 1], bins=(xedges[kk], yedges[kk]))[0]
            )

            ## filter out noisy regions
            multiplicity[multiplicity <= density_cutoff] = 1.0
            gridded_quantities[kk][multiplicity <= density_cutoff] = 0.0
            gridded_quantities[kk] /= multiplicity
            curr_xgrid, curr_ygrid = np.meshgrid(xedges[kk], yedges[kk], indexing='ij')
            xgrid_plots[kk] = curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges[kk])[:, None]
            ygrid_plots[kk] = curr_ygrid[:-1, :-1] + 0.5*np.diff(yedges[kk])[None, :]


    ## normalize and compute colorbar
    gridded_quantities[1:] = gridded_quantities[1:] / onp.array(v0s)[1:nv0s, None, None]
    vmin, vmax = onp.quantile(gridded_quantities, q=clip_quantile), onp.quantile(gridded_quantities, q=(1-clip_quantile))
    norm       = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    mappable   = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])


    ## manually clip to avoid weird background square artifact
    gridded_quantities[gridded_quantities < vmin] = vmin
    gridded_quantities[gridded_quantities > vmax] = vmax
    gridded_quantities = np.array(gridded_quantities)


    ## make the actual plot
    for kk, ax in enumerate(axs[1]):        
        # set up simple plot parameters
        cfg = data_dicts[v0s[kk]]['cfg']
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        ax.set_facecolor(mappable.to_rgba(0.0))

        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)

        # make the plot
        ctr = ax.contourf(xgrid_plots[kk], ygrid_plots[kk], gridded_quantities[kk], cmap=cmap, norm=norm, levels=100)

        # fix strange aliasing artifact
        for c in ctr.collections:
            c.set_edgecolor("face")
            c.set_rasterized(True)


    ######## plot particle realization ########
    if map_to_particle:
        if radial:
            particle_xs         = particle_grid_xs.reshape((nv0s, -1, N, d))[:, onp.random.randint(ntrajs)]
            particle_quantities = onp.array(find_radial_vec(particle_xs, np.array(binned_quantities), np.array(redges)))
        else:
            particle_xs        = particle_grid_xs.reshape((nv0s, -1, N, d))[:, onp.random.randint(ntrajs)]
            print(xedges.shape, yedges.shape)

            # map over all particles and all values of v0
            iis, jjs            = jax.vmap(
                jax.vmap(grid_utils.find_grid_pt, in_axes=(0, None, None))
            )(particle_xs, xedges, yedges)
            particle_quantities = onp.array(
                jax.vmap(
                    jax.vmap(
                        lambda ii, jj, grid_quantity: grid_quantity[ii, jj],
                        in_axes=(0, 0, None)
                    )
                )(iis, jjs, gridded_quantities)
            )

    ## normalize if it wasn't done already
    particle_quantities         = onp.array(particle_quantities)
    if not map_to_particle:
        particle_quantities[1:nv0s] = particle_quantities[1:nv0s] / (np.array(v0s)[1:nv0s])[:, None]

    if particle_colorbar:
        vmin, vmax = onp.quantile(particle_quantities, q=clip_quantile), onp.quantile(particle_quantities, q=(1-clip_quantile))
        particle_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        particle_mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=particle_norm)
        particle_mappable.set_array([])

    ## manually clip to colorbar
    particle_quantities[particle_quantities < vmin] = vmin
    particle_quantities[particle_quantities > vmax] = vmax


    for kk, ax in enumerate(axs[0]):
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.set_facecolor(mappable.to_rgba(0.0))
        cfg  = data_dicts[v0s[kk]]['cfg']
        s    = radius_to_points(ax, 1.5)
        scat = ax.scatter(particle_xs[kk, :, 0], particle_xs[kk, :, 1], 
                          s=s, marker='o', c=particle_quantities[kk], 
                          cmap=cmap, norm=(particle_norm if particle_colorbar else norm))


        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)

    ### construct the shared colorbar
    if particle_colorbar:
        for kk, (curr_mappable, curr_norm) in enumerate(
            zip([mappable, particle_mappable], [particle_norm, particle_mappable])
        ):
            cbar = fig.colorbar(mappable=curr_mappable, norm=curr_norm, ax=axs[kk].ravel(), fraction=fraction, shrink=shrink, pad=pad, orientation='vertical')
            cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
            cbar.ax.yaxis.get_offset_text().set(size=0.8*fontsize)
            cbar.set_label(title, fontsize=fontsize)
            cbar.outline.set_edgecolor('grey')
    else:
        cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs.ravel(), fraction=fraction, shrink=shrink, pad=pad, orientation='vertical')
        cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
        cbar.ax.yaxis.get_offset_text().set(size=0.8*fontsize)
        cbar.set_label(title, fontsize=fontsize)
        cbar.outline.set_edgecolor('grey')


    if save_str != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)

        base_str = f'{save_str}_combined'
        if radial:
            base_str += '_radial'
        if map_to_particle:
            base_str += '_on_particle'
        if particle_colorbar:
            base_str += '_particle_colorbar'

        plt.savefig(f'{base_str}.pdf', dpi=300, bbox_inches='tight')

## Visualize and play

In [None]:
nv0s     = 5
quantity = 'div_vx'
npts_avg = 128


make_radial_gridded_and_particle_combined_entropy_plot(
    particle_quantities_grid[quantity][:nv0s],
    particle_quantities_grid['xs'][:nv0s],
    np.median(particle_quantities[quantity][:nv0s, :npts_avg], axis=1),
    particle_quantities['xs'][:nv0s, npts_avg],
    batch_size=256,
    npts_grid=64,
    density_cutoff=10,
    clip_quantile=0.05,
    colormap_str='mako',
    particle_colorbar=False,
    map_to_particle=True,
    radial=True,
    title='',
    save_str=''
)

## Loop and plot

In [None]:
titles = {
    'gdot':   r"$|v_g|$",
    'xdot':   r"$|v_x|$",
    'v':      r"$|v|$", 
    'div_v':  r"$\nabla\cdot v$", 
    'div_vg': r"$\nabla_g\cdot v_g$", 
    'div_vx': r"$\nabla_x\cdot v_x$"
}


nv0s_plot    = 5
npts_avg     = 128
base_folder  = '/scratch/nb3397/results/mips/lowd_cates/figures/'
# date_folder  = '8_18_23'
date_folder  = '9_16_23'


for key in titles.keys():
    for particle_colorbar in [True, False]:
        for map_to_particle in [True, False]:
            print(f'Starting {key}!')
            make_radial_gridded_and_particle_combined_entropy_plot(
                particle_quantities_grid[key][:nv0s_plot],
                particle_quantities_grid['xs'][:nv0s_plot],
                np.median(particle_quantities[key][:nv0s_plot, :npts_avg], axis=1),
                particle_quantities['xs'][:nv0s_plot, npts_avg],
                batch_size=256,
                npts_grid=128,
                density_cutoff=5,
                clip_quantile=0.05,
                colormap_str='mako',
                particle_colorbar=particle_colorbar,
                map_to_particle=map_to_particle,
                radial=False,
                title=titles[key],
                save_str=f'{base_folder}/{date_folder}/{key}'
            )
print('Finished!')

# Statistics Visualization

In [None]:
@functools.partial(jax.jit, static_argnums=(2, 3, 4))
@functools.partial(jax.vmap, in_axes=(0, None, None, None, None))
def compute_batch_statistics(
    xgs: np.ndarray,
    params: Dict[str, hk.Params],
    cfg: config_dict.ConfigDict,
    score_net: Callable,
    particle_div_net: Callable,
) -> Tuple:
    """Compute the quantitative verification metrics over a batch."""
    xs, gs = np.split(xgs, 2)                        # ([N, d], [N, d])
    sxs = score_net.apply(params['x'], xs, gs, 'x')  # [N, d]
    sgs = score_net.apply(params['g'], xs, gs, 'g')  # [N, d]
    particle_scores = np.hstack((sxs, sgs))          # [N, 2*d]
    vs = launcher.calc_vs(xgs, sxs, sgs, cfg)[-1]    # (N, 2d)
    div_sxs, div_sgs, div_vxs, div_vgs, div_vs = \
        launcher.calc_divs(params, xs, gs, cfg, particle_div_net) # ([N], [N], [N], [N], [N])

    div_v     = np.sum(div_vs)
    div_vx    = np.sum(div_vxs)
    div_vg    = np.sum(div_vgs)
    v_times_s = np.sum(vs*particle_scores)
    pinn      = div_v + v_times_s
    ibp       = np.sum(particle_scores**2) + np.sum(div_sxs + div_sgs)

    return div_v, div_vx, div_vg, v_times_s, ibp, pinn


def compute_statistics(
    batch_size: int,
    n_samples: int,
    ema_fac: float
) -> dict:
    """Compute the quantitative verification metrics."""
    ## pick time samples for averaging and set up batching
    ntrajs = data_dicts[v0s[-1]]['xgs'].shape[0]
    inds = onp.random.choice(onp.arange(ntrajs), size=n_samples)
    if n_samples % batch_size == 0:
        n_batches = n_samples // batch_size
    else:   
        n_batches = int(n_samples / batch_size) + 1


    ## set up arrays to hold statistics of interest
    n_v0s = len(v0s)
    statistics = {
        'div_v':     onp.zeros((n_v0s, n_samples)),
        'div_vx':    onp.zeros((n_v0s, n_samples)),
        'div_vg':    onp.zeros((n_v0s, n_samples)),
        'v_times_s': onp.zeros((n_v0s, n_samples)),
        'ibp':       onp.zeros((n_v0s, n_samples)),
        'pinn':      onp.zeros((n_v0s, n_samples)),
    }


    for kk, v0 in enumerate(v0s):
        # unpack everything
        data_dict   = data_dicts[v0]
        cfg         = make_cfg_hashable(data_dict['cfg'])
        params      = get_params(data_dict, ema_fac)
        xgs         = data_dict['xgs'][inds]
        xs, gs      = np.split(xgs, 2, axis=1)

        # set up needed functions
        score_net, particle_div_net, _, _ = launcher.construct_network(cfg)

        # compute quantity of interest for spatial averaging
        for curr_batch in range(n_batches):
            start_time = time.time()
            lb = curr_batch*batch_size
            ub = lb + batch_size
            
            statistics['div_v'][kk, lb:ub], \
            statistics['div_vx'][kk, lb:ub], \
            statistics['div_vg'][kk, lb:ub], \
            statistics['v_times_s'][kk, lb:ub], \
            statistics['ibp'][kk, lb:ub:], \
            statistics['pinn'][kk, lb:ub] = compute_batch_statistics(xgs[lb:ub], params, cfg, score_net, particle_div_net)
            end_time = time.time()
            
            print(f'[{curr_batch+1}/{n_batches}], [{kk+1}/{len(v0s)}], {end_time-start_time}s. ')

    return statistics


def make_histogram_plot(
    statistics: np.ndarray,
    v0s_plot: list,
    bins: int,
    sup_title: str,
    save_str: str
) -> None:
    ## make the figure
    plt.close('all')
    fw, fh    = 4, 4
    fontsize  = 22
    titles    = [rf"$v_0=${v0}" for v0 in v0s_plot]
    nv0s_plot = len(v0s_plot)


    with sns.axes_style('whitegrid'), plt.rc_context(rc={'text.usetex': True, 'font.family': 'serif'}):
        cmap = sns.cubehelix_palette(n_colors=6)
        fig, axs = plt.subplots(nrows=1, ncols=nv0s_plot, sharey=True, constrained_layout=True, figsize=(nv0s_plot*fw, fh))
        fig.suptitle(sup_title, fontsize=fontsize)


        for kk, ax in enumerate(axs):
            ax.grid(which='both', axis='both', color='0.75', alpha=0.5)
            ax.tick_params(axis='both', labelsize=fontsize)
            ax.hist(statistics[kk] / (2*N*d), bins=bins, color=cmap[kk])
            ax.set_title(titles[kk], fontsize=fontsize)
            ax.xaxis.offsetText.set_fontsize(0.8*fontsize)
            mean, std = np.mean(statistics[kk] / (2*N*d)) , np.std(statistics[kk] / (2*N*d))
            ax.axvline(mean,     color='k',          alpha=0.25)
            ax.text(0.05, 0.9, rf"$\mu={mean:.1e}$",   transform=ax.transAxes, fontsize=0.65*fontsize)
            ax.text(0.65, 0.9, rf"$\sigma={std:.1e}$", transform=ax.transAxes, fontsize=0.65*fontsize)


            if kk == 0:
                ax.set_ylabel('count', fontsize=fontsize)

        if save_str != '':
            fig.patch.set_facecolor('k')
            fig.patch.set_alpha(0.0)
            fig.savefig(save_str, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())
        else:
            # shouldn't need this, but seems that we do anyways for some reason.
            # moreover, it seems to save only a white figure if we do this outside the else block.
            plt.show()

#### generate statistics

In [None]:
statistics = compute_statistics(batch_size=64, n_samples=4096*16, ema_fac=0.9999)

#### visualize individually

In [None]:
nv0s_plot = 5

In [None]:
make_histogram_plot(statistics['pinn'][:nv0s_plot], v0s_plot=v0s[:nv0s_plot], bins=50, sup_title='', save_str='')

In [None]:
make_histogram_plot(statistics['div_v'][:nv0s_plot], v0s_plot=v0s[:nv0s_plot], bins=50, sup_title='', save_str='')

In [None]:
make_histogram_plot(statistics['ibp'][:nv0s_plot], v0s_plot=v0s[:nv0s_plot], bins=50, sup_title='', save_str='')

In [None]:
make_histogram_plot(statistics['v_times_s'][:nv0s_plot], v0s_plot=v0s[:nv0s_plot], bins=50, sup_title='', save_str='')

#### loop and plot

In [None]:
nv0s_plot = 5

titles = {
    'div_v':     r"$\nabla\cdot v$",
    'v_times_s': r"$v\cdot h$",
    'ibp':       r"$|h|^2 + \nabla \cdot h$",
    'pinn':      r"$\nabla\cdot v + v\cdot h$"
}

# base_output_folder = '/scratch/nb3397/results/mips/lowd_cates/figures/8_18_23'
base_output_folder = '/scratch/nb3397/results/mips/lowd_cates/figures/9_20_23'


for key in titles.keys():
    make_histogram_plot(statistics[key][:nv0s_plot], v0s_plot=v0s[:nv0s_plot], 
                        bins=50, sup_title=titles[key], 
                        save_str=f'{base_output_folder}/{key}_statistics.pdf')

# Probability Flow Visualization

## Rollout functions

In [None]:
def pflow_rhs(
    xgs: np.ndarray, # [2N, d]
    score_net: Callable,
    cfg: config_dict.FrozenConfigDict,
    params: Dict[str, hk.Params]
) -> np.ndarray:
    xs, gs = np.split(xgs, 2) # [N, d], [N, d]
    xdots  = launcher.calc_xdots(xgs, cfg) - cfg.eps*score_net.apply(params['x'], xs, gs, 'x')
    gdots  = -cfg.gamma*gs - cfg.gamma*score_net.apply(params['g'], xs, gs, 'g')
    return np.concatenate((xdots, gdots))


@functools.partial(jax.jit, static_argnums=(1, 2))
def step_pflow(
    xgs: np.ndarray, # [2N, d]
    score_net: Callable,
    cfg: config_dict.FrozenConfigDict,
    params: Dict[str, hk.Params],
    dt: float,
) -> np.ndarray:
    xs, gs       = np.split(xgs, 2)
    xdots, gdots = np.split(pflow_rhs(xgs, score_net, cfg, params), 2)
    xnexts       = drifts.torus_project(xs + dt*xdots, cfg.width)
    gnexts       = gs + dt*gdots
    return np.concatenate((xnexts, gnexts))


@functools.partial(jax.jit, static_argnums=(2, 3))
def rollout_traj_pflow(
    init_xgs: np.ndarray, # [2N, d]
    steps: np.ndarray,    # [nsteps]
    score_net: Callable,
    cfg: config_dict.FrozenConfigDict,
    params: Dict[str, hk.Params],
    dt: float,
) -> np.ndarray:
    def scan_fn(xgs: np.ndarray, step: np.ndarray):
        xgsnext = step_pflow(xgs, score_net, cfg, params, dt)
        return xgsnext, xgsnext

    xgs_final, xgs_traj = jax.lax.scan(scan_fn, init_xgs, steps)
    return xgs_traj


def compute_pflow_trajs(
    tf: int,
    nbatches: int,
    ema_fac: float,
    dt: float
) -> onp.ndarray:
    """Compute a probability flow trajectory for each value of v0."""
    # select random initialization for each value of v0
    init_index = onp.random.randint(data_dicts[v0s[-1]]['xgs'].shape[0])
    
    ## set up arrays for storing trajectories
    n_v0s  = len(v0s)
    nsteps = int(tf / dt)
    bs     = nsteps // nbatches
    assert nsteps % nbatches == 0
    trajs  = onp.zeros((n_v0s, nsteps+1, 2*N, d))

    for kk, v0 in enumerate(v0s):        
        # unpack everything
        data_dict    = data_dicts[v0]
        cfg          = make_cfg_hashable(data_dict['cfg'])
        params       = get_params(data_dict, ema_fac)
        trajs[kk, 0] = data_dict['xgs'][init_index]
        score_net    = launcher.construct_network(cfg)[0]
        
        # compute the trajectory
        for curr_batch in range(nbatches):
            lb = curr_batch*bs
            ub = lb + bs

            start_time = time.time()
            trajs[kk, 1+lb:1+ub] = rollout_traj_pflow(trajs[kk, lb], np.arange(bs), score_net, cfg, params, dt)
            end_time = time.time()

            print(f'[{kk+1}/{len(v0s)}], [{curr_batch+1}/{nbatches}], {end_time-start_time}s.')


    return trajs

### Test v0 = 0

In [None]:
test_xgs       = data_dicts[0.0]['xgs']
nsamples       = test_xgs.shape[0]
test_cfg       = data_dicts[0.0]['cfg']
test_score_net = launcher.construct_network(test_cfg)[0]
test_params    = data_dicts[0.0]['ema_params_list'][-1][0.999]

In [None]:
ind                  = onp.random.randint(nsamples)
test_xdot, test_gdot = np.split(pflow_rhs(test_xgs[ind], test_score_net, test_cfg, test_params), 2)
test_xdot.min(), test_xdot.max(), np.linalg.norm(test_xdot)

## Compute Trajectories

In [None]:
tfac        = 5.0
gamma       = data_dicts[v0s[0]]['cfg'].gamma
tf          = tfac / gamma
dt          = 0.01
ema_fac     = 0.9999
nbatches    = 10
pflow_trajs = compute_pflow_trajs(tf, nbatches, ema_fac, dt)

## Make Movie

In [None]:
def make_movie(
    trajs: np.ndarray, # [ntrajs, nsteps, 2N, d]
    save_str: str,
    dt: float,
    nframes: int = 250
) -> None:
    plt.close('all')
    sns.set_palette('deep')
    fw, fh   = 4, 4
    fontsize = 20
    r_fac    = 1.4
    nrows    = 1
    ncols    = trajs.shape[0]
    nsteps   = trajs.shape[1]
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows), sharex=True, sharey=True, constrained_layout=True)
    width    = data_dicts[v0s[0]]['cfg'].width


    ## set axis limits, grids, and particle scaling
    cfg = data_dicts[0.0]['cfg']
    for ax in axs:
        ax.set_xlim([-cfg.width, cfg.width])
        ax.set_ylim([-cfg.width, cfg.width])
        
        ax.set_xticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*cfg.width, 0, 0.5*cfg.width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])

        ax.grid(which='both', axis='both', color='0.90', alpha=0.2)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        ax.set_facecolor('k')
        ax.axes.set_aspect(1.0)


    ## set up initial scatter plots
    scats = []
    for kk, (traj, ax) in enumerate(zip(trajs, axs)):
        scats.append(
            ax.scatter(
                traj[0, :N, 0], traj[0, :N, 1], zorder=1, s=radius_to_points(ax, r_fac), marker='o', color='C4', alpha=0.75
            )
        )
        ax.set_title(rf"$v_0=${v0s[kk]}", fontsize=fontsize)
        fig.suptitle(rf"$t=${0:0.3f}$/\gamma$", fontsize=fontsize)


    def init():
        return scats


    def animate(frame: int):
        ## remove scatter for speed
        for path_collection in scats:
            path_collection.remove()


        ## update scatter plot
        t = frame*dt*step
        
        for kk, (traj, ax) in enumerate(zip(trajs, axs)):
            scats[kk] = ax.scatter(
                traj[frame, :N, 0], traj[frame, :N, 1], zorder=1, s=radius_to_points(ax, r_fac), marker='o', color='C4', alpha=0.75
            )
        
        t = gamma*dt*frame
        fig.suptitle(rf"$t=${t:0.3f}$/\gamma$", fontsize=fontsize)

        ## return what changed in the plot
        if frame % (stop // 10) == 0:
            print(f'[{frame}/{stop}]')

        return scats

    start, stop, step = 0, nsteps, nsteps // nframes
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=onp.arange(start, stop, step), 
                                   interval=0.1, blit=True, repeat=False, cache_frame_data=False)

    save_name = f'{save_str}.mp4'
    anim.save(save_name, fps=60)
    plt.show()

In [None]:
nv0s_plot = 5
base_output_folder = '/scratch/nb3397/results/mips/lowd_cates/figures/8_18_23/'

make_movie(
    pflow_trajs[:nv0s_plot], 
    f'{base_output_folder}/pflow',
    dt=dt,
    nframes=500
)

# Attention Map Visualization

### Separate Parameters

In [None]:
separate_params_dict = {}
for kk, v0 in enumerate(v0s):
    curr_params = get_params(data_dicts[v0], ema_fac=0.9999)
    separate_params_dict[v0] = {}

    for key in ['x', 'g']:
        separate_params_dict[v0][f's{key}'] = {
            'g_embedding_params': {key: value for key, value in curr_params[key].items() if 'g_embedding' in key},
            'x_embedding_params': {key: value for key, value in curr_params[key].items() if 'x_embedding' in key},
            'trans_params'      : {key: value for key, value in curr_params[key].items() if 'transformer' in key},
        }

### Separate network for attention map computation

In [None]:
@hk.without_apply_rng
@hk.transform
def g_embedding(
    inp: np.ndarray, 
    cfg: config_dict.FrozenConfigDict
):
    return hk.Sequential(
        networks.construct_mlp_layers(
            cfg.embed_n_hidden, cfg.embed_n_neurons, jax.nn.gelu, w0=cfg.w0, 
            output_dim=cfg.embed_dim // 2, use_layer_norm=False, 
            use_residual_connections=True, name='g_embedding'
        )
    )(inp)


@hk.without_apply_rng
@hk.transform
def x_embedding(
    inp: np.ndarray,
    cfg: config_dict.FrozenConfigDict
):
    return hk.Sequential(
        networks.construct_mlp_layers(
            cfg.embed_n_hidden, cfg.embed_n_neurons, jax.nn.gelu, w0=cfg.w0, 
            output_dim=cfg.embed_dim // 2, use_layer_norm=False, 
            use_residual_connections=True, name='x_embedding'
        )
    )(inp)


@hk.without_apply_rng
@hk.transform
def transformer(
    inp: np.ndarray,
    cfg: config_dict.FrozenConfigDict
):
     return networks.Transformer(
         num_layers=cfg.num_layers, input_dim=cfg.embed_dim,
         num_heads=cfg.num_heads, dim_feedforward=cfg.dim_feedforward,
         n_layers_feedforward=cfg.n_layers_feedforward,
         n_inducing_points=cfg.n_inducing_points, w0=cfg.w0
     )(inp)


@hk.without_apply_rng
@hk.transform
def get_transformer_attention(
    inp: np.ndarray,
    cfg: config_dict.FrozenConfigDict
):
     return networks.Transformer(
         num_layers=cfg.num_layers, input_dim=cfg.embed_dim,
         num_heads=cfg.num_heads, dim_feedforward=cfg.dim_feedforward,
         n_layers_feedforward=cfg.n_layers_feedforward,
         n_inducing_points=cfg.n_inducing_points, w0=cfg.w0
     ).get_attention_maps(inp)


def attention_rollout(
    attention_maps: np.ndarray # [L, N, N]
) -> np.ndarray:
    """Compute the rollout attention map; assumes that we have averaged over all heads."""
    dim             = attention_maps[0].shape[0]
    rollout_attn    = onp.zeros(attention_maps.shape)
    I               = onp.eye(dim)
    rollout_attn[0] = 0.5*(attention_maps[0] + I)

    for ii, curr_attn in enumerate(attention_maps[1:]):
        rollout_attn[ii+1] = 0.5*(curr_attn + I) @ rollout_attn[ii]

    return rollout_attn

### compute attention maps

In [None]:
attention_dict = {}
for kk, v0 in enumerate(v0s):
    attention_dict[v0] = {}
    for key in ['sx', 'sg']:
        curr_cfg           = data_dicts[v0]['cfg']
        curr_xgs           = data_dicts[v0]['xgs']
        curr_xs, curr_gs   = np.split(curr_xgs, 2, axis=1)
        curr_params_dict   = separate_params_dict[v0][key]
        x_embedding_params = curr_params_dict['x_embedding_params']
        g_embedding_params = curr_params_dict['g_embedding_params']
        trans_params       = curr_params_dict['trans_params']

        
        embedded_xs       = x_embedding.apply(x_embedding_params, curr_xs[0], curr_cfg)
        embedded_gs       = g_embedding.apply(g_embedding_params, curr_gs[0], curr_cfg)
        transformer_input = np.hstack((embedded_xs, embedded_gs))
        
        
        attention_dict[v0]['attention_maps'] = np.array(get_transformer_attention.apply(trans_params, transformer_input, curr_cfg))
        attention_dict[v0]['rollout_attn']   = attention_rollout(np.mean(attention_dict[v0]['attention_maps'], axis=1))

## Single-Particle

In [None]:
## standard figure configuration
plt.close('all')
plt.style.use('dark_background')
sns.set_palette('deep')
fraction = 0.95
shrink   = 0.75
fontsize = 12.5
cmap     = sns.color_palette('mako', as_cmap=True)


## what to plot
layer   = 0
head    = 3
v0      = 0.3
curr_xs = np.split(data_dicts[v0]['xgs'], 2, axis=1)[0]
curr_cfg = data_dicts[v0]['cfg']
particle_identity = onp.random.randint(curr_cfg.N)
curr_attention_map = attention_dict[v0]['attention_maps'][layer][head]
# curr_attention_map = attention_dict[v0]['rollout_attn'][layer]


## visualize attention map itself
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(5, 2), constrained_layout=True)
vmin = np.min(curr_attention_map)
vmax = np.max(curr_attention_map)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
im   = axs[0].imshow(curr_attention_map, norm=norm, cmap=cmap)
cbar = fig.colorbar(im, ax=axs[0], fraction=fraction, shrink=shrink)
cbar.ax.tick_params(labelsize=fontsize)
axs[0].grid(which='both', axis='both', color='0.90', alpha=0.0)


## visualize attention on the particles
c    = curr_attention_map[particle_identity]
vmin = np.min(c)
vmax = np.max(c)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
scat = axs[1].scatter(curr_xs[0, :, 0], curr_xs[0, :, 1], s=75, marker='o', c=c, norm=norm, cmap=cmap)
cbar = fig.colorbar(scat, ax=axs[1], fraction=fraction, shrink=shrink)
cbar.ax.tick_params(labelsize=fontsize)
axs[1].scatter(curr_xs[0, particle_identity, 0], curr_xs[0, particle_identity, 1], s=75, marker='o', color='white', alpha=0.75)
axs[1].grid(which='both', axis='both', color='0.90', alpha=0.2)
axs[1].set_xlim([-curr_cfg.width//2, curr_cfg.width//2])
axs[1].set_ylim([-curr_cfg.width//2, curr_cfg.width//2])
axs[1].grid(which='both', axis='both', color='0.90', alpha=0.2)
axs[1].axes.set_aspect(1.0)
scale = axs[1].transData.get_matrix()[0, 0]
axs[1].tick_params(axis='both', labelsize=fontsize)

## Multi-Particle Visualization

In [None]:
def make_multiparticle_attention_map_plot(
    attention_map: np.ndarray,
    xs: np.ndarray,
    nparticles_plot: int,
    cfg: config_dict.FrozenConfigDict
) -> None:

    ## plot parameters
    plt.close('all')
    sns.set_palette('deep')

    fraction = 0.25
    shrink   = 0.25
    fontsize = 10.0
    cmap     = sns.color_palette('mako', as_cmap=True)
    nrows    = ncols = nparticles_plot
    nparticles_total = nparticles_plot**2
    particle_size    = 40

    ## what to plot
    particle_identity = onp.random.randint(cfg.N)
    particle_inds     = onp.random.choice(onp.arange(cfg.N), size=nparticles_total)

    ## visualize many particles at once
    fw, fh   = 1.5, 1.5
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows), sharex=True, sharey=True, constrained_layout=True)

    ## basic plot properties
    for ax in axs.ravel():
        ax.set_facecolor('k')
        ax.set_xlim([-0.75*cfg.width, 0.75*cfg.width])
        ax.set_ylim([-0.75*cfg.width, 0.75*cfg.width])
        ax.axes.set_aspect(1.0)
        scale = ax.transData.get_matrix()[0, 0]

        ## for sizing the particles
        ax.grid(which='both', axis='both', color='0.90', alpha=0.25)

        ## for actual figure
        ax.set_xticks([])
        ax.set_yticks([])   

    ## visualize attention-colored particles
    for ii in range(nrows):
        for jj in range(ncols):
            # identify the particle and the corresponding row of the attention map
            particle_identity = particle_inds[jj + ii*nrows]
            attn = attention_map[particle_identity]
            attn = attn / np.sum(attn)

            # current colorbar scaling
            vmin = np.min(attn)
            vmax = np.max(attn)
            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

            # plot other particles
            s = radius_to_points(ax, 1.25)
            scat = axs[ii, jj].scatter(xs[:, 0], xs[:, 1], s=s, marker='o', c=attn, norm=norm, cmap=cmap, alpha=1.0)

            # plot this particle, circled
            axs[ii, jj].scatter(
                xs[particle_identity, 0], 
                xs[particle_identity, 1], 
                s=s, 
                c=attn[particle_identity], 
                marker='o', 
                edgecolor='white',
                linewidth=0.25, 
            )


    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs.ravel().tolist(), shrink=shrink, fraction=fraction, orientation='horizontal')
    cbar.ax.xaxis.set_ticks([])
    cbar.set_label('attention', fontsize=fontsize)

In [None]:
v0    = 0.4
layer = 3
xs    = np.split(data_dicts[v0]['xgs'][0], 2)[0]
cfg   = data_dicts[v0]['cfg']
head  = 3


## raw attention
# make_multiparticle_attention_map_plot(attention_dict[v0]['attention_maps'][layer][head], xs, cfg)

## rollout attention
make_multiparticle_attention_map_plot(attention_dict[v0]['rollout_attn'][layer], xs, nparticles_plot=2, cfg=cfg)

## Multi-Particle, Multi-v0

In [None]:
def make_v0_attention_map_plot(
    v0s_plot: list,
    nparticles_plot: int,
    clip_quantile: float,
    save_str: str
) -> None:
    ## plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fraction  = 0.1
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 20
    titles    = [rf"$v_0={v0}$" for v0 in v0s_plot]
    cmap      = sns.color_palette('mako', as_cmap=True)
    nrows     = nparticles_plot
    ncols     = len(titles)


    ## visualize many particles at once
    fw, fh   = 4, 5.0
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, 
                            figsize=(fw*ncols, fh*nrows), 
                            sharex=True, sharey=True, 
                            constrained_layout=True)
    axs = axs.reshape((nrows, ncols))

    ## pre-select particles
    particle_inds = onp.random.choice(onp.arange(N), replace=False, size=nparticles_plot)
    
    
    ## single colorbar
    normalized_attns = onp.zeros((nrows, ncols, N))
    for ii, particle_ind in enumerate(particle_inds):
        for jj, v0 in enumerate(v0s_plot):
            curr_attn = attention_dict[v0]['rollout_attn'][-1][particle_ind]
            normalized_attns[ii, jj] = curr_attn / onp.sum(curr_attn)

    vmin       = onp.quantile(normalized_attns, q=clip_quantile)
    vmax       = onp.quantile(normalized_attns, q=(1-clip_quantile))
    normalized_attns[normalized_attns < vmin] = vmin
    normalized_attns[normalized_attns > vmax] = vmax
    norm       = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
    mappable   = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    mappable.set_array([])

    ## make the plot
    vmin, vmax = np.inf, -np.inf
    for ii in range(nrows):
        particle_identity = particle_inds[ii]

        for jj, v0 in enumerate(v0s_plot):
            ## unpack
            cfg = data_dicts[v0]['cfg']
            xs = np.split(data_dicts[v0]['xgs'][0], 2)[0]


            ## update basic visual aspects
            ax = axs[ii, jj]
            ax.set_facecolor('k')
            ax.set_xlim([-0.85*cfg.width, 0.85*cfg.width])
            ax.set_ylim([-0.85*cfg.width, 0.85*cfg.width])
            ax.set_xticks([-0.75*cfg.width, 0, 0.75*cfg.width],
                          [r"$-3L/8$", r"$0.0$", r"$3L/8$"])
            ax.set_yticks([-0.75*cfg.width, 0, 0.75*cfg.width],
                          [r"$-3L/8$", r"$0.0$", r"$3L/8$"])
            ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
            
            
            if ii == 0:
                ax.set_title(titles[jj], fontsize=fontsize)
            ax.axes.set_aspect(1.0)
            ax.grid(which='both', axis='both', color='0.90', alpha=0.05)
            ax.tick_params(axis='both', length=0, width=0, labelsize=fontsize)
            s = radius_to_points(ax, 1.35)


            ## plot the attention
#             attn = attention_dict[v0]['rollout_attn'][-1][particle_identity]
#             attn = attn / np.sum(attn)
#             vmin, vmax = attn.min(), attn.max()
#             norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
            attn = normalized_attns[ii, jj]
            scat = ax.scatter(     xs[:, 0],                     xs[:, 1], s=s, c=attn,                    marker='o', norm=norm, cmap=cmap, alpha=0.75)
            ax.scatter(xs[particle_identity, 0], xs[particle_identity, 1], s=s, c=attn[particle_identity], marker='o', norm=norm, cmap=cmap, edgecolor='white', linewidth=0.1)

    cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs.ravel().tolist(), shrink=shrink, fraction=fraction, orientation='vertical')
    cbar.ax.xaxis.set_ticks([])
    cbar.ax.yaxis.set_ticks([])
    cbar.set_label('attention', fontsize=fontsize)


    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}.pdf', dpi=300, bbox_inches='tight')

In [None]:
# base_output_folder = '/scratch/nb3397/results/mips/lowd_cates/figures/8_18_23/'
base_output_folder = '/scratch/nb3397/results/mips/lowd_cates/figures/9_20_23/'
make_v0_attention_map_plot(v0s_plot=v0s[1:-1], nparticles_plot=1, clip_quantile=0.02, save_str=f'{base_output_folder}/attention')