In [None]:
import jax
from jax import numpy as np
from jax import vmap
import numpy as onp
import sys

import dill as pickle
sys.path.append('../py')
%load_ext autoreload
%autoreload 2


import common.drifts as drifts
import common.networks as networks
import common.grid_utils as grid_utils
import launchers.twod_ring_simulation as twod_ring_simulation
from launchers.twod_ring_simulation import *
from ml_collections import config_dict
import matplotlib.animation as animation
import functools


import matplotlib.pyplot as plt
import matplotlib as mpl


mpl.rcParams['axes.grid']           = True
mpl.rcParams['axes.grid.which']     = 'both'
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['ytick.minor.visible'] = True
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['axes.facecolor']      = 'white'
mpl.rcParams['grid.color']          = '0.8'
mpl.rcParams['grid.alpha']          = '0.1'
mpl.rcParams['text.usetex']         = True
mpl.rcParams['font.family']         = 'serif'
mpl.rcParams['figure.figsize']      = (8, 4)
mpl.rcParams['figure.titlesize']    = 7.5
mpl.rcParams['font.size']           = 10
mpl.rcParams['legend.fontsize']     = 7.5
mpl.rcParams['figure.dpi']          = 300

In [None]:
## shifted system, higher resolution
base_folder = '/scratch/nb3397/results/mips/2d_ring/8_15_23'
output_name = 'ring_8_15_23_shifted_system_symmetric_0'

## load the data
data = pickle.load(open(f'{base_folder}/{output_name}.npy', 'rb'))

In [None]:
cfg                      = data['cfg']
twod_ring_simulation.cfg = cfg
xgs                      = data['xgs']
xs, gs                   = np.split(xgs, 2, axis=1)
xs, gs                   = np.squeeze(xs), np.squeeze(gs)
ema_params               = data['ema_params_list'][-1]
params                   = data['params_list'][-1]

try:
    print(f'symmetric: {cfg.symmetric_network}')
except:
    print('Adding symmetric network.')
    cfg.symmetric_network=True
    
try:
    print(f'shift system: {cfg.shift_system}')
except:
    print('Adding shift system.')
    cfg.shift_system=False

In [None]:
score_net, particle_div_net, div_net, map_score_net = construct_network()
twod_ring_simulation.score_net        = score_net
twod_ring_simulation.particle_div_net = particle_div_net
twod_ring_simulation.div_net          = div_net
twod_ring_simulation.map_score_net    = map_score_net

# Entropy Scatterplot

In [None]:
def make_entropy_scatter_plot(params: Dict[str, hk.Params]):
    """Make a full scatter plot for visualization."""
    # compute quantities needed for plotting
    inds = onp.random.choice(onp.arange(cfg.ntrajs), size=cfg.plot_bs, replace=False)
    xs, gs, gdot_mags, xdot_mags, score_mags, \
        x_score_mags, g_score_mags, vs, v_times_s, \
            div_vs, div_vxs, div_vgs, div_sxs, div_sgs = \
                compute_batch_output_info(data['xgs'][inds], params)


    # compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])
    scales          = np.concatenate((xscale, gscale))
    seifert_entropy = np.sum(vs/scales[None, :]*vs, axis=1)


    # common plot parameters
    plt.close('all')
    plt.style.use('dark_background')
    sns.set_palette('deep')
    gmax, gmin = gs.max(), gs.min()
    xmax = cfg.width
    xmin = -cfg.width
    fw   = 4.0
    fh   = (gmax-gmin) / (xmax - xmin) * fw
    fraction  = 0.25
    shrink    = 0.5
    fontsize  = 17.5


    ###### main entropy figure
    # individual panels
    titles = [[r"$\Vert\dot{g}\Vert$",        r"$\Vert\dot{x}\Vert$", r"$\Vert s \Vert$"],
              [r"$\Vert v \Vert_{D^{-1}}^2$", r"$\Vert s_x\Vert$",    r"$\Vert s_g\Vert$"],
              [r"$\nabla\cdot v$",            r"$\nabla\cdot v_x$",   r"$\nabla\cdot v_g$"],
              [r"$v \cdot s$",                r"$\nabla\cdot s_x$",   r"$\nabla\cdot s_g$"]]

    cs     = [[gdot_mags,       xdot_mags,    score_mags],
              [seifert_entropy, x_score_mags, g_score_mags],
              [div_vs,          div_vxs,      div_vgs],
              [v_times_s,       div_sxs,      div_sgs]]

    cmaps  = [sns.color_palette('magma', as_cmap=True),
              sns.color_palette('magma', as_cmap=True),
              sns.color_palette('icefire', as_cmap=True),
              sns.color_palette('icefire', as_cmap=True)]

    nrows = len(titles)
    ncols = len(titles[0])
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=True, sharey=True, constrained_layout=True
    )



    for ax in axs.ravel():
        ax.set_xlim([-cfg.width, cfg.width])
        ax.set_ylim([gmin, gmax])
        ax.grid(which='both', axis='both', color='0.90', alpha=0.2)
        ax.tick_params(axis='both', labelsize=fontsize)
        ax.set_aspect((gmax-gmin) / (xmax-xmin))


    # do the plotting
    for ii in range(nrows):
        for jj in range(ncols):
            title = titles[ii][jj]
            c = cs[ii][jj]
            ax = axs[ii, jj]
            ax.set_title(title, fontsize=fontsize)

            min_val = float(onp.min(c)) if onp.min(c) < 0 else 0
            max_val = float(onp.max(c))

            # make symmetric
            if ii >= nrows//2:
                min_val = min(min_val, -max_val)
                max_val = max(max_val, -min_val)

            vmin = min_val
            vmax = max_val

            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
            scat = ax.scatter(xs, gs, s=10, marker='o', c=c, cmap=cmaps[ii], norm=norm, alpha=0.1)
            cbar = fig.colorbar(scat, ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
            cbar.ax.tick_params(labelsize=fontsize)

            # fix alpha on colorbar
            cbar.set_alpha(1.0)
            cbar.draw_all()

            if ii == nrows-1:
                ax.set_xlabel(r"$x$", fontsize=fontsize)
            if jj == 0:
                ax.set_ylabel(r"$g$", fontsize=fontsize)


def make_reduced_entropy_scatter_plot(
    params: Dict[str, hk.Params],
    shift_plot: bool,
    output_name: str
):
    """Make a scaled-down scatter plot for the paper."""
    # compute quantities needed for plotting
    inds = onp.random.choice(onp.arange(cfg.ntrajs), size=cfg.plot_bs, replace=False)
    xs, gs, gdot_mags, xdot_mags, score_mags, \
        x_score_mags, g_score_mags, vs, v_times_s, \
            div_vs, div_vxs, div_vgs, div_sxs, div_sgs = \
                compute_batch_output_info(data['xgs'][inds], params)


    # compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])
    scales          = np.concatenate((xscale, gscale))
    seifert_entropy = np.sum(vs/scales[None, :]*vs, axis=1)
    
    ### handle the shift and set up the grid
    if shift_plot:
        xs = (xs + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        min_x, max_x = -cfg.width, cfg.width

    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    gmax, gmin = data['xgs'][:, 1].max(), data['xgs'][:, 1].min()
    gmin = min(gmin, -gmax)
    gmax = max(-gmin, gmax)
    
    if shift_plot:
        xmin, xmax = 0, 2*cfg.width
    else:
        xmin, xmax = -cfg.width, cfg.width

    aspect = (gmax - gmin) / (xmax - xmin)
    fw = 8.0
    fh = 5.0
    fraction = 0.15
    shrink = 0.5
    fontsize = 28


    ###### main entropy figure
    # individual panels
    titles = [[r"$|v_g|$", r"$|v_x|$",  r"$|v|_{D^{-1}}$"],
              [r"$\nabla_g\cdot v_g$", r"$\nabla_x\cdot v_x$", r"$\nabla\cdot v$"]]

    cs     = [[gdot_mags, xdot_mags, onp.sqrt(seifert_entropy)],
              [div_vgs,     div_vxs,                   div_vs]]

    cmap = sns.color_palette('mako', as_cmap=True)

    nrows = len(titles)
    ncols = len(titles[0])
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=True, sharey=True, constrained_layout=True
    )


    for ax in axs.ravel():
        ax.set_xlim([xmin, xmax])
        ax.set_ylim([gmin, gmax])
        
        if shift_plot:
            ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        else:
            ax.set_xticks([     -8.0,      -4.0, 0,      4.0,      8.0],
                          [r"$-L/2$", r"$-L/4$", 0, r"$L/4$", r"$L/2$"])

        
        ax.grid(which='both', axis='both', color='0.9', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        
        for spine in ax.spines.values():
            spine.set_edgecolor('grey')


    # do the plotting
    for ii in range(nrows):
        for jj in range(ncols):
            title = titles[ii][jj]
            c = cs[ii][jj]
            ax = axs[ii, jj]
            ax.set_title(title, fontsize=fontsize)

            vmin = float(onp.min(c)) if onp.min(c) < 0 else 0
            vmax = float(onp.max(c))

            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
            mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
            mappable.set_array([])
            ax.set_facecolor(mappable.to_rgba(0.0))
            scat = ax.scatter(xs, gs, s=10, marker='o', c=c, cmap=cmap, norm=norm, alpha=0.75)
            cbar = fig.colorbar(scat, ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
            cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
            cbar.outline.set_edgecolor('grey')

            # fix alpha on colorbar
            cbar.set_alpha(1.0)
            cbar.draw_all()

            if ii == nrows-1:
                ax.set_xlabel(r"$x$", fontsize=fontsize)
            if jj == 0:
                ax.set_ylabel(r"$g$", fontsize=fontsize)


    if output_name != '':
        ## fig background transparency
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_5_23'
save_name   = 'entropy_scatter_nospring.pdf'

make_reduced_entropy_scatter_plot(
    ema_params[0.9999], shift_plot=True, output_name=f'{save_folder}/{save_name}'
)

# Entropy Density

In [None]:
def make_entropy_density_plot(
    params: Dict[str, hk.Params],
    xgrid_plot: np.ndarray,
    ggrid_plot: np.ndarray,
    xgrid: np.ndarray,
    ggrid: np.ndarray,
    nbatches: int,
    compute_density: bool
) -> None:
    # compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs = onp.zeros(cfg.ntrajs)
    gs = onp.zeros(cfg.ntrajs)
    gdot_mags = onp.zeros(cfg.ntrajs)
    xdot_mags = onp.zeros(cfg.ntrajs)
    score_mags = onp.zeros(cfg.ntrajs)
    x_score_mags = onp.zeros(cfg.ntrajs)
    g_score_mags = onp.zeros(cfg.ntrajs)
    vs = onp.zeros((cfg.ntrajs, 2*cfg.d))
    v_times_s = onp.zeros(cfg.ntrajs)
    div_vs  = onp.zeros(cfg.ntrajs)
    div_vxs = onp.zeros(cfg.ntrajs)
    div_vgs = onp.zeros(cfg.ntrajs)
    div_sxs = onp.zeros(cfg.ntrajs)
    div_sgs = onp.zeros(cfg.ntrajs)

    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size

        xs[lb:ub], gs[lb:ub], gdot_mags[lb:ub], xdot_mags[lb:ub], score_mags[lb:ub], \
            x_score_mags[lb:ub], g_score_mags[lb:ub], vs[lb:ub], v_times_s[lb:ub], \
                div_vs[lb:ub], div_vxs[lb:ub], div_vgs[lb:ub], div_sxs[lb:ub], div_sgs[lb:ub] = \
                    compute_batch_output_info(data['xgs'][lb:ub], params)


    # compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])
    scales          = np.concatenate((xscale, gscale))
    seifert_entropy = np.sum(vs/scales[None, :]*vs, axis=1)


    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')


    fw = 8.0
    fh = 5.0
    fraction  = 0.15
    shrink    = 0.5
    fontsize  = 28


    ## grid quantities
    xgs = np.concatenate((xs.reshape((-1, 1, 1)), gs.reshape((-1, 1, 1))), axis=-1)
    pointwise_quantities = [[gdot_mags,    xdot_mags,    onp.sqrt(seifert_entropy)],
                            [g_score_mags, x_score_mags, score_mags],
                            [div_vgs,      div_vxs,      div_vs],
                            [div_sgs,      div_sxs,      v_times_s]]

    gridded_quantities = []
    for row in pointwise_quantities:
        grid_row = []
        for pointwise_quantity in row:
            curr_grid = onp.zeros((xgrid.size-1, ggrid.size-1))
            curr_multiplicity = onp.zeros((xgrid.size-1, ggrid.size-1))
            for curr_batch in range(nbatches):
                lb = curr_batch*batch_size
                ub = lb + batch_size
                batch_grid, batch_multiplicity = grid_utils.average_grid_quantity(
                    pointwise_quantity[lb:ub].reshape((-1, 1)), xgs[lb:ub], xgrid, ggrid, sum_rslts=True
                )
                curr_grid += onp.array(batch_grid)
                curr_multiplicity += onp.array(batch_multiplicity)
                
            if compute_density:
                curr_multiplicity[curr_multiplicity == 0] = 1.0 # avoid divide-by-zero
                grid_row.append(curr_grid / curr_multiplicity)  # no need to average because they cancel when averaging both
            else:
                grid_row.append(curr_grid / cfg.ntrajs)

        gridded_quantities.append(grid_row)


    ###### main entropy figure
    # individual panels
    titles = [[r"$\Vert\dot{g}\Vert$", r"$\Vert\dot{x}\Vert$", r"$\Vert v \Vert_{D^{-1}}$",],
              [r"$\Vert s_g\Vert$",    r"$\Vert s_x\Vert$",    r"$\Vert s \Vert$",],
              [r"$\nabla\cdot v_g$",   r"$\nabla\cdot v_x$",   r"$\nabla\cdot v$"],
              [r"$\nabla\cdot s_g$",   r"$\nabla\cdot s_x$",   r"$v \cdot s$"]]

    cmaps  = [[sns.color_palette('mako',   as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)],
              [sns.color_palette('mako',   as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)],
              [sns.color_palette('mako',   as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)],
              [sns.color_palette('mako_r', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)]]

    nrows = len(titles)
    ncols = len(titles[0])
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=True, sharey=True, constrained_layout=True
    )

    for ax in axs.ravel():
        ax.set_xlim([xgrid.min(), xgrid.max()])
        ax.set_ylim([ggrid.min(), ggrid.max()])
        ax.grid(which='both', axis='both', color='0.95', alpha=0.1)
        ax.tick_params(axis='both', labelsize=fontsize)

        for spine in ax.spines.values():
            spine.set_edgecolor('grey')

    # do the plotting
    for ii in range(nrows):
        for jj in range(ncols):
            title = titles[ii][jj]
            grid_values = gridded_quantities[ii][jj]
            ax = axs[ii, jj]
            ax.set_title(title, fontsize=fontsize)

            min_val = float(onp.min(grid_values))
            max_val = float(onp.max(grid_values))

            vmin = min_val
            vmax = max_val

            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
            ctr  = ax.contourf(xgrid_plot, ggrid_plot, grid_values, cmap=cmaps[ii][jj], norm=norm, levels=100)
            cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[ii][jj]), ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
            cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
            cbar.outline.set_edgecolor('grey')

            if ii == nrows-1:
                ax.set_xlabel(r"$x$", fontsize=fontsize)
            if jj == 0:
                ax.set_ylabel(r"$g$", fontsize=fontsize)


def make_reduced_entropy_density_plot(
    params: Dict[str, hk.Params],
    npts: int,
    nbatches: int,
    compute_density: bool,
    shift_plot: bool,
    output_name: str,
    clip_quantile: float
) -> None:
    ### compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs = onp.zeros(cfg.ntrajs)
    gs = onp.zeros(cfg.ntrajs)
    gdot_mags = onp.zeros(cfg.ntrajs)
    xdot_mags = onp.zeros(cfg.ntrajs)
    vs = onp.zeros((cfg.ntrajs, 2*cfg.d))
    div_vs  = onp.zeros(cfg.ntrajs)
    div_vxs = onp.zeros(cfg.ntrajs)
    div_vgs = onp.zeros(cfg.ntrajs)


    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size

        xs[lb:ub], gs[lb:ub], gdot_mags[lb:ub], xdot_mags[lb:ub], _, \
            _, _, vs[lb:ub], _, div_vs[lb:ub], div_vxs[lb:ub], div_vgs[lb:ub], _, _ = \
                    compute_batch_output_info(data['xgs'][lb:ub], params)


    ### compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])
    scales          = np.concatenate((xscale, gscale))
    seifert_entropy = np.sum(vs/scales[None, :]*vs, axis=1)


    ### handle the shift and set up the grid
    if shift_plot:
        xs = (xs + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        min_x, max_x = -cfg.width, cfg.width


    ### grid quantities
    pointwise_quantities = [[gdot_mags, xdot_mags, onp.sqrt(seifert_entropy)],
                            [div_vgs,   div_vxs,                     div_vs]]
    nrows = len(pointwise_quantities)
    ncols = len(pointwise_quantities[0])
    gridded_quantities       = onp.zeros((nrows, ncols, npts, npts))
    xgrid_plots, ggrid_plots = onp.zeros((nrows, ncols, npts, npts)), onp.zeros((nrows, ncols, npts, npts))
    for ii in range(nrows):
        for jj in range(ncols):
            gridded_quantities[ii, jj], xedges, gedges = np.histogram2d(
                xs, gs, weights=pointwise_quantities[ii][jj], bins=npts, density=(not compute_density)
            )

            if compute_density:
                multiplicity = onp.array(np.histogram2d(xs, gs, bins=(xedges, gedges))[0])

                # remove noisy regions
                inds = multiplicity <= 0
                multiplicity[inds] = 1.0
                gridded_quantities[ii, jj][inds] = 0.0
                gridded_quantities[ii, jj] /= multiplicity

            curr_xgrid, curr_ggrid = np.meshgrid(xedges, gedges, indexing='ij')
            xgrid_plots[ii, jj] = curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges)[:, None]
            ggrid_plots[ii, jj] = curr_ggrid[:-1, :-1] + 0.5*np.diff(gedges)[None, :]


    ###### main entropy figure
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw        = 8.0
    fh        = 6.25
    fraction  = 0.2
    shrink    = 0.5
    fontsize  = 28

    min_g, max_g = gs.min(), gs.max()
    min_g        = min(min_g, -max_g)
    max_g        = max(-min_g, max_g)


    # individual panels
    if compute_density:
        titles = [[            r"$|v_g|$",             r"$|v_x|$",  r"$|v|_{D^{-1}}$"],
                  [r"$\nabla_g\cdot v_g$", r"$\nabla_x\cdot v_x$", r"$\nabla\cdot v$"]]
    else:
        titles = [[                  r"$|j_g|$",                   r"$|j_x|$",        r"$|j|_{D^{-1}}$"],
                  [r"$(\nabla_g\cdot v_g)\rho$", r"$(\nabla_x\cdot v_x)\rho$", r"$(\nabla\cdot v)\rho$"]]


    cmaps = [[sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)],
             [sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)]]


    nrows = len(titles)
    ncols = len(titles[0])
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )


    for ax in axs.ravel():
        ax.set_xlim([min_x, max_x])
        ax.set_ylim([min_g, max_g])
        
        if shift_plot:
            ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        else:
            ax.set_xticks([     -8.0,      -4.0, 0,      4.0,      8.0],
                          [r"$-L/2$", r"$-L/4$", 0, r"$L/4$", r"$L/2$"])


        ax.grid(which='both', axis='both', color='0.9', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)


        for spine in ax.spines.values():
            spine.set_edgecolor('grey')


    # do the plotting
    for ii in range(nrows):
        for jj in range(ncols):
            title = titles[ii][jj]
            grid_values = gridded_quantities[ii, jj]
            ax = axs[ii, jj]
#             ax.set_title(title, fontsize=fontsize)

            ## clip the values to quantile to avoid washout
            vmin = onp.quantile(grid_values, q=clip_quantile)
            vmax = onp.quantile(grid_values, q=(1-clip_quantile))
            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
            mappable = mpl.cm.ScalarMappable(cmap=cmaps[ii][jj], norm=norm)
            mappable.set_array([])
            ax.set_facecolor(mappable.to_rgba(0.0))
            ctr  = ax.contourf(xgrid_plots[ii, jj], ggrid_plots[ii, jj], grid_values, cmap=cmaps[ii][jj], norm=norm, levels=100)
            cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[ii][jj]), ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
            cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
            cbar.ax.set_xlabel(title, fontsize=fontsize)
            cbar.outline.set_edgecolor('grey')


            # fix strange aliasing artifact
            for c in ctr.collections:
                c.set_edgecolor("face")
                c.set_rasterized(True)

            ax.set_xlabel(r"$x$", fontsize=fontsize)
            if jj == 0:
                ax.set_ylabel(r"$g$", fontsize=fontsize)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())


def make_density_system_and_total_entropy_plot(
    params: Dict[str, hk.Params],
    npts: int,
    bins: int,
    nbatches: int,
    compute_density: bool,
    output_name: str,
    clip_quantile: float
) -> None:
    ### compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs         = onp.zeros(cfg.ntrajs)
    gs         = onp.zeros(cfg.ntrajs)
    vs         = onp.zeros((cfg.ntrajs, 2*cfg.d))
    div_vs     = onp.zeros(cfg.ntrajs)

    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size
        xs[lb:ub], gs[lb:ub], _, _, _, _, _, \
            vs[lb:ub], _, div_vs[lb:ub], _, _, _, _ = \
                compute_batch_output_info(data['xgs'][lb:ub], params)


    ### compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])
    scales          = np.concatenate((xscale, gscale))
    seifert_entropy = np.sum(vs/scales[None, :]*vs, axis=1)


    ### bin the data
    pointwise_quantities     = [None, onp.sqrt(seifert_entropy), div_vs]
    nplots                   = len(pointwise_quantities)
    gridded_quantities, xgrid_plots, ggrid_plots = [], [], []
    
    for ii in range(nplots):
        gridded_quantity, xedges, gedges = np.histogram2d(
            xs, gs, weights=pointwise_quantities[ii], bins=bins if ii == 0 else npts, density=(ii == 0) or (not compute_density)
        )
        multiplicity = onp.array(np.histogram2d(xs, gs, bins=(xedges, gedges))[0])

        # remove noisy regions for everything except \rho
        if (ii != 0) and compute_density:
            gridded_quantity = onp.array(gridded_quantity)
            inds = multiplicity <= 0
            multiplicity[inds] = 1.0
            gridded_quantity[inds] = 0.0
            gridded_quantity /= multiplicity
            
        gridded_quantities.append(gridded_quantity)
        curr_xgrid, curr_ggrid = np.meshgrid(xedges, gedges, indexing='ij')
        xgrid_plots.append(curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges)[:, None])
        ggrid_plots.append(curr_ggrid[:-1, :-1] + 0.5*np.diff(gedges)[None, :])

    
    ### make figure
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw        = 8.0
    fh        = 6.25
    fraction  = 0.2
    shrink    = 0.5
    fontsize  = 28


    min_x, max_x = 0, 2*cfg.width
    min_g, max_g = gs.min(), gs.max()
    min_g        = min(min_g, -max_g)
    max_g        = max(max_g, -min_g)
    

    # individual panels
    if compute_density:
        titles = [r"$\rho$", r"$|v|_{D^{-1}}$", r"$\nabla\cdot v$"]
    else:
        titles = [r"$\rho$", r"$|j|_{D^{-1}}$", r"$(\nabla\cdot v)\rho$"]
        
    cmaps  = [sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)]
    nrows  = 1
    ncols  = len(titles)
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=True, sharey=True, constrained_layout=True
    )


    for ax in axs.ravel():
        ax.set_xlim([min_x, max_x])
        if compute_density:
            ax.set_ylim([min_g, max_g])
        else:
            ax.set_ylim([-2.5, 2.5])
        ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                      [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        ax.grid(which='both', axis='both', color='0.9', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)

        for spine in ax.spines.values():
            spine.set_edgecolor('grey')


    ## now plot the entropies
    for ii in range(ncols):
        title = titles[ii]
        grid_values = gridded_quantities[ii]
        ax = axs[ii]
        ax.set_title(title, fontsize=fontsize)

        ## clip the values to quantile to avoid washout
        vmin = onp.quantile(grid_values, q=clip_quantile)
        vmax = onp.quantile(grid_values, q=(1-clip_quantile))
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = mpl.cm.ScalarMappable(cmap=cmaps[ii], norm=norm)
        mappable.set_array([])
        ax.set_facecolor(mappable.to_rgba(0.0))
        ctr  = ax.contourf(xgrid_plots[ii], ggrid_plots[ii], grid_values, cmap=cmaps[ii], norm=norm, levels=100)
        cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[ii]), ax=ax, 
                            fraction=fraction, shrink=shrink, orientation='horizontal')
        cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')
        cbar.outline.set_linewidth(0.25)


        # fix strange aliasing artifact
        for c in ctr.collections:
            c.set_edgecolor("face")
            c.set_rasterized(True)


        ax.set_xlabel(r"$x$", fontsize=fontsize)
        if ii == 0:
            ax.set_ylabel(r"$g$", fontsize=fontsize)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())


def make_split_dof_entropy_plot(
    params: Dict[str, hk.Params],
    npts: int,
    nbatches: int,
    compute_density: bool,
    shift_plot: bool,
    output_name: str,
    clip_quantile: float,
    two_rows: bool
) -> None:
    ### compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs = onp.zeros(cfg.ntrajs)
    gs = onp.zeros(cfg.ntrajs)
    gdot_mags = onp.zeros(cfg.ntrajs)
    xdot_mags = onp.zeros(cfg.ntrajs)
    div_vxs = onp.zeros(cfg.ntrajs)
    div_vgs = onp.zeros(cfg.ntrajs)


    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size

        xs[lb:ub], gs[lb:ub], gdot_mags[lb:ub], xdot_mags[lb:ub], _, \
            _, _, _, _, _, div_vxs[lb:ub], div_vgs[lb:ub], _, _ = \
                    compute_batch_output_info(data['xgs'][lb:ub], params)


    ### compute seifert entropy
    xscale          = np.array(cfg.d*[cfg.eps])
    gscale          = np.array(cfg.d*[cfg.gamma])


    ### handle the shift and set up the grid
    if shift_plot:
        xs = (xs + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        min_x, max_x = -cfg.width, cfg.width


    ### grid quantities
    if two_rows:
        pointwise_quantities = [[gdot_mags, xdot_mags], 
                                [div_vgs,     div_vxs]]
    else:
        pointwise_quantities = [[gdot_mags, xdot_mags, div_vgs, div_vxs]]

    nrows = len(pointwise_quantities)
    ncols = len(pointwise_quantities[0])
    gridded_quantities       = onp.zeros((nrows, ncols, npts, npts))
    xgrid_plots, ggrid_plots = onp.zeros((nrows, ncols, npts, npts)), onp.zeros((nrows, ncols, npts, npts))
    for ii in range(nrows):
        for jj in range(ncols):
            gridded_quantities[ii, jj], xedges, gedges = np.histogram2d(
                xs, gs, weights=pointwise_quantities[ii][jj], bins=npts, density=(not compute_density)
            )

            if compute_density:
                multiplicity = onp.array(np.histogram2d(xs, gs, bins=(xedges, gedges))[0])

                # remove noisy regions
                inds = multiplicity <= 0
                multiplicity[inds] = 1.0
                gridded_quantities[ii, jj][inds] = 0.0
                gridded_quantities[ii, jj] /= multiplicity

            curr_xgrid, curr_ggrid = np.meshgrid(xedges, gedges, indexing='ij')
            xgrid_plots[ii, jj] = curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges)[:, None]
            ggrid_plots[ii, jj] = curr_ggrid[:-1, :-1] + 0.5*np.diff(gedges)[None, :]


    ###### main entropy figure
    # common plot parameters
    plt.close('all')
    sns.set_palette('deep')
    
    fw  = 8.0
    fh  = 6.25
    
    fraction  = 0.2
    shrink    = 0.5
    fontsize  = 28

    min_g, max_g = gs.min(), gs.max()
    min_g        = min(min_g, -max_g)
    max_g        = max(-min_g, max_g)


    # individual panels
    if compute_density:
        if two_rows:
            titles = [[r"$|v_g|$",                         r"$|v_x|$"],
                      [r"$\nabla_g\cdot v_g$", r"$\nabla_x\cdot v_x$"]]
        else:
            titles = [[r"$|v_g|$", r"$|v_x|$", r"$\nabla_g\cdot v_g$", r"$\nabla_x\cdot v_x$"]]

    else:
        if two_rows:
            titles = [[r"$|j_g|$",                                     r"$|j_x|$"], 
                      [r"$(\nabla_g\cdot v_g)\rho$", r"$(\nabla_x\cdot v_x)\rho$"]]
        else:
            titles = [[r"$|j_g|$", r"$|j_x|$", r"$(\nabla_g\cdot v_g)\rho$", r"$(\nabla_x\cdot v_x)\rho$"]]


    if two_rows:
        cmaps = [[sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)], 
                 [sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)]]
    else:
        cmaps = [[sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True), sns.color_palette('mako', as_cmap=True)]]


    nrows = len(titles)
    ncols = len(titles[0])
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )
    axs = axs.reshape((nrows, ncols))


    for ax in axs.ravel():
        ax.set_xlim([min_x, max_x])
        ax.set_ylim([min_g, max_g])

        
        if shift_plot:
            ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        else:
            ax.set_xticks([     -8.0,      -4.0, 0,      4.0,      8.0],
                          [r"$-L/2$", r"$-L/4$", 0, r"$L/4$", r"$L/2$"])


        ax.grid(which='both', axis='both', color='0.9', alpha=0.1)
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)


        for spine in ax.spines.values():
            spine.set_edgecolor('grey')


    # do the plotting
    for ii in range(nrows):
        for jj in range(ncols):
            title = titles[ii][jj]
            grid_values = gridded_quantities[ii, jj]
            ax = axs[ii, jj]
#             ax.set_title(title, fontsize=fontsize)

            ## clip the values to quantile to avoid washout
            vmin = onp.quantile(grid_values, q=clip_quantile)
            vmax = onp.quantile(grid_values, q=(1-clip_quantile))
            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
            mappable = mpl.cm.ScalarMappable(cmap=cmaps[ii][jj], norm=norm)
            mappable.set_array([])
            ax.set_facecolor(mappable.to_rgba(0.0))
            ctr  = ax.contourf(xgrid_plots[ii, jj], ggrid_plots[ii, jj], grid_values, cmap=cmaps[ii][jj], norm=norm, levels=100)
            cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmaps[ii][jj]), ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
            cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
            cbar.ax.set_xlabel(title, fontsize=fontsize)
            cbar.outline.set_edgecolor('grey')


            # fix strange aliasing artifact
            for c in ctr.collections:
                c.set_edgecolor("face")
                c.set_rasterized(True)

            ax.set_xlabel(r"$x$", fontsize=fontsize)
            if jj == 0:
                ax.set_ylabel(r"$g$", fontsize=fontsize)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

### spatial density

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'
save_name       = 'entropy_hist_nospring'
compute_density = False
output_name     = f'{save_folder}/{save_name}' + ('_density.pdf' if compute_density else '.pdf')


make_reduced_entropy_density_plot(
    ema_params[0.9999], npts=64, nbatches=100, 
    compute_density=compute_density, shift_plot=True, 
    clip_quantile=0.005, output_name=output_name
)

### per-particle (equivalent to scatterplot)

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'

save_name       = 'entropy_hist_nospring'
compute_density = True
output_name     = f'{save_folder}/{save_name}' + ('_density.pdf' if compute_density else '.pdf')


make_reduced_entropy_density_plot(
    ema_params[0.9999], npts=256, nbatches=100, 
    compute_density=compute_density, shift_plot=True, 
    clip_quantile=0.0075, output_name=output_name
)

### system and total EPR only

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_10_23'
compute_density = True
save_name       = 'density_system_and_total_epr' + ('_density' if compute_density else '')
output_name     = f'{save_folder}/{save_name}.pdf'


make_density_system_and_total_entropy_plot(
    ema_params[0.9999], bins=64, npts=256, nbatches=100, compute_density=compute_density, output_name=output_name, clip_quantile=0.0075
)

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_10_23'
compute_density = False
save_name       = 'density_system_and_total_epr' + ('_density' if compute_density else '')
output_name     = f'{save_folder}/{save_name}.pdf'


make_density_system_and_total_entropy_plot(
    ema_params[0.9999], bins=64, npts=64, nbatches=100, compute_density=compute_density, output_name=output_name, clip_quantile=0.0075
)

### split DOFs only

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'
two_rows        = True
compute_density = True
save_name       = 'split_dof_entropy' + ('_two_rows' if two_rows else '') + ('_density' if compute_density else '')
output_name     = f'{save_folder}/{save_name}.pdf'


make_split_dof_entropy_plot(
    ema_params[0.9999], npts=256, nbatches=100, output_name=output_name, shift_plot=True, 
    compute_density=compute_density, clip_quantile=0.0075, two_rows=two_rows
)

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'

two_rows        = True
compute_density = False
save_name       = 'split_dof_entropy' + ('_two_rows' if two_rows else '') + ('_density' if compute_density else '')
output_name     = f'{save_folder}/{save_name}.pdf'


make_split_dof_entropy_plot(
    ema_params[0.9999], npts=64, nbatches=100, output_name=output_name, shift_plot=True, 
    compute_density=compute_density, clip_quantile=0.0075, two_rows=two_rows
)

### Single plot of div(v) \rho

In [None]:
def make_div_v_single_pane(
    params: Dict[str, hk.Params],
    npts: int,
    nbatches: int,
    clip_quantile: float,
    output_name: str
) -> None:
    ## set up figure
    plt.close('all')
    sns.set_palette('deep')
    fw        = 8.0
    fh        = 6.25
    fraction  = 0.2
    shrink    = 0.5
    fontsize  = 28
    min_x, max_x = 0.0, 2*cfg.width
    cmap = sns.color_palette('mako', as_cmap=True)
    title = r"$\left(\nabla\cdot v\right)\rho$"
    fig, axs = plt.subplots(figsize=(fw, fh),
        sharex=True, sharey=True, constrained_layout=True
    )


    ### compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs         = onp.zeros(cfg.ntrajs)
    gs         = onp.zeros(cfg.ntrajs)
    div_vs     = onp.zeros(cfg.ntrajs)
    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size

        xs[lb:ub], gs[lb:ub], _, _, _, \
            _, _, _, _, div_vs[lb:ub], _, _, _, _ = \
                    compute_batch_output_info(data['xgs'][lb:ub], params)



    ### bin the data
    grid_values, xedges, gedges = np.histogram2d(xs, gs, weights=div_vs, bins=npts, density=True)
    curr_xgrid, curr_ggrid = np.meshgrid(xedges, gedges, indexing='ij')
    xgrid_plot = curr_xgrid[:-1, :-1] + 0.5*np.diff(xedges)[:, None]
    ggrid_plot = curr_ggrid[:-1, :-1] + 0.5*np.diff(gedges)[None, :]


    ## basic plot setup
    ax.set_xlim([min_x, max_x])
    min_g, max_g = gs.min(), gs.max()
    min_g = min(min_g, -max_g)
    max_g = max(-min_g, max_g)
    ax.set_ylim([min_g, max_g])
    ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                  [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
    ax.set_xlabel(r"$x$", fontsize=fontsize)
    ax.set_ylabel(r"$g$", fontsize=fontsize)
    ax.grid(which='both', axis='both', color='0.9', alpha=0.1)
    ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
    for spine in ax.spines.values():
        spine.set_edgecolor('grey')


    ## do the plotting
    vmin = onp.quantile(grid_values, q=clip_quantile)
    vmax = onp.quantile(grid_values, q=(1-clip_quantile))
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
    mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    ax.set_facecolor(mappable.to_rgba(0.0))
    ctr = ax.contourf(xgrid_plot, ggrid_plot, grid_values, cmap=cmap, norm=norm, levels=100)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
    cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
    cbar.ax.set_xlabel(title, fontsize=fontsize)
    cbar.outline.set_edgecolor('grey')


    # fix strange aliasing artifact
    for c in ctr.collections:
        c.set_edgecolor("face")
        c.set_rasterized(True)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
save_folder     = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'
save_name       = 'div_v'
output_name     = f'{save_folder}/{save_name}.pdf'


make_div_v_single_pane(
    params=ema_params[0.9999],
    npts=64,
    nbatches=10,
    clip_quantile=0.0,
    output_name=output_name
)

# Density and stochastic trajectories

In [None]:
@jax.jit
def rollout_full_trajs(
    init_xgs: np.ndarray, # [ntrajs, 2, d]
    noises: np.ndarray,   # [ntrajs, nsteps, 2, d]
) -> np.ndarray:
    return jax.vmap(
        lambda init_xg, traj_noises: rollout(init_xg, traj_noises)[1]
    )(init_xgs, noises)

In [None]:
def make_density_plot(
    params: Dict[str, hk.Params],
    xs: np.ndarray,
    gs: np.ndarray,
    sde_trajs: np.ndarray,
    bins: int,
    shift_plot: bool,
    sde_skip: int,
    output_name: str
) -> None:
    plt.close('all')
    sns.set_palette('deep')

    if shift_plot:
        xs = (xs + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        min_x, max_x = -cfg.width, cfg.width

    min_g, max_g = gs.min(), gs.max()
    min_g        = min(-max_g, min_g)
    max_g        = max(-min_g, max_g)


    # compute histogram
    hist, xedges, gedges = np.histogram2d(np.squeeze(xs), np.squeeze(gs), bins=bins, density=True)
    xcenters             = xedges[:-1] + np.diff(xedges)
    gcenters             = gedges[:-1] + np.diff(gedges)
    xgrid, ggrid         = np.meshgrid(xcenters, gcenters, indexing='ij')


    # define figure
    fw       = 8.88
    fh       = 5
    fraction = 0.1
    shrink   = 0.75
    fontsize = 22.0
    cmap     = sns.color_palette('mako', as_cmap=True)
    fig, ax  = plt.subplots(figsize=(fw, fh), constrained_layout=True)

    ax.set_xlim([min_x, max_x])
    ax.set_ylim([min_g, max_g])
    ax.grid(which='both', axis='both', color='0.9', alpha=0.025)

    ax.set_xlabel(r"$x$", fontsize=fontsize)
    ax.set_ylabel(r"$g$", fontsize=fontsize)

    ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                  [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])

    ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
    ax.set_facecolor(sns.color_palette('mako', n_colors=150)[0])


    for spine in ax.spines.values():
        spine.set_edgecolor(None)


    ## plot histogram of samples
    vmin = float(hist.min())
    vmax = float(hist.max())
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    ctr  = ax.contourf(xgrid, ggrid, hist, cmap=cmap, norm=norm, levels=100)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, fraction=fraction, shrink=shrink, orientation='vertical')
    cbar.ax.tick_params(which='both', length=0, labelsize=0)
    cbar.set_ticklabels([])
    cbar.outline.set_edgecolor('grey')
    cbar.outline.set_linewidth(0.25)
    
    ## plot SDE trajectories
    ax.plot(sde_trajs[:, ::sde_skip, 0, 0].T, sde_trajs[:, ::sde_skip, 1, 0].T, alpha=0.2, color='white', lw=0.5)


    # fix weird rasterization
    for c in ctr.collections:
        c.set_edgecolor("face")
        c.set_rasterized(True)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())


def make_trajectory_plot(
    params: Dict[str, hk.Params],
    sde_trajs: np.ndarray,
    sde_skip: int,
    output_name: str
) -> None:
    plt.close('all')
    sns.set_palette('deep')
    
    min_x, max_x = 0.0, 2*cfg.width
    min_g, max_g = sde_trajs[:, :, 1].min(), sde_trajs[:, :, 1].max()
    min_g = min(min_g, -max_g)
    max_g = max(max_g, -min_g)
    

    # define figure
    fw       = 8
    fh       = 5
    fontsize = 22.0
    cmap     = sns.color_palette('mako', as_cmap=True)
    fig, ax  = plt.subplots(figsize=(fw, fh), constrained_layout=True)

    ax.set_xlim([min_x, max_x])
    ax.set_ylim([min_g, max_g])
    ax.grid(which='both', axis='both', color='0.9', alpha=0.025)
    ax.set_xlabel(r"$x$", fontsize=fontsize)
    ax.set_ylabel(r"$g$", fontsize=fontsize)
    ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                  [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
    ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
    ax.set_facecolor('black')

    for spine in ax.spines.values():
        spine.set_edgecolor(None)

    ## plot SDE trajectories
    ax.plot(sde_trajs[:, ::sde_skip, 0, 0].T, sde_trajs[:, ::sde_skip, 1, 0].T, alpha=0.2, color='white', lw=0.5)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())


def make_trajectory_and_density_plot(
    params: Dict[str, hk.Params],
    xs: np.ndarray,
    gs: np.ndarray,
    sde_trajs: np.ndarray,
    sde_skip: int,
    bins: int,
    min_g: float,
    max_g: float,
    output_name: str
) -> None:
    plt.close('all')
    sns.set_palette('deep')


    ## define combined figure
    fw        = 8
    fh        = 5.56
    fraction  = 0.1
    shrink    = 0.75
    fontsize  = 22.0
    cmap      = sns.color_palette('mako', as_cmap=True)
    fig, axs  = plt.subplots(nrows=1, ncols=2, figsize=(2*fw, fh), sharey=True, sharex=True, constrained_layout=True)


    ## common plot parameters
    for kk, ax in enumerate(axs.ravel()):
        ax.set_xlim([0.0, 2*cfg.width])
        ax.set_ylim([min_g, max_g])
        ax.grid(which='both', axis='both', color='0.9', alpha=0.025)
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        if kk == 0:
            ax.set_ylabel(r"$g$", fontsize=fontsize)
        ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                      [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        ax.set_yticks([-2, -1, 0, 1, 2])
        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        for spine in ax.spines.values():
            spine.set_edgecolor(None)


    ## compute histogram
    hist, xedges, gedges = np.histogram2d(np.squeeze(xs), np.squeeze(gs), bins=bins, density=True)
    xcenters             = xedges[:-1] + np.diff(xedges)
    gcenters             = gedges[:-1] + np.diff(gedges)
    xgrid, ggrid         = np.meshgrid(xcenters, gcenters, indexing='ij')


    ## plot histogram of samples
    axs[1].set_facecolor(sns.color_palette('mako', n_colors=150)[0])
    vmin = float(hist.min())
    vmax = float(hist.max())
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    ctr  = axs[1].contourf(xgrid, ggrid, hist, cmap=cmap, norm=norm, levels=100)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, fraction=fraction, shrink=shrink, orientation='horizontal')
    cbar.ax.tick_params(which='both', length=0, labelsize=0)
    cbar.ax.set_xlabel(r"$\rho$", fontsize=fontsize)
    cbar.set_ticklabels([])
    cbar.outline.set_edgecolor('grey')
    cbar.outline.set_linewidth(0.25)


    # fix weird rasterization
    for c in ctr.collections:
        c.set_edgecolor("face")
        c.set_rasterized(True)

    ## plot SDE trajectories
    axs[0].set_facecolor('k')
    axs[0].plot(sde_trajs[:, ::sde_skip, 0, 0].T, sde_trajs[:, ::sde_skip, 1, 0].T, alpha=0.4, color='white', lw=0.5)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
ntrajs_rollout = 50
inds           = onp.random.choice(onp.arange(xgs.shape[0]), size=ntrajs_rollout, replace=False)
nsteps         = 1000
init_xgs       = xgs[inds]
noises         = onp.random.randn(ntrajs_rollout, nsteps, 2, 1)
sde_trajs      = rollout_full_trajs(init_xgs, noises)

### Trajectories Alone

In [None]:
make_trajectory_plot(
    ema_params[0.9999], sde_trajs, sde_skip=1, output_name=''
)

### Density + Trajectories Overlaid

In [None]:
xs, gs      = np.split(xgs, indices_or_sections=2, axis=1)
bins        = 64
shift_plot  = True
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_10_23'
save_name   = 'density'
output_name = f'{save_folder}/{save_name}.pdf'


make_density_plot(
    ema_params[0.9999], xs, gs, sde_trajs,
    bins, shift_plot, sde_skip=3, output_name=output_name
)

## Density + Trajectories Two Panes

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'
save_name   = 'sde_trajs_and_density'
output_name = f'{save_folder}/{save_name}.pdf'


make_trajectory_and_density_plot(
    params=ema_params[0.9999],
    xs=xs,
    gs=gs,
    sde_trajs=sde_trajs,
    sde_skip=1,
    bins=64,
    min_g=-3.0,
    max_g=3.0,
    output_name=output_name
)

# Phase Portrait

In [None]:
def pflow_rhs(
    xg: np.ndarray, # [2, d]
    params: Dict[str, hk.Params]
) -> np.ndarray:
    x, g = np.split(xg, 2)
    xdot = twod_ring_simulation.calc_xdot(x, g) - cfg.eps*score_net.apply(params['x'], x, g, 'x')
    gdot = -cfg.gamma*g - cfg.gamma*score_net.apply(params['g'], x, g, 'g')
    return np.concatenate((xdot, gdot))


def step_pflow(
    xg: np.ndarray,
    params: Dict[str, hk.Params],
    dt: float,
) -> np.ndarray:
    x, g       = np.split(xg, 2)
    xdot, gdot = np.split(pflow_rhs(xg, params), 2)
    
    if cfg.shift_system:
        xnext = (x + dt*xdot) % (2*cfg.width)
    else:
        xnext = drifts.torus_project(x + dt*xdot, cfg.width)
    gnext      = g + dt*gdot
    return np.concatenate((xnext, gnext))


def rollout_traj_pflow(
    init_xg: np.ndarray, # [2, d]
    steps: np.ndarray,   # [nsteps]
    params: Dict[str, hk.Params],
    dt: float
) -> np.ndarray:
    def scan_fn(xg: np.ndarray, step: np.ndarray):
        xgnext = step_pflow(xg, params, dt)
        return xgnext, xgnext

    xg_final, xg_traj = jax.lax.scan(scan_fn, init_xg, steps)
    return xg_traj


rollout_trajs_pflow = jax.jit(jax.vmap(rollout_traj_pflow, in_axes=(0, None, None, None)))
map_pflow_rhs       = jax.jit(jax.vmap(jax.vmap(pflow_rhs, in_axes=(0, None)), in_axes=(0, None)))

In [None]:
def make_phase_portrait(
    params: Dict[str, hk.Params],
    pflow_trajs: np.ndarray,
    min_g: float,
    max_g: float,
    fw: float,
    fh: float,
    fontsize: float,
    arrow_skip_fac: int,
    nbatches: int,
    shift_plot: bool,
    plot_init: bool,
    plot_arrows: bool,
    plot_curves: bool,
    normalize_arrows: bool,
    output_name: str
) -> None:
    plt.close('all')
    sns.set_palette('deep')

    ## compute flow directions
    if plot_arrows:
        bs         = pflow_trajs.shape[0] // nbatches
        pflow_vals = onp.zeros_like(pflow_trajs)
        for curr_batch in range(nbatches):
            lb = curr_batch*bs
            ub = lb + bs
            pflow_vals[lb:ub] = map_pflow_rhs(pflow_trajs[lb:ub], params)

        if normalize_arrows:
            pflow_vals /= onp.linalg.norm(onp.squeeze(pflow_vals), axis=-1)[:, :, None, None]

    ## re-map the data
    if shift_plot:
        plot_trajs = onp.copy(pflow_trajs)
        plot_trajs[:, :, 0] = (plot_trajs[:, :, 0] + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        plot_trajs = onp.array(pflow_trajs)
        min_x, max_x = -cfg.width, cfg.width
        
    if cfg.shift_system:
        min_x, max_x = 0.0, 2*cfg.width

    ## set up figure
    fig, ax = plt.subplots(figsize=(fw, fh), constrained_layout=True)

    ## plot trajectories with arrows on top
    for curr_traj in range(ntrajs_pflow):
        # plot cycles
        if plot_curves:
            ax.plot(
                plot_trajs[curr_traj, :, 0], 
                plot_trajs[curr_traj, :, 1], 
                color='lavender', 
                alpha=0.35, 
                lw=0.2
            )
        
        if plot_init:
            ax.scatter(
                plot_trajs[:, 0, 0],
                plot_trajs[:, 0, 1],
                s=0.05,
                alpha=0.25,
                marker='x',
                color='white'
            )

        if plot_arrows:
            ax.quiver(
                plot_trajs[curr_traj, ::arrow_skip_fac, 0], 
                plot_trajs[curr_traj, ::arrow_skip_fac, 1], 
                pflow_vals[curr_traj, ::arrow_skip_fac, 0],
                pflow_vals[curr_traj, ::arrow_skip_fac, 1],
                color='lavender', 
                alpha=0.75,
                scale=100,
#                 scale=25,
                headwidth=3,
                width=5e-4,
                minlength=0
#                 headlength=10
            )

    ax.grid(which='both', axis='both', color='0.9', alpha=0.05)
    ax.set_facecolor('k')
    ax.set_xlabel(r"$x$", fontsize=fontsize)
    ax.set_ylabel(r"$g$", fontsize=fontsize)
    ax.set_xlim([min_x, max_x])
    ax.set_ylim([min_g, max_g])
    
    if shift_plot or cfg.shift_system:
        ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                      [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
    else:
        ax.set_xticks([     -8.0,      -4.0, 0.0,      4.0,      8.0],
                      [r"$-L/2$", r"$-L/4$", 0.0, r"$L/4$", r"$L/2$"])
        
    ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
    
    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

### spring force

In [None]:
### set up initial conditions for phase portrait
ntrajs_pflow = 150

## pick some specific initial conditions to map out the phase portrait
lb_x, ub_x = 1.5, 6.5
lb_g, ub_g = 0.0, 0.0
inits                        = onp.zeros((ntrajs_pflow, 2, cfg.d))
inits[:ntrajs_pflow // 2, 0] = onp.linspace( lb_x,  ub_x, ntrajs_pflow // 2).reshape((-1, cfg.d))
inits[:ntrajs_pflow // 2, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow // 2).reshape((-1, cfg.d))
inits[ntrajs_pflow // 2:, 0] = onp.linspace(-ub_x, -lb_x, ntrajs_pflow // 2).reshape((-1, cfg.d))
inits[ntrajs_pflow // 2:, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow // 2).reshape((-1, cfg.d))

## roll out trajectories
dt_pflow       = 1e-3
rollout_fac    = 20.0
nsteps_pflow   = int((rollout_fac / cfg.gamma) / dt_pflow + 0.5)
pflow_trajs    = onp.array(rollout_trajs_pflow(inits, np.arange(nsteps_pflow), ema_params[0.9999], dt_pflow))

## compute bounds
xs, gs         = np.split(xgs, indices_or_sections=2, axis=1)
min_g, max_g   = gs.min(), gs.max()
min_g          = min(min_g, -max_g)
max_g          = max(max_g, -min_g)

### no spring, shifted system

In [None]:
### set up initial conditions for phase portrait
ntrajs_pflow = 40

## symmetric line
inits      = onp.zeros((ntrajs_pflow, 2, cfg.d))
lb_x, ub_x = cfg.width/3, 2*cfg.width/3
lb_g, ub_g = -1.5, 1.5

inits[:ntrajs_pflow//2, 0] = onp.linspace( lb_x,  ub_x, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[:ntrajs_pflow//2, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[ntrajs_pflow//2:, 0] = onp.linspace(lb_x+cfg.width, ub_x+cfg.width, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[ntrajs_pflow//2:, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow//2).reshape((-1, cfg.d))


## roll out trajectories
dt_pflow     = 1e-2
rollout_fac  = 50.0
nsteps_pflow = int((rollout_fac / cfg.gamma) / np.abs(dt_pflow) + 0.5)
pflow_trajs  = onp.array(rollout_trajs_pflow(inits, np.arange(nsteps_pflow), ema_params[0.9999], dt_pflow))


## compute bounds
pflow_xs, pflow_gs = np.split(pflow_trajs, indices_or_sections=2, axis=2)
min_g, max_g       = pflow_gs.min(), pflow_gs.max()
min_g              = min(min_g, -max_g)
max_g              = max(max_g, -min_g)

### phase portrait alone

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_10_23'
save_name   = 'phase_portrait'
output_name = f'{save_folder}/{save_name}.pdf'


make_phase_portrait(
    ema_params[0.9999],
    pflow_trajs,
    min_g=-2.5,
    max_g=2.5,
    fw=8,
    fh=5,
    fontsize=22.0,
    arrow_skip_fac=500,
    nbatches=1,
    shift_plot=False,
    plot_init=False,
    plot_arrows=True,
    normalize_arrows=True,
    plot_curves=True,
    output_name=output_name
)

### phase portrait + SDE trajs

In [None]:
def make_phase_portrait_and_SDE_trajs(
    params: Dict[str, hk.Params],
    pflow_trajs: np.ndarray,
    sde_trajs: np.ndarray,
    sde_skip: int,
    min_g: float,
    max_g: float,
    fw: float,
    fh: float,
    fontsize: float,
    arrow_skip_fac: int,
    nbatches: int,
    shift_plot: bool,
    plot_init: bool,
    plot_arrows: bool,
    plot_curves: bool,
    normalize_arrows: bool,
    output_name: str
) -> None:
    plt.close('all')
    sns.set_palette('deep')

    ## compute flow directions
    if plot_arrows:
        bs         = pflow_trajs.shape[0] // nbatches
        pflow_vals = onp.zeros_like(pflow_trajs)
        for curr_batch in range(nbatches):
            lb = curr_batch*bs
            ub = lb + bs
            pflow_vals[lb:ub] = map_pflow_rhs(pflow_trajs[lb:ub], params)

        if normalize_arrows:
            pflow_vals /= onp.linalg.norm(onp.squeeze(pflow_vals), axis=-1)[:, :, None, None]

    ## re-map the data
    if shift_plot:
        plot_trajs = onp.copy(pflow_trajs)
        plot_trajs[:, :, 0] = (plot_trajs[:, :, 0] + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        plot_trajs = onp.array(pflow_trajs)
        min_x, max_x = -cfg.width, cfg.width
        
    if cfg.shift_system:
        min_x, max_x = 0.0, 2*cfg.width

    ## set up figure
    fig, axs = plt.subplots(figsize=(2*fw, fh), nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True)

    ### phase portrait
    ## plot trajectories with arrows on top
    for curr_traj in range(ntrajs_pflow):
        # plot cycles
        if plot_curves:
            axs[0].plot(
                plot_trajs[curr_traj, :, 0], 
                plot_trajs[curr_traj, :, 1], 
                color='lavender', 
                alpha=0.35, 
                lw=0.2
            )
        
        if plot_init:
            axs[0].scatter(
                plot_trajs[:, 0, 0],
                plot_trajs[:, 0, 1],
                s=0.05,
                alpha=0.25,
                marker='x',
                color='white'
            )

        if plot_arrows:
            axs[0].quiver(
                plot_trajs[curr_traj, ::arrow_skip_fac, 0], 
                plot_trajs[curr_traj, ::arrow_skip_fac, 1], 
                pflow_vals[curr_traj, ::arrow_skip_fac, 0],
                pflow_vals[curr_traj, ::arrow_skip_fac, 1],
                color='lavender', 
                alpha=0.75,
                scale=100,
#                 scale=25,
                headwidth=3,
                width=5e-4,
                minlength=0
#                 headlength=10
            )
            
    ### SDE trajs
    axs[1].plot(sde_trajs[:, ::sde_skip, 0, 0].T, sde_trajs[:, ::sde_skip, 1, 0].T, alpha=0.4, color='white', lw=0.5)
            
    for kk, ax in enumerate(axs):
        ax.grid(which='both', axis='both', color='0.9', alpha=0.15)
        ax.set_facecolor('k')
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        
        if kk == 0:
            ax.set_ylabel(r"$g$", fontsize=fontsize)

        ax.set_xlim([min_x, max_x])
        ax.set_ylim([min_g, max_g])

        if shift_plot or cfg.shift_system:
            ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        else:
            ax.set_xticks([     -8.0,      -4.0, 0.0,      4.0,      8.0],
                          [r"$-L/2$", r"$-L/4$", 0.0, r"$L/4$", r"$L/2$"])

        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
        
        for spine in ax.spines.values():
            spine.set_edgecolor(None)
    
    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
ntrajs_rollout = 50
inds           = onp.random.choice(onp.arange(xgs.shape[0]), size=ntrajs_rollout, replace=False)
nsteps         = 1000
init_xgs       = xgs[inds]
noises         = onp.random.randn(ntrajs_rollout, nsteps, 2, 1)
sde_trajs      = rollout_full_trajs(init_xgs, noises)

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_10_23'
save_name   = 'phase_portrait_and_SDE_trajs'
output_name = f'{save_folder}/{save_name}.pdf'


make_phase_portrait_and_SDE_trajs(
    ema_params[0.9999],
    pflow_trajs,
    sde_trajs,
    sde_skip=1,
    min_g=-2.5,
    max_g=2.5,
    fw=8,
    fh=5,
    fontsize=22.0,
    arrow_skip_fac=500,
    nbatches=1,
    shift_plot=False,
    plot_init=False,
    plot_arrows=True,
    normalize_arrows=True,
    plot_curves=True,
    output_name=output_name
)

### phase portrait + current

In [None]:
def make_phase_portrait_and_current_plot(
    params: Dict[str, hk.Params],
    pflow_trajs: np.ndarray,
    bins: int,
    clip_quantile: float,
    min_g: float,
    max_g: float,
    fw: float,
    fh: float,
    fontsize: float,
    arrow_skip_fac: int,
    nbatches: int,
    shift_plot: bool,
    plot_init: bool,
    plot_arrows: bool,
    plot_curves: bool,
    normalize_arrows: bool,
    output_name: str
) -> None:
    ## set up figure
    plt.close('all')
    sns.set_palette('deep')
    fig, axs = plt.subplots(figsize=(2*fw, fh), nrows=1, ncols=2, sharex=True, sharey=True, constrained_layout=True)
    cmap = sns.color_palette('mako', as_cmap=True)
    fraction  = 0.2
    shrink    = 0.5


    ####### current plot
    ### compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    xs         = onp.zeros(cfg.ntrajs)
    gs         = onp.zeros(cfg.ntrajs)
    vs         = onp.zeros((cfg.ntrajs, 2*cfg.d))

    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size
        xs[lb:ub], gs[lb:ub], _, _, _, _, _, \
            vs[lb:ub], _, _, _, _, _, _ = \
                compute_batch_output_info(data['xgs'][lb:ub], params)

    ### bin the data
    current, xedges, gedges = np.histogram2d(xs, gs, weights=np.linalg.norm(vs, axis=-1)/cfg.gamma, bins=bins, density=True)
    xgrid, ggrid = np.meshgrid(xedges, gedges, indexing='ij')
    xgrid_plot   = xgrid[:-1, :-1] + 0.5*np.diff(xedges)[:, None]
    ggrid_plot   = ggrid[:-1, :-1] + 0.5*np.diff(gedges)[None, :]


    ## plot the current
    vmin = onp.quantile(current, q=clip_quantile)
    vmax = onp.quantile(current, q=(1-clip_quantile))
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
    mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    mappable.set_array([])
    axs[0].set_facecolor(mappable.to_rgba(0.0))
    ctr  = axs[0].contourf(xgrid_plot, ggrid_plot, current, cmap=cmap, norm=norm, levels=100)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs[0], 
                        fraction=fraction, shrink=shrink, orientation='horizontal')
    cbar.ax.tick_params(which='both', labelsize=fontsize, width=0, length=0)
    cbar.ax.set_xlabel(r"$|j|_{D^{-1}}$", fontsize=fontsize)
    cbar.outline.set_edgecolor('grey')
    cbar.outline.set_linewidth(0.25)

    # fix strange aliasing artifact
    for c in ctr.collections:
        c.set_edgecolor("face")
        c.set_rasterized(True)


    ###### phase portrait
    ## compute flow directions
    if plot_arrows:
        bs         = pflow_trajs.shape[0] // nbatches
        pflow_vals = onp.zeros_like(pflow_trajs)
        for curr_batch in range(nbatches):
            lb = curr_batch*bs
            ub = lb + bs
            pflow_vals[lb:ub] = map_pflow_rhs(pflow_trajs[lb:ub], params)

        if normalize_arrows:
            pflow_vals /= onp.linalg.norm(onp.squeeze(pflow_vals), axis=-1)[:, :, None, None]

    ## re-map the data
    if shift_plot:
        plot_trajs = onp.copy(pflow_trajs)
        plot_trajs[:, :, 0] = (plot_trajs[:, :, 0] + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        plot_trajs = onp.array(pflow_trajs)
        min_x, max_x = -cfg.width, cfg.width
        
    if cfg.shift_system:
        min_x, max_x = 0.0, 2*cfg.width
    
    ## plot trajectories with arrows on top
    for curr_traj in range(ntrajs_pflow):
        # plot cycles
        if plot_curves:
            axs[1].plot(
                plot_trajs[curr_traj, :, 0], 
                plot_trajs[curr_traj, :, 1], 
                color='lavender', 
                alpha=0.35, 
                lw=0.2
            )
        
        if plot_init:
            axs[1].scatter(
                plot_trajs[:, 0, 0],
                plot_trajs[:, 0, 1],
                s=0.05,
                alpha=0.25,
                marker='x',
                color='white'
            )

        if plot_arrows:
            axs[1].quiver(
                plot_trajs[curr_traj, ::arrow_skip_fac, 0], 
                plot_trajs[curr_traj, ::arrow_skip_fac, 1], 
                pflow_vals[curr_traj, ::arrow_skip_fac, 0],
                pflow_vals[curr_traj, ::arrow_skip_fac, 1],
                color='lavender', 
                alpha=0.75,
                scale=100,
                headwidth=3,
                width=5e-4,
                minlength=0
            )

    for kk, ax in enumerate(axs):
        ax.grid(which='both', axis='both', color='0.9', alpha=0.15)
        ax.set_facecolor('k')
        ax.set_xlabel(r"$x$", fontsize=fontsize)

        if kk == 0:
            ax.set_ylabel(r"$g$", fontsize=fontsize)

        ax.set_xlim([min_x, max_x])
        ax.set_ylim([min_g, max_g])

        if shift_plot or cfg.shift_system:
            ax.set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
        else:
            ax.set_xticks([     -8.0,      -4.0, 0.0,      4.0,      8.0],
                          [r"$-L/2$", r"$-L/4$", 0.0, r"$L/4$", r"$L/2$"])

        ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)

        for spine in ax.spines.values():
            spine.set_edgecolor(None)


    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_16_23'
save_name   = 'phase_portrait_and_current'
output_name = f'{save_folder}/{save_name}.pdf'


make_phase_portrait_and_current_plot(
    params=ema_params[0.9999],
    pflow_trajs=pflow_trajs,
    bins=64,
    clip_quantile=0.0,
    min_g=-3.0,
    max_g=3.0,
    fw=8,
    fh=5.556,
    fontsize=22,
    arrow_skip_fac=500,
    nbatches=10,
    shift_plot=False,
    plot_init=False,
    plot_arrows=True,
    normalize_arrows=True,
    plot_curves=True,
    output_name=output_name
)

# Probability Flow Movie

### Movie Initial Conditions

In [None]:
### set up initial conditions for movie
ntrajs_pflow = 9

## symmetric line
inits       = onp.zeros((ntrajs_pflow, 2, cfg.d))
lb_x, ub_x  = cfg.width/3, 2*cfg.width/3
lb_g, ub_g  = -1.5, 1.5

inits[:ntrajs_pflow, 0] = onp.linspace( lb_x,  ub_x, ntrajs_pflow).reshape((-1, cfg.d))
inits[:ntrajs_pflow, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow).reshape((-1, cfg.d))


## roll out trajectories
dt_pflow       = 1e-2
rollout_fac    = 50.0
nsteps_pflow   = int((rollout_fac / cfg.gamma) / np.abs(dt_pflow) + 0.5)
pflow_trajs    = onp.array(rollout_trajs_pflow(inits, np.arange(nsteps_pflow), ema_params[0.9999], dt_pflow))

In [None]:
def map_to_circle(x: float, g: float):
    theta = np.pi*x/cfg.width
    rad   = cfg.width/np.pi
    
    return rad*np.cos(theta), rad*np.sin(theta), -g*np.sin(theta), g*np.cos(theta)


def make_movie(
    trajs: np.ndarray, # [ntrajs, nsteps, 2, 1]
    save_str: str,
    dt: float,
    nframes: int = 250,
    particle_size: float = 1250
) -> None:
    plt.close('all')
    sns.set_palette('deep')
    fw, fh   = 4, 4
    fontsize = 20
    nrows    = int(onp.sqrt(trajs.shape[0]))
    ncols    = nrows
    nsteps   = trajs.shape[1]
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows), sharex=True, sharey=True, constrained_layout=True)


    ## set axis limits, grids, and particle scaling
    rad = cfg.width/np.pi
    thetas = np.linspace(0, 2*np.pi, 100)
    for ii in range(nrows):
        for jj in range(ncols):
            ax = axs[ii, jj]
            ax.set_xlim([-1.5*rad, 1.5*rad])
            ax.set_ylim([-1.5*rad, 1.5*rad])
            ax.grid(which='both', axis='both', color='0.90', alpha=0.2)
            ax.tick_params(which='both', width=0, length=0, labelsize=fontsize)
            ax.set_facecolor('k')
            ax.plot(cfg.width/np.pi*np.cos(thetas), cfg.width/np.pi*np.sin(thetas), color='gray', lw=1, alpha=0.5, zorder=0)
            ax.axes.set_aspect(1.0)


            if jj == 0:
                ax.set_ylabel(r"$y$", fontsize=fontsize)

            if ii == nrows-1:
                ax.set_xlabel(r"$x$", fontsize=fontsize)


    ## set up initial scatter plots
    scats = []
    for ii in range(nrows):
        for jj in range(ncols):
            ax = axs[ii, jj]
            index = jj + ii*ncols
            point = trajs[index, 0]
            
            # plot moving paerticle
            x, y, _, _ = map_to_circle(point[0], point[1])
            scats.append(ax.scatter(x, y, s=particle_size, marker='o', color='C4', alpha=0.75, zorder=1))
            
            # plot fixed particle
            x, y, _, _ = map_to_circle(0.0, 0.0)
            ax.scatter(x, y, s=particle_size, marker='o', color='C4', alpha=0.75, zorder=1)
            
            
    fig.suptitle(rf"$t=${0:0.3f}", fontsize=fontsize)


    def init():
        return scats


    def animate(frame: int):
        ## remove scatter for speed
        for path_collection in scats:
            path_collection.remove()


        ## update scatter plot
        t = frame*dt*step
        
        for ii in range(nrows):
            for jj in range(ncols):
                ax = axs[ii, jj]
                index = jj + ii*ncols
                point = trajs[index, frame]
                x, y, _, _ = map_to_circle(point[0], point[1])
                scats[index] = ax.scatter(x, y, s=particle_size, marker='o', color='C4', alpha=0.75, zorder=1)
                
                if jj == 0:
                    ax.set_ylabel(r"$y$", fontsize=fontsize)

                if ii == nrows-1:
                    ax.set_xlabel(r"$x$", fontsize=fontsize)

        fig.suptitle(rf"$t=${dt*frame:0.3f}", fontsize=fontsize)

        if frame % (stop // 10) == 0:
            print(f'Finished animation on frame {frame}/{stop}')

        return scats

    start, stop, step = 0, nsteps, nsteps // nframes
    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=onp.arange(start, stop, step), 
                                   interval=0.1, blit=True, repeat=True, cache_frame_data=True)

    save_name = f'{save_str}.mp4'
    anim.save(save_name, fps=60)
    plt.show()

In [None]:
# save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/8_16_23'
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_5_23'
save_name   = 'pflow_movie'
output_name = f'{save_folder}/{save_name}'

make_movie(
    pflow_trajs,
    save_str=output_name,
    dt=dt_pflow,
    nframes=250,
    particle_size=4250
)

# Combined density + phase portrait

In [None]:
def make_density_phase_portrait_plot(
    params: Dict[str, hk.Params],
    pflow_trajs: np.ndarray,
    xs: np.ndarray,
    gs: np.ndarray,
    min_g: float,
    max_g: float,
    bins: int,
    narrows: int,
    arrow_skip_fac: int,
    nbatches: int,
    plot_init: bool,
    plot_arrows: bool,
    plot_curves: bool,
    shift_plot: bool,
    output_name: str,
    horizontal: bool
) -> None:
    plt.close('all')
    sns.set_palette('deep')
    fw       = 8.0
    fh       = 5
    fraction = 0.1
    shrink   = 0.75
    fontsize = 20
    
    ####### density plot    
    if shift_plot:
        xs = (xs + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        min_x, max_x = -cfg.width, cfg.width


    # compute histogram
    hist, xedges, gedges = np.histogram2d(np.squeeze(xs), np.squeeze(gs), bins=bins, density=True)
    xcenters             = xedges[:-1] + np.diff(xedges)
    gcenters             = gedges[:-1] + np.diff(gedges)
    xgrid, ggrid         = np.meshgrid(xcenters, gcenters, indexing='ij')
    

    # compute scores
    scores = {}
    inds   = onp.random.choice(onp.arange(cfg.ntrajs), size=narrows, replace=False)
    for key in params.keys():
        scores[key] = onp.squeeze(map_score_net(params[key], xs[inds], gs[inds], key))
        scores[key] /= np.abs(scores[key])


    # define figure
    cmap     = sns.color_palette('mako', as_cmap=True)
    if horizontal:
        fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(2*fw, fh), constrained_layout=True, sharey=True)
    else:
        fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(fw, 1.25*fh), constrained_layout=True, sharex=True)
        
    axs[1].set_xlim([min_x, max_x])
    axs[1].set_ylim([min_g, max_g])
    axs[1].grid(which='both', axis='both', color='0.9', alpha=0.025)
    axs[1].set_xlabel(r"$x$", fontsize=fontsize)
    if not horizontal:
        axs[1].set_ylabel(r"$g$", fontsize=fontsize)
    axs[1].set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                      [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])
    axs[1].tick_params(which='both', width=0, length=0, labelsize=fontsize)
    axs[1].set_facecolor(sns.color_palette('mako', n_colors=150)[0])


    # plot histogram of samples
    vmin = float(hist.min())
    vmax = float(hist.max())
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    ctr  = axs[1].contourf(xgrid, ggrid, hist, cmap=cmap, norm=norm, levels=100)
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs[1], fraction=fraction, shrink=shrink, orientation='vertical' if horizontal else 'horizontal')
    cbar.ax.tick_params(which='both', length=0, labelsize=fontsize)
    cbar.set_ticklabels([])
    cbar.outline.set_edgecolor('grey')
    
    # fix weird rasterization
    for c in ctr.collections:
        c.set_edgecolor("face")
        c.set_rasterized(True)


    # plot quiver on top
    axs[1].quiver(xs[inds], gs[inds], scores['x'], scores['g'], color='white', alpha=0.1, scale=2e2)


    ##### phase portrait
    ## compute flow directions
    bs         = pflow_trajs.shape[0] // nbatches
    pflow_vals = onp.zeros_like(pflow_trajs)
    for curr_batch in range(nbatches):
        lb = curr_batch*bs
        ub = lb + bs
        pflow_vals[lb:ub] = map_pflow_rhs(pflow_trajs[lb:ub], params)
    pflow_vals /= onp.linalg.norm(onp.squeeze(pflow_vals), axis=-1)[:, :, None, None]

    ## re-map the data
    if shift_plot:
        plot_trajs = onp.copy(pflow_trajs)
        plot_trajs[:, :, 0] = (plot_trajs[:, :, 0] + 2*cfg.width) % (2*cfg.width)
        min_x, max_x = 0, 2*cfg.width
    else:
        plot_trajs = onp.array(pflow_trajs)

    ## plot trajectories with arrows on top
    for curr_traj in range(ntrajs_pflow):
        # plot cycles
        if plot_curves:
            axs[0].plot(
                plot_trajs[curr_traj, :, 0], 
                plot_trajs[curr_traj, :, 1], 
                color='lavender', 
                alpha=0.25, 
                lw=0.2
            )
        
        if plot_init:
            axs[0].scatter(
                plot_trajs[:, 0, 0],
                plot_trajs[:, 0, 1],
                s=0.05,
                alpha=0.25,
                marker='x',
                color='white'
            )

        if plot_arrows:
            axs[0].quiver(
                plot_trajs[curr_traj, ::arrow_skip_fac, 0], 
                plot_trajs[curr_traj, ::arrow_skip_fac, 1], 
                pflow_vals[curr_traj, ::arrow_skip_fac, 0],
                pflow_vals[curr_traj, ::arrow_skip_fac, 1],
                color='lavender', 
                alpha=0.25,
                scale=100,
                headwidth=3,
                width=5e-4,
                minlength=0
            )

    axs[0].grid(which='both', axis='both', color='0.9', alpha=0.05)
    axs[0].set_facecolor('k')
    axs[0].set_ylabel(r"$g$", fontsize=fontsize)
    axs[0].set_xlim([min_x, max_x])
    axs[0].set_ylim([min_g, max_g])
    axs[0].tick_params(which='both', width=0, length=0, labelsize=fontsize)
    
    if horizontal:
        axs[0].set_xlabel(r"$x$", fontsize=fontsize)
        axs[0].set_xticks([0.0,      4.0,      8.0,      12.0,   16.0],
                          [  0, r"$L/4$", r"$L/2$", r"$3L/4$", r"$L$"])

    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

### probability flow samples for combined plot

In [None]:
### set up initial conditions for phase portrait
ntrajs_pflow = 80

## symmetric line
inits      = onp.zeros((ntrajs_pflow, 2, cfg.d))
lb_x, ub_x = cfg.width/3, 2*cfg.width/3
lb_g, ub_g = -1.5, 1.5

inits[:ntrajs_pflow//2, 0] = onp.linspace( lb_x,  ub_x, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[:ntrajs_pflow//2, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[ntrajs_pflow//2:, 0] = onp.linspace(lb_x+cfg.width, ub_x+cfg.width, ntrajs_pflow//2).reshape((-1, cfg.d))
inits[ntrajs_pflow//2:, 1] = onp.linspace( lb_g,  ub_g, ntrajs_pflow//2).reshape((-1, cfg.d))


## roll out trajectories
dt_pflow     = 1e-2
rollout_fac  = 50.0
nsteps_pflow = int((rollout_fac / cfg.gamma) / np.abs(dt_pflow) + 0.5)
pflow_trajs  = onp.array(rollout_trajs_pflow(inits, np.arange(nsteps_pflow), ema_params[0.9999], dt_pflow))


## compute bounds
pflow_xs, pflow_gs = np.split(pflow_trajs, indices_or_sections=2, axis=2)
min_g, max_g       = pflow_gs.min(), pflow_gs.max()
min_g              = min(min_g, -max_g)
max_g              = max(max_g, -min_g)

In [None]:
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_5_23'
save_name   = 'density_phase_portrait'
output_name = f'{save_folder}/{save_name}.pdf'


make_density_phase_portrait_plot(
    ema_params[0.9999],
    pflow_trajs,
    xs,
    gs,
    min_g=-2.75,
    max_g=2.75,
    bins=64,
    narrows=0,
    arrow_skip_fac=250,
    nbatches=10,
    plot_init=False,
    plot_arrows=True,
    plot_curves=True,
    shift_plot=True,
    output_name=output_name,
    horizontal=True
)

# Entropy Histogram

In [None]:
batch_calc_divs = jax.jit(jax.vmap(calc_divs, in_axes=(None, 0, 0)))

def make_entropy_histogram(
    params: Dict[str, hk.Params],
    xs: np.ndarray,
    gs: np.ndarray,
    bins: int,
    nbatches: int,
    output_name: str
) -> None:
    ## compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    div_vs     = onp.zeros((cfg.ntrajs, 1))
    div_vxs    = onp.zeros((cfg.ntrajs, 1))
    div_vgs    = onp.zeros((cfg.ntrajs, 1))


    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size    
        _, _, div_vxs[lb:ub], div_vgs[lb:ub], div_vs[lb:ub] = batch_calc_divs(params, xs[lb:ub], gs[lb:ub])
        
    
    ## make the figure
    plt.close('all')
    sns.set_style('whitegrid')
    sns.set_palette("Set2")
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    fw, fh   = 4, 4
    fontsize = 22
    fig, axs = plt.subplots(nrows=1, ncols=3, sharey=True, constrained_layout=True, figsize=(3*fw, fh))
    vals     = [div_vxs,                           div_vgs,             div_vs]
    titles   = [r"$\nabla\cdot v_x$", r"$\nabla\cdot v_g$", r"$\nabla\cdot v$"]


    for kk, ax in enumerate(axs):
        ax.grid(which='both', axis='both', color='0.75', alpha=0.5)
        ax.tick_params(axis='both', labelsize=fontsize)
        ax.hist(vals[kk], bins=bins, color=f'C{kk}')
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.set_yscale('log')


        mean, std = np.mean(vals[kk]), np.std(vals[kk])
        ax.axvline(mean,     color='k',          alpha=0.25)
        ax.text(0.1, 0.8, rf"$\mu={mean:.03f}$",   transform=ax.transAxes, fontsize=fontsize)
        ax.text(0.1, 0.6, rf"$\sigma={std:.03f}$", transform=ax.transAxes, fontsize=fontsize)


        if kk == 0:
            ax.set_ylabel('count', fontsize=fontsize)
            
    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
xs, gs = np.split(xgs, indices_or_sections=2, axis=1)
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_5_23'
save_name   = 'entropy_hist'
output_name = f'{save_folder}/{save_name}.pdf'

make_entropy_histogram(ema_params[0.9999], xs, gs, nbatches=10, bins=75, output_name=output_name)

# Statistics Histogram

In [None]:
@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def compute_batch_statistics(
    xg: np.ndarray,
    params: Dict[str, hk.Params]
) -> Tuple:
    """Compute the quantitative verification metrics over a batch."""
    x, g  = np.split(xg, 2)                              # ([1, 1], [1, 1])
    sx    = score_net.apply(params['x'], x, g, 'x')     # [1, 1]
    sg    = score_net.apply(params['g'], x, g, 'g')     # [1, 1]
    s     = np.hstack((sx, sg))                           # [1, 2]
    v     = calc_vs(xg, sx, sg)[-1]                       # [1, 2]
    div_sx, div_sg, _, _, div_v = calc_divs(params, x, g) # ([1], [1], [1])

    v_times_s = np.sum(v*s)
    pinn      = div_v + v_times_s
    ibp_x     = np.sum(sx**2) + div_sx
    ibp_g     = np.sum(sg**2) + div_sg

    return pinn, v_times_s, ibp_x, ibp_g

In [None]:
def make_convergence_histogram(
    params: Dict[str, hk.Params],
    xgs: np.ndarray,
    bins: int,
    nbatches: int,
    output_name: str
) -> None:
    ## compute quantities needed for plotting
    batch_size = cfg.ntrajs // nbatches
    pinn       = onp.zeros((cfg.ntrajs, 1))
    v_times_s  = onp.zeros((cfg.ntrajs))
    ibp_x      = onp.zeros((cfg.ntrajs, 1))
    ibp_g      = onp.zeros((cfg.ntrajs, 1))


    for curr_batch in range(nbatches):
        lb = batch_size*curr_batch
        ub = lb + batch_size
        pinn[lb:ub], v_times_s[lb:ub], ibp_x[lb:ub], ibp_g[lb:ub] = compute_batch_statistics(xgs[lb:ub], params)
        
    
    ## make the figure
    plt.close('all')
    sns.set_style('whitegrid')
    sns.set_palette("Set2")
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    fw, fh   = 4, 4
    fontsize = 22
    fig, axs = plt.subplots(nrows=1, ncols=4, sharey=True, constrained_layout=True, figsize=(3*fw, fh))
    vals     = [                         pinn,      v_times_s,                            ibp_x,                            ibp_g]
    titles   = [r"$\nabla\cdot v + v\cdot h$", r"$v \cdot h$", r"$|h_x|^2 + \nabla_x\cdot h_x$", r"$|h_g|^2 + \nabla_g\cdot h_g$"]


    for kk, ax in enumerate(axs):
        ax.grid(which='both', axis='both', color='0.75', alpha=0.5)
        ax.tick_params(axis='both', labelsize=fontsize)
        ax.hist(vals[kk], bins=bins, color=f'C{kk}')
        ax.set_title(titles[kk], fontsize=fontsize)
        ax.set_yscale('log')


        mean, std = np.mean(vals[kk]), np.std(vals[kk])
        ax.axvline(mean,     color='k',          alpha=0.25)
        ax.text(0.35, 0.8, rf"$\mu={mean:.03f}$",   transform=ax.transAxes, fontsize=fontsize)
        ax.text(0.35, 0.6, rf"$\sigma={std:.03f}$", transform=ax.transAxes, fontsize=fontsize)


        if kk == 0:
            ax.set_ylabel('count', fontsize=fontsize)
            
    if output_name != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(output_name, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())

In [None]:
xs, gs = np.split(xgs, indices_or_sections=2, axis=1)
save_folder = '/scratch/nb3397/results/mips/2d_ring/figures/9_20_23'
save_name   = 'quantitative_stats_hist'
output_name = f'{save_folder}/{save_name}.pdf'

make_convergence_histogram(ema_params[0.9999], xgs, nbatches=10, bins=75, output_name=output_name)