In [None]:
import jax
from jax import numpy as np
from jax import vmap
import numpy as onp
import sys
sys.path.append('../py')
%load_ext autoreload
%autoreload 24


import common.drifts as drifts
import common.networks as networks
import haiku as hk
from typing import Tuple, Dict, Callable
import functools
import time
import glob
from mpl_toolkits.axes_grid1 import make_axes_locatable


import matplotlib as mpl
import matplotlib.animation as animation
import matplotlib.gridspec as gridspec
from matplotlib import pyplot as plt
import seaborn as sns


mpl.rcParams['axes.grid']           = True
mpl.rcParams['axes.grid.which']     = 'both'
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['ytick.minor.visible'] = True
mpl.rcParams['xtick.minor.visible'] = True
mpl.rcParams['axes.facecolor']      = 'white'
mpl.rcParams['grid.color']          = '0.8'
mpl.rcParams['grid.alpha']          = '0.1'
mpl.rcParams['text.usetex']         = True
mpl.rcParams['font.family']         = 'serif'
mpl.rcParams['figure.figsize']      = (8, 4)
mpl.rcParams['figure.titlesize']    = 7.5
mpl.rcParams['font.size']           = 10
mpl.rcParams['legend.fontsize']     = 7.5
mpl.rcParams['figure.dpi']          = 300


from tqdm.auto import tqdm
import dill as pickle

## Load Data

In [None]:
base_folder  = '/scratch/nb3397/results/mips/hypoelliptic'
dataset_file = '/scratch/nb3397/projects/mips_project/dataset/big_dataset_5_23_23/combined_dataset.npy'


## EMA transformer
output_name = '8_22_23/hypoelliptic_mips_sweep_trans_ema_8_22_23_2'
load_name   = sorted(glob.glob(f'{base_folder}/{output_name}*'), key=lambda x: float(x.split('_')[-1].split('.')[0]))[-1] # pick off latest output


## load it up
data_dict    = pickle.load(open(load_name, 'rb'))
config       = data_dict['config']


try:
    params = data_dict['params_list'][-1]
except:
    params = data_dict['params']
    ema_params = data_dict['ema_params']


dataset_dict = pickle.load(open(dataset_file, 'rb'))
width        = dataset_dict['width']
N            = dataset_dict['N']
d            = dataset_dict['d']
xgs_data     = np.array(dataset_dict['trajs']['SDE'].reshape((-1, 2*N, d)))

## Assemble Network

In [None]:
if 'transformer' in config['network_type']:
    print('Constructing transformer.')

    particle_score_net, particle_div_net = \
        networks.define_transformer_networks(
            w0=config['w0'],
            n_neighbors=config['n_neighbors'],
            num_layers=config['num_layers'],
            embed_dim=config['embed_dim'],
            embed_n_hidden=config['embed_n_hidden'],
            embed_n_neurons=config['embed_n_neurons'],
            num_heads=config['num_heads'],
            dim_feedforward=config['dim_feedforward'],
            n_layers_feedforward=1,
            width=config['width'],
            network_type=config['network_type'],
            this_particle_pooling=config['this_particle_pooling'],
            scale_fac=config['scale_fac']
        )
else:
    print('Constructing two-particle network.')

    particle_potential_net, particle_score_net, \
            particle_div_net, two_particle_net = \
                networks.define_two_particle_networks(
                    config['n_neighbors'],
                    config['d'],
                    config['width'],
                    config['embed_n_hidden'],
                    config['embed_n_neurons'],
                    config['w0'],
                    jax.nn.gelu,
                    config['network_type'],
                )

## Utility Functions

In [None]:
map_particle_score = jax.jit(
    jax.vmap(particle_score_net.apply, in_axes=(None, None, 0))
)

map_particle_div = jax.jit(
    jax.vmap(particle_div_net.apply, in_axes=(None, None, 0))
)

In [None]:
def compute_entropy_measures(
    xgs: np.ndarray,
    params: hk.Params,
    map_particle_score: Callable[[hk.Params, np.ndarray, np.ndarray], np.ndarray],
    map_particle_div: Callable[[hk.Params, np.ndarray, np.ndarray], np.ndarray],
    N: int,
    batch_size: int
) -> Tuple:
    ## set up batching
    assert(N % batch_size == 0)
    nbatches = N // batch_size

    ## unpack the data
    xs, gs = np.split(xgs, 2)
    
    ## compute entropies
    sgs     = onp.zeros((N, d))
    gdots   = onp.zeros((N, d))
    div_vgs = onp.zeros(N)
    for kk in range(nbatches):
        lb = kk*batch_size
        ub = lb + batch_size
        
        sgs[lb:ub]     =  map_particle_score(params, xgs, np.arange(lb, ub))
        div_vgs[lb:ub] = -d - map_particle_div(params, xgs, np.arange(lb, ub))
        
    gdots = -gs - sgs
    return gdots, div_vgs, gs, sgs

## MIPS Figure

### Define Figure

In [None]:
def radius_to_points(ax, radius):
    """Convert radius in axis units to radius in points, for use
    with making scatterplots. Written by GPT4."""
    # Convert the radius from data units to display units (pixels)
    pixel_radius = ax.transData.transform([(radius, radius)]) - ax.transData.transform([(0, 0)])
    
    # Compute the area of the circle in pixels²
    area_pixels = np.pi * (pixel_radius[0, 0] ** 2)

    # Convert the area from pixels² to points²
    area_points = area_pixels / (plt.gcf().dpi / 72) ** 2
    
    return area_points


def radius_to_area(ax, radius):
    """
    Convert a desired radius (in data units) to an appropriate s value for scatter in matplotlib.
    
    Parameters:
    - ax: The axes object where the scatter plot will be drawn.
    - radius: Desired radius in data units.
    
    Returns:
    - s: Appropriate s value for scatter.
    """
    # Use transformations to convert a single point, which is essentially 
    # the center of the marker plus its radius, from data coordinates to display coordinates.
    pts = np.array([[0, 0], [radius, 0]])
    trans_pts = ax.transData.transform(pts)
    
    # Compute the distance between these two points in display coordinates.
    # This distance represents the radius in display units (like pixels).
    r_display = np.linalg.norm(trans_pts[1] - trans_pts[0])
    
    # Area of a circle is pi * r^2. We return the area in display units as the s parameter for scatter.
    return np.pi * r_display**2


def test_solid(
    xi: np.ndarray,
    xs: np.ndarray,
) -> bool:
    norms = np.linalg.norm(xs - xi[None, :], axis=1)
    return np.sum(norms < 2.4) >= 5


def test_gas(
    xi: np.ndarray,
    xs: np.ndarray,
) -> bool:
    norms = np.linalg.norm(xs - xi[None, :], axis=1)
    return np.sum(norms < 2.4) <= 2


@jax.jit
def find_solid_inds(
    xs: np.ndarray
) -> np.ndarray: 
    return jax.vmap(test_solid, in_axes=(0, None))(xs, xs)


@jax.jit
def find_gas_inds(
    xs: np.ndarray
) -> np.ndarray: 
    return jax.vmap(test_gas, in_axes=(0, None))(xs, xs)

In [None]:
def make_particle_plot(
    q1: np.ndarray,
    q2: np.ndarray,
    label1: str,
    label2: str,
    xs_plot: np.ndarray,
    index: int,
    clip_quantile: float,
    width: float,
    s_fac: float,
    plot_cluster: bool,
    center_cluster: bool,
    rasterized: bool,
    cmap_str: str,
    extension: str,
    save_str: str
) -> None:
    ## set up basic plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 5
    fraction  = 0.2
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 16.0
    nrows     = 1


    if plot_cluster:
        titles     = [ r"", label1, label2]
        quantities = [None,     q1,     q2]
    else:
        titles     = [label1, label2]
        quantities = [    q1,     q2]


    ncols     = len(titles)
    cmap      = sns.color_palette(cmap_str, as_cmap=True)
    fig, axs  = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )
    
    ## center xs
    if center_cluster:
        com     = onp.mean(xs_plot[find_solid_inds(xs_plot)], axis=0)
        xs_plot = drifts.torus_project(xs_plot - com[None, :], width)
    
    ## common to all axes
    for kk, ax in enumerate(axs.ravel()):
        ax.set_facecolor('k')
        ax.set_xlim([-width, width])
        ax.set_ylim([-width, width])
        ax.set_xticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', length=0, labelsize=fontsize)
        
        ax.axes.set_aspect(1.0)
        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        
        for spine in ax.spines.values():
            spine.set_edgecolor(None)


    ## make scatter plot
    for kk, ax in enumerate(axs):
        quantity = quantities[kk]
        s        = radius_to_points(ax, s_fac)

        if quantity is not None:
            vmin = onp.quantile(quantity, q=clip_quantile) if quantity.min() < 0.0 else 0.0
            vmax = onp.quantile(quantity, q=(1-clip_quantile))
        else:
            vmin, vmax = 0.0, 1.0
            
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable.set_array([])
        cbar = fig.colorbar(mappable=mappable, norm=norm, ax=ax, 
                            fraction=fraction, shrink=shrink, pad=pad, orientation='horizontal')
        cbar.set_label(titles[kk], fontsize=fontsize)
        cbar.ax.tick_params(which='both', labelsize=0.8*fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')
        
        if quantity is None:
            ## plot without colorbar
            ax.set_facecolor('white')
            scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', c='black', linewidths=0.5, alpha=1.0, rasterized=rasterized)
            cbar.ax.set_visible(False)
        else:
            ## plot with color bar
            scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', c=quantity, cmap=cmap, norm=norm, alpha=1.0, rasterized=rasterized)



    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}_{config["network_type"]}_{cmap_str}_{index}.{extension}', dpi=600, bbox_inches='tight')

## Play around

In [None]:
index                   = onp.random.randint(xgs_data.shape[0])
xgs                     = xgs_data[index]
xs, gs                  = np.split(xgs, 2)
gdots, div_vgs, gs, sgs = compute_entropy_measures(xgs, ema_params, map_particle_score, map_particle_div, N=N, batch_size=256)
vg_norms                = onp.linalg.norm(gdots, axis=1)
sg_norms                = onp.linalg.norm(sgs, axis=1)
g_norms                 = onp.linalg.norm(gs, axis=1)

In [None]:
make_particle_plot(
    q1=vg_norms, 
    q2=div_vgs, 
    label1=r"$|v_g|/\gamma$", 
    label2=r"$\nabla_g\cdot v_g /\gamma$", 
    xs_plot=xs, 
    clip_quantile=0.025,
    width=dataset_dict['width'],
    s_fac=1.0,
    index=index, 
    cmap_str='mako', 
    save_str='', 
    plot_cluster=False,
    center_cluster=True
)

In [None]:
make_particle_plot(
    q1=g_norms, 
    q2=sg_norms, 
    label1=r"$|g|$", 
    label2=r"$|s_g|$", 
    xs_plot=xs, 
    clip_quantile=0.025,
    width=dataset_dict['width'],
    s_fac=1.0,
    index=index, 
    cmap_str='mako', 
    save_str='', 
    plot_cluster=True,
    center_cluster=True
)

## Loop and make figure.

In [None]:
good_indices  = [0, 87077, 1176, 117399, 75483, 48061, 21530]
base_folder   = '/scratch/nb3397/results/mips/hypoelliptic/figures'
output_folder = '9_22_23'


for index in good_indices:
    xgs                     = xgs_data[index]
    xs, gs                  = np.split(xgs, 2)
    gdots, div_vgs, gs, sgs = compute_entropy_measures(xgs, ema_params, map_particle_score, map_particle_div, N=N, batch_size=256)
    vg_norms                = onp.linalg.norm(gdots, axis=1)
    sg_norms                = onp.linalg.norm(sgs, axis=1)
    g_norms                 = onp.linalg.norm(gs, axis=1)


    for (q1, q2), (label1, label2) in [((vg_norms, div_vgs), (r"$|v_g|/\gamma$", r"$\nabla_g\cdot v_g /\gamma$")),
                                       ((g_norms, sg_norms), (r"$|g|$", r"$|s_g|$"))]:


        if label1 == r"$|g|$":
            save_name = 'particle_g_sg'
            plot_cluster = False
        else:
            save_name = 'particle_EPR'
            plot_cluster = True


        print(f'Making figure of ({label1}, {label2}). plot_cluster={plot_cluster}')
        for extension in ['pdf', 'png']:
            make_particle_plot(
                q1=q1,
                q2=q2,
                label1=label1,
                label2=label2,
                xs_plot=xs, 
                clip_quantile=0.025, 
                width=dataset_dict['width'],
                s_fac=1.0, 
                index=index, 
                cmap_str='mako',
                plot_cluster=plot_cluster,
                center_cluster=True,
                rasterized=True,
                extension=extension,
                save_str=f'{base_folder}/{output_folder}/{save_name}{extension}'
            )
    
        print(f'Finished making figure for index {index}.')
print('Finished!')

##  Make reduced set of figures

In [None]:
good_indices  = [117399]
base_folder   = '/scratch/nb3397/results/mips/hypoelliptic/figures'
output_folder = '9_22_23'


for index in good_indices:
    xgs                     = xgs_data[index]
    xs, gs                  = np.split(xgs, 2)
    gdots, div_vgs, gs, sgs = compute_entropy_measures(xgs, ema_params, map_particle_score, map_particle_div, N=N, batch_size=256)
    vg_norms                = onp.linalg.norm(gdots, axis=1)
    sg_norms                = onp.linalg.norm(sgs, axis=1)
    g_norms                 = onp.linalg.norm(gs, axis=1)


    for (q1, q2), (label1, label2) in [((vg_norms, div_vgs), (r"$|v_g|/\gamma$", r"$\nabla_g\cdot v_g /\gamma$")),
                                       ((g_norms, sg_norms), (r"$|g|$", r"$|s_g|$"))]:


        if label1 == r"$|g|$":
            save_name = 'particle_g_sg'
            plot_cluster = False
        else:
            save_name = 'particle_EPR'
            plot_cluster = True


        print(f'Making figure of ({label1}, {label2}). plot_cluster={plot_cluster}')
        for extension in ['pdf', 'png']:
            make_particle_plot(
                q1=q1,
                q2=q2,
                label1=label1,
                label2=label2,
                xs_plot=xs, 
                clip_quantile=0.025, 
                width=dataset_dict['width'],
                s_fac=1.0, 
                index=index, 
                cmap_str='mako',
                plot_cluster=plot_cluster,
                center_cluster=True,
                rasterized=True,
                extension=extension,
                save_str=f'{base_folder}/{output_folder}/{save_name}'
            )
    
        print(f'Finished making figure for index {index}.')
print('Finished!')

# vg Visualization

In [None]:
def make_vg_plot(
    vgs: np.ndarray,
    sgs: np.ndarray,
    xs_plot: np.ndarray,
    gs_plot: np.ndarray,
    clip_quantile: float,
    width: float,
    s_fac: float,
    plot_score: bool,
    center_cluster: bool,
    rasterized: bool,
    min_norm: float,
    save_str: str
) -> None:
    ## set up basic plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh   = 4, 5
    fraction = 0.2
    pad      = 0.025
    shrink   = 0.5
    fontsize = 16.0
    nrows    = 1

    if plot_score:
        titles     = [r"$|v_g|$", r"$|g|$", r"$|s_g|$"]
        quantities = [      vgs,        gs,        sgs]
    else:
        titles     = [r"$|v_g|$", r"$|g|$"]
        quantities = [       vgs,       gs]


    ncols      = len(titles)
    cmap       = sns.color_palette('mako', as_cmap=True)
    fig, axs   = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )


    ## center xs
    if center_cluster:
        com     = onp.mean(xs_plot[find_solid_inds(xs_plot)], axis=0)
        xs_plot = drifts.torus_project(xs_plot - com[None, :], width)


    ## common to all axes
    for kk, ax in enumerate(axs.ravel()):
        ax.set_facecolor('k')
        ax.set_xlim([-width, width])
        ax.set_ylim([-width, width])
        ax.set_xticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', length=0, labelsize=fontsize)
        
        ax.axes.set_aspect(1.0)
        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        
        for spine in ax.spines.values():
            spine.set_edgecolor(None)


    ## make scatter plot
    for kk, ax in enumerate(axs):
        quantity = onp.array(quantities[kk])
        norms    = onp.linalg.norm(quantity, axis=-1)
        s        = radius_to_points(ax, s_fac)
        vmin     = onp.quantile(norms, q=clip_quantile) if quantity.min() < 0.0 else 0.0
        vmax     = onp.quantile(norms, q=(1-clip_quantile))            
        norm     = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable.set_array([])
        cbar     = fig.colorbar(mappable=mappable, norm=norm, ax=ax, 
                                fraction=fraction, shrink=shrink, pad=pad, 
                                orientation='horizontal')
        cbar.set_label(titles[kk], fontsize=fontsize)
        cbar.ax.tick_params(which='both', labelsize=0.8*fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')


        ## plot norm with color bar
        scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o',# c='C4')
                          c=norms, cmap=cmap, norm=norm, alpha=1.0, rasterized=rasterized)


        ## plot direction
        if kk == 0:
            quiver_inds = (norms >= min_norm)
        else:
            quiver_inds = onp.arange(vgs.shape[0])

        ax.quiver(xs_plot[quiver_inds, 0], xs_plot[quiver_inds, 1], quantity[quiver_inds, 0], 
                  quantity[quiver_inds, 1], color='white', alpha=0.75, rasterized=rasterized)
        

    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}.pdf', dpi=300, bbox_inches='tight')

In [None]:
dataset_folder = '/scratch/nb3397/projects/mips_project/dataset/transfer_learn/phi_transfer'
snapshot_N     = 16384
data_file      = f'OU_v0=0.025_N={snapshot_N}_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy'
snapshot_dict  = pickle.load(open(f'{dataset_folder}/{data_file}', 'rb'))
snapshot_index = 5
xgs            = snapshot_dict['traj'][snapshot_index]
xs, gs         = np.split(xgs, 2)


## set up the network
curr_particle_score_net, curr_particle_div_net = networks.define_transformer_networks(
    w0=config['w0'],
    n_neighbors=config['n_neighbors'],
    num_layers=config['num_layers'],
    embed_dim=config['embed_dim'],
    embed_n_hidden=config['embed_n_hidden'],
    embed_n_neurons=config['embed_n_neurons'],
    num_heads=config['num_heads'],
    dim_feedforward=config['dim_feedforward'],
    n_layers_feedforward=1,
    width=snapshot_dict['width'],
    network_type=config['network_type'],
    this_particle_pooling=config['this_particle_pooling'],
    scale_fac=config['scale_fac']
)


curr_map_particle_score = jax.jit(jax.vmap(curr_particle_score_net.apply, in_axes=(None, None, 0)))
curr_map_particle_div   = jax.jit(jax.vmap(curr_particle_div_net.apply, in_axes=(None, None, 0)))
vgs, div_vgs, gs, sgs   = compute_entropy_measures(xgs, ema_params, curr_map_particle_score, curr_map_particle_div, N=snapshot_N, batch_size=256)

In [None]:
base_save_folder = '/scratch/nb3397/results/mips/hypoelliptic/figures/9_22_23'


make_vg_plot(
    vgs=vgs,
    sgs=sgs,
    xs_plot=xs,
    gs_plot=gs,
    clip_quantile=0.05,
    width=snapshot_dict['width'],
    s_fac=0.5,
    center_cluster=True,
    plot_score=False,
    rasterized=True,
    min_norm=0.9,
    save_str=f'{base_save_folder}/vg_N={snapshot_N}'
)

# Stochastic Movie

In [None]:
def make_movie(
    frame_data: np.ndarray,  # [nsteps, N, d]
    v_data: np.ndarray,      # [nsteps, N]
    div_vg_data: np.ndarray, # [nsteps, N]
    clip_quantile: float,
    width: float,
    s_fac: float,
    save_str: str,
) -> None:
    plt.close('all')
    sns.set_palette('deep')


    ## set up figure
    nframes  = frame_data.shape[0]
    fw, fh    = 4, 5
    fraction  = 0.2
    pad       = 0.025
    shrink    = 0.5
    fontsize  = 16.0
    lw        = 0.5
    nrows     = 1
    ncols     = 3
    fig, axs  = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )
    quantities = [None,          v_data,                     div_vg_data]
    titles     = [None, r"$|v_g|/\gamma$", r"$\nabla_g \cdot v_g /\gamma$"]
    cmap=sns.color_palette('mako', as_cmap=True)
    

    ## center all xs by the initial COM
    com        = onp.mean(frame_data[0][find_solid_inds(frame_data[0])], axis=0)
    frame_data = drifts.torus_project(frame_data - com[None, None, :], width)


    ## set axis limits, grids, and particle scaling
    for kk, ax in enumerate(axs.ravel()):
        ax.set_facecolor('k')
        ax.set_xlim([-width, width])
        ax.set_ylim([-width, width])
        ax.set_xticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.set_yticks([-0.5*width, 0, 0.5*width],
                      [r"$-L/4$", r"$0.0$", r"$L/4$"])
        ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
        ax.tick_params(which='both', length=0, labelsize=fontsize)
        
        ax.axes.set_aspect(1.0)
        if kk == 0:
            ax.set_ylabel(r"$y$", fontsize=fontsize)
        ax.set_xlabel(r"$x$", fontsize=fontsize)
        
        for spine in ax.spines.values():
            spine.set_edgecolor(None)
            
            
    ## set up initial scatter plots
    scats   = []
    norms   = []
    xs_plot = frame_data[0]
    for kk, ax in enumerate(axs):
        quantity       = quantities[kk]
        s              = radius_to_points(ax, s_fac)

        if quantity is not None:
            frame_quantity = quantity[0]
            vmin = quantity.min()
            vmax = onp.quantile(quantity, q=(1-clip_quantile))
        else:
            vmin, vmax = 0.0, 1.0
            
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        norms.append(norm)
        mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable.set_array([])
        cbar = fig.colorbar(mappable=mappable, norm=norm, ax=ax, 
                            fraction=fraction, shrink=shrink, pad=pad, 
                            orientation='horizontal')
        cbar.set_label(titles[kk], fontsize=fontsize)
        cbar.ax.tick_params(which='both', labelsize=0.8*fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')
        
        if quantity is None:
            ## plot without colorbar
            ax.set_facecolor('white')
            scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', 
                              c='black', linewidths=lw, alpha=1.0)
            scats.append(scat)
            cbar.ax.set_visible(False)
        else:
            ## plot with color bar
            scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', 
                              c=frame_quantity, linewidths=lw, cmap=cmap, norm=norm, alpha=1.0)
            scats.append(scat)


    def init():
        return scats


    def animate(frame: int):
        ## remove scatter for speed
        for path_collection in scats:
            path_collection.remove()


        ## update scatter plot
        xs_plot = frame_data[frame]
        for kk, ax in enumerate(axs):
            s = radius_to_points(ax, s_fac)
            if kk > 0:
                frame_quantity = quantities[kk][frame]
            else:
                frame_quantity = None
    
            if frame_quantity is None:
                scats[kk] = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', 
                                       c='black', linewidths=lw, alpha=1.0)
            else:
                scats[kk] = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', c=frame_quantity, 
                                       linewidths=lw, cmap=cmap, norm=norms[kk], alpha=1.0)        
        

        ## return what changed in the plot
        try:
            if frame % (nframes // 10) == 0:
                print(f'[{frame}/{nframes}]')
        except:
            pass

        return scats


    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=onp.arange(nframes), 
                                   interval=0.1, blit=True, repeat=False, cache_frame_data=False)


    save_name = f'{save_str}.mp4'
    anim.save(save_name, fps=60, dpi=300)
    plt.show()

### experiment with individual movie

In [None]:
dataset_folder = '/scratch/nb3397/projects/mips_project/dataset/real_mips_datasets/'
traj_file      = 'N=4096_v0=0.025_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace_fac=0.00100_thermalize=0.00100_A=0.0_k=1_eps=0.npy'
traj_data      = pickle.load(open(f'{dataset_folder}/{traj_file}', 'rb'))['trajs']

In [None]:
nframes     = 25
frame_data  = traj_data[:nframes]
v_data      = onp.zeros((nframes, N))
div_vg_data = onp.zeros((nframes, N))


start_time = time.time()
for curr_frame in range(nframes):
    gdots, div_vg_data[curr_frame] = compute_entropy_measures(
        frame_data[curr_frame], ema_params, map_particle_score, map_particle_div, N=N, batch_size=256
    )
    v_data[curr_frame] = np.linalg.norm(gdots, axis=-1)
    
    try:
        if ((curr_frame+1) % (nframes // 10)) == 0:
            end_time = time.time()
            print(f'[{curr_frame+1}/{nframes}], {end_time-start_time}s')
            start_time = time.time()
    except:
        pass

In [None]:
base_folder   = '/scratch/nb3397/results/mips/hypoelliptic/figures'
output_folder = '9_16_23'
clip_quantile = 0.075
save_name     = f'sde_movie_N={movie_N}'

make_movie(
    frame_data[:, :movie_N], v_data, div_vg_data, clip_quantile=clip_quantile, s_fac=1.0, save_str=f'{base_folder}/{output_folder}/{save_name}'
)

### Loop and make movies all at once

In [None]:
dataset_folder = '/scratch/nb3397/projects/mips_project/dataset/transfer_learn/movie_trajs/backup_trajs'
base_folder    = '/scratch/nb3397/results/mips/hypoelliptic/figures'
output_folder  = '9_18_23'
clip_quantile  = 0.075
# movie_Ns       = [4096, 8192, 16384, 32768]
movie_Ns       = [32768]
nframes        = 1000


for kk, movie_N in enumerate(movie_Ns):
    ## load in the data
    total_start = time.time()
    traj_file   = f'OU_v0=0.025_N={movie_N}_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0.npy'
    save_name   = f'sde_movie_N={movie_N}'
    traj_data   = pickle.load(open(f'{dataset_folder}/{traj_file}', 'rb'))
    frame_data  = traj_data['traj'][:nframes]


    ## set up the network
    curr_particle_score_net, curr_particle_div_net = networks.define_transformer_networks(
        w0=config['w0'],
        n_neighbors=config['n_neighbors'],
        num_layers=config['num_layers'],
        embed_dim=config['embed_dim'],
        embed_n_hidden=config['embed_n_hidden'],
        embed_n_neurons=config['embed_n_neurons'],
        num_heads=config['num_heads'],
        dim_feedforward=config['dim_feedforward'],
        n_layers_feedforward=1,
        width=traj_data['width'],
        network_type=config['network_type'],
        this_particle_pooling=config['this_particle_pooling'],
        scale_fac=config['scale_fac']
    )
    curr_map_particle_score = jax.jit(jax.vmap(curr_particle_score_net.apply, in_axes=(None, None, 0)))
    curr_map_particle_div   = jax.jit(jax.vmap(curr_particle_div_net.apply, in_axes=(None, None, 0)))
    
    
    ## compute particle EPR data
    v_data      = onp.zeros((nframes, movie_N))
    div_vg_data = onp.zeros((nframes, movie_N))
    start_time = time.time()
    for curr_frame in range(nframes):
        gdots, div_vg_data[curr_frame], _, _ = compute_entropy_measures(
            frame_data[curr_frame], ema_params, curr_map_particle_score, curr_map_particle_div, N=movie_N, batch_size=256
        )
        v_data[curr_frame] = np.linalg.norm(gdots, axis=-1)

        try:
            if ((curr_frame+1) % (nframes // 10)) == 0:
                end_time = time.time()
                print(f'[{curr_frame+1}/{nframes}], {end_time-start_time}s')
                start_time = time.time()
        except:
            pass


    make_movie(
        frame_data[:, :movie_N],
        v_data,
        div_vg_data,
        clip_quantile=clip_quantile,
        width=traj_data['width'],
        s_fac=1.0,
        save_str=f'{base_folder}/{output_folder}/{save_name}'
    )
    
    print(f'Finished N={movie_N}. [{kk+1}/{len(movie_Ns)}]. {(time.time() - total_start)/60}m.')

# Transfer Learning

## Packing Fraction

In [None]:
## load in all datasets
base_folder = '/scratch/nb3397/projects/mips_project/dataset/transfer_learn/phi_transfer'
# phi_N = 4096
phi_N = 8192
# phi_N = 16384
# phi_N = 32768


file_names = {
    fr"\phi=0.1": f'OU_v0=0.025_N={phi_N}_gamma=0.0001_phi=0.1_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_1.npy',
    fr"\phi=0.25": f'OU_v0=0.025_N={phi_N}_gamma=0.0001_phi=0.25_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_2.npy',
    fr"\phi=0.5": f'OU_v0=0.025_N={phi_N}_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy',
    fr"\phi=0.75": f'OU_v0=0.025_N={phi_N}_gamma=0.0001_phi=0.75_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_4.npy',
    fr"\phi=0.9": f'OU_v0=0.025_N={phi_N}_gamma=0.0001_phi=0.9_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_5.npy'
}


transfer_trajs = {
    key: pickle.load(open(f'{base_folder}/{file_name}', 'rb')) for key, file_name in file_names.items()
}

In [None]:
## assemble individual networks to correctly take into account the periodicity
map_particle_scores = {}
map_particle_divs   = {}
for key in transfer_trajs.keys():
    curr_particle_score_net, curr_particle_div_net = \
        networks.define_transformer_networks(
            w0=config['w0'],
            n_neighbors=config['n_neighbors'],
            num_layers=config['num_layers'],
            embed_dim=config['embed_dim'],
            embed_n_hidden=config['embed_n_hidden'],
            embed_n_neurons=config['embed_n_neurons'],
            num_heads=config['num_heads'],
            dim_feedforward=config['dim_feedforward'],
            n_layers_feedforward=1,
            width=transfer_trajs[key]['width'],
            network_type=config['network_type'],
            this_particle_pooling=config['this_particle_pooling'],
            scale_fac=config['scale_fac']
        )


    map_particle_scores[key] = jax.jit(
        jax.vmap(curr_particle_score_net.apply, in_axes=(None, None, 0))
    )


    map_particle_divs[key] = jax.jit(
        jax.vmap(curr_particle_div_net.apply, in_axes=(None, None, 0))
    )

In [None]:
## compute the entropy measures for each
index = 0
epr_data = {}


for key in transfer_trajs.keys():
    curr_traj = transfer_trajs[key]['traj'][index]
    xs, gs                  = np.split(curr_traj, 2)
    gdots, div_vgs, gs, sgs = compute_entropy_measures(
        curr_traj, ema_params, map_particle_scores[key], map_particle_divs[key], N=phi_N, batch_size=256
    )


    vg_norms                = onp.linalg.norm(gdots, axis=1)
    sg_norms                = onp.linalg.norm(sgs, axis=1)
    g_norms                 = onp.linalg.norm(gs, axis=1)


    epr_data[key] = {
        r"$\nabla_g \cdot v_g / \gamma$": div_vgs,
        r"$|v_g| / \gamma$": vg_norms,
        r"$|s_g|$": sg_norms,
        r"$|g|$": g_norms
    }

In [None]:
def make_phi_transfer_plot(
    epr_data: dict,
    transfer_trajs: dict,
    index: int,
    clip_quantile: float,
    s_fac: int,
    rasterized: bool,
    plot_vg: bool,
    extension: str,
    save_str: str
) -> None:
    ## set up basic plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh     = 5, 4
    fraction   = 0.2
    pad        = 0.025
    shrink     = 0.5
    fontsize   = 24
    
#     phi_order  = ['\phi=0.01', '\phi=0.1', '\phi=0.25', '\phi=0.5', '\phi=0.75', '\phi=0.9']
#     phis       = [        0.01,        0.1,        0.25,       0.5,        0.75,        0.9]
    
    phi_order  = ['\phi=0.1', '\phi=0.25', '\phi=0.5', '\phi=0.75', '\phi=0.9']
    phis       = [       0.1,        0.25,        0.5,        0.75,        0.9]
    
    base_phi   = 0.5
    
    if plot_vg:
        plot_keys  = [None, r"$\nabla_g \cdot v_g / \gamma$", r"$|v_g| / \gamma$"]
    else:
        plot_keys  = [None, r"$\nabla_g \cdot v_g / \gamma$"]
        
    nrows      = len(plot_keys)
    ncols      = len(phi_order)
    cmap       = sns.color_palette('mako', as_cmap=True)
    fig, axs   = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=False, constrained_layout=True
    )


    ## make the plot
    for row in range(nrows):
        ## compute shared colorbar
        plot_key = plot_keys[row]
        if plot_key == None:
            vmin, vmax = 0.0, 1.0
        else:
            global_quantity = onp.array([])
            for col in range(ncols):
                global_quantity = onp.concatenate((global_quantity, epr_data[phi_order[col]][plot_key]))

            vmin = onp.quantile(global_quantity, q=clip_quantile)
            vmax = onp.quantile(global_quantity, q=(1-clip_quantile))

        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable.set_array([])
        cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs[row].ravel(), 
                            fraction=fraction, shrink=shrink, pad=pad, 
                            orientation='vertical')
        cbar.set_label(fr"${plot_key}$", fontsize=fontsize)
        cbar.ax.tick_params(which='both', labelsize=0.8*fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')


        for col in range(ncols):
            ## set up basic plot parameters
            ax      = axs[row, col]
            phi_key = phi_order[col]

            if row == 0:
                ax.set_facecolor('white')
                ax.set_title(fr"${phi_key}$", fontsize=fontsize)
            else:
                ax.set_facecolor('k')


            ## compute particle radius to reflect change in width
            traj_data = transfer_trajs[phi_order[col]]
            width     = traj_data['width']
            ax.set_xlim([-width, width])
            ax.set_ylim([-width, width])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
            ax.tick_params(which='both', length=0, labelsize=fontsize)
            ax.set_aspect(1.0)
            for spine in ax.spines.values():
                spine.set_edgecolor(None)


            ## make the scatter plot
            curr_xgs  = traj_data['traj'][index]
            xs, _     = onp.split(curr_xgs, 2)
            curr_phi  = phis[col]
            s         = radius_to_area(ax, s_fac*onp.sqrt(curr_phi/base_phi))
            if plot_key == None:
                scat = ax.scatter(xs[:, 0], xs[:, 1], s=s, c='black', linewidths=0.5, alpha=1.0, rasterized=rasterized)
                cbar.ax.set_visible(False)
            else:
                quantity = epr_data[phi_key][plot_key]
                scat = ax.scatter(xs[:, 0], xs[:, 1], s=s, c=quantity, linewidths=0.5, cmap=cmap, norm=norm, alpha=1.0, rasterized=rasterized)


    if save_str != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}.{extension}', dpi=300, bbox_inches='tight')

In [None]:
base_save_folder = '/scratch/nb3397/results/mips/hypoelliptic/figures/10_1_23'
save_name        = f'phi_transfer_N={phi_N}'


make_phi_transfer_plot(
    epr_data=epr_data,
    transfer_trajs=transfer_trajs,
    index=index,
    clip_quantile=0.05,
    s_fac=0.265,  # phi_N=4k, phi_N=8k, phi_N=16k,
    rasterized=True,
    plot_vg=False,
    extension='pdf',
    save_str=f'{base_save_folder}/{save_name}'
)

### Number of Particles

In [None]:
## load in all datasets
base_folder = '/scratch/nb3397/projects/mips_project/dataset/transfer_learn/'


file_names = {
    r"N=4096":  'phi_transfer/OU_v0=0.025_N=4096_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy',
    r"N=8192":  'phi_transfer/OU_v0=0.025_N=8192_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy',
    r"N=16384": 'phi_transfer/OU_v0=0.025_N=16384_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy',
    r"N=32768": 'phi_transfer/OU_v0=0.025_N=32768_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy'

}


transfer_trajs = {
    key: pickle.load(open(f'{base_folder}/{file_name}', 'rb')) for key, file_name in file_names.items()
}

In [None]:
## assemble individual networks to correctly take into account the periodicity
map_particle_scores = {}
map_particle_divs   = {}
for key in transfer_trajs.keys():
    curr_particle_score_net, curr_particle_div_net = \
        networks.define_transformer_networks(
            w0=config['w0'],
            n_neighbors=config['n_neighbors'],
            num_layers=config['num_layers'],
            embed_dim=config['embed_dim'],
            embed_n_hidden=config['embed_n_hidden'],
            embed_n_neurons=config['embed_n_neurons'],
            num_heads=config['num_heads'],
            dim_feedforward=config['dim_feedforward'],
            n_layers_feedforward=1,
            width=transfer_trajs[key]['width'],
            network_type=config['network_type'],
            this_particle_pooling=config['this_particle_pooling'],
            scale_fac=config['scale_fac']
        )


    map_particle_scores[key] = jax.jit(
        jax.vmap(curr_particle_score_net.apply, in_axes=(None, None, 0))
    )


    map_particle_divs[key] = jax.jit(
        jax.vmap(curr_particle_div_net.apply, in_axes=(None, None, 0))
    )

In [None]:
## compute the entropy measures for each
index = 0
epr_data = {}


for key in transfer_trajs.keys():
    curr_traj               = transfer_trajs[key]['traj'][index]
    xs, gs                  = np.split(curr_traj, 2)
    gdots, div_vgs, gs, sgs = compute_entropy_measures(
        curr_traj, ema_params, map_particle_scores[key], map_particle_divs[key], N=xs.shape[0],batch_size=256
    )


    vg_norms                = onp.linalg.norm(gdots, axis=1)
    sg_norms                = onp.linalg.norm(sgs, axis=1)
    g_norms                 = onp.linalg.norm(gs, axis=1)


    epr_data[key] = {
        r"$\nabla_g \cdot v_g / \gamma$": div_vgs,
        r"$|v_g| / \gamma$": vg_norms,
        r"$|s_g|$": sg_norms,
        r"$|g|$": g_norms
    }

In [None]:
def make_N_transfer_plot(
    epr_data: dict,
    transfer_trajs: dict,
    index: int,
    clip_quantile: float,
    rasterized: bool,
    plot_vg: bool,
    s_facs: list,
    save_str: str
) -> None:
    ## set up basic plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh     = 5, 4
    fraction   = 0.2
    pad        = 0.025
    shrink     = 0.5
    fontsize   = 24
    N_order    = [r"N=4096", r"N=8192", r"N=16384", r"N=32768"]
    Ns         = [4096.,         8192.,     16384.,     32768.]
    
    if plot_vg:
        plot_keys  = [None, r"$\nabla_g \cdot v_g / \gamma$", r"$|v_g| / \gamma$"]
    else:
        plot_keys  = [None, r"$\nabla_g \cdot v_g / \gamma$"]
        
    nrows      = len(plot_keys)
    ncols      = len(N_order)
    cmap       = sns.color_palette('mako', as_cmap=True)
    base_N     = 4096
    fig, axs   = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=False, constrained_layout=True
    )


    ## make the plot
    for row in range(nrows):
        ## compute shared colorbar
        plot_key = plot_keys[row]
        if plot_key == None:
            vmin, vmax = 0.0, 1.0
        else:
            global_quantity = onp.array([])
            for col in range(ncols):
                global_quantity = onp.concatenate((global_quantity, epr_data[N_order[col]][plot_key]))

            vmin = onp.quantile(global_quantity, q=clip_quantile)
            vmax = onp.quantile(global_quantity, q=(1-clip_quantile))

        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
        mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        mappable.set_array([])
        cbar = fig.colorbar(mappable=mappable, norm=norm, ax=axs[row].ravel(), 
                            fraction=fraction, shrink=shrink, pad=pad, 
                            orientation='vertical')
        cbar.set_label(fr"${plot_key}$", fontsize=fontsize)
        cbar.ax.tick_params(which='both', labelsize=0.8*fontsize, width=0, length=0)
        cbar.outline.set_edgecolor('grey')


        for col in range(ncols):
            ## set up basic plot parameters
            ax     = axs[row, col]
            N_key  = N_order[col]
            curr_N = Ns[col]


            if row == 0:
                ax.set_facecolor('white')
                ax.set_title(fr"${N_key}$", fontsize=fontsize)
            else:
                ax.set_facecolor('k')


            ## compute particle radius to reflect change in width
            traj_data = transfer_trajs[N_order[col]]
            width     = traj_data['width']
            ax.set_xlim([-width, width])
            ax.set_ylim([-width, width])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
            ax.tick_params(which='both', length=0, labelsize=fontsize)
            ax.set_aspect(1.0)
            for spine in ax.spines.values():
                spine.set_edgecolor(None)


            ## make the scatter plot
            curr_xgs  = traj_data['traj'][index]
            xs, _     = onp.split(curr_xgs, 2)
#             s         = radius_to_points(ax, s_fac*(base_width/width))
            s         = radius_to_area(ax, s_facs[col])
            if plot_key == None:
                scat = ax.scatter(xs[:, 0], xs[:, 1], s=s, marker='o', c='black', linewidths=0.5, alpha=1.0, rasterized=rasterized)
                cbar.ax.set_visible(False)
            else:
                quantity = epr_data[N_key][plot_key]
                scat = ax.scatter(xs[:, 0], xs[:, 1], s=s, marker='o', c=quantity, cmap=cmap, linewidths=0.5, norm=norm, alpha=1.0, rasterized=rasterized)


    if save_str != '':
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}.pdf', dpi=300, bbox_inches='tight')

In [None]:
base_save_folder = '/scratch/nb3397/results/mips/hypoelliptic/figures/10_1_23'
save_name        = 'N_transfer'


make_N_transfer_plot(
    epr_data=epr_data,
    transfer_trajs=transfer_trajs,
    index=index,
    clip_quantile=0.05,
    s_facs=[0.35, 0.315, 0.28, 0.25],
    save_str=f'{base_save_folder}/{save_name}'
)

# Attention Map

## Separate Network Functions

In [None]:
@hk.without_apply_rng
@hk.transform
def g_embedding(inp: np.ndarray):
    return hk.Sequential(
        networks.construct_mlp_layers(
            config['embed_n_hidden'], config['embed_n_neurons'], jax.nn.gelu, w0=config['w0'], 
            output_dim=config['embed_dim'] // 2, use_layer_norm=False, 
            use_residual_connections=True, name='g_embedding'
        )
    )(inp)


@hk.without_apply_rng
@hk.transform
def x_embedding(inp: np.ndarray):
    return hk.Sequential(
        networks.construct_mlp_layers(
            config['embed_n_hidden'], config['embed_n_neurons'], jax.nn.gelu, w0=config['w0'], 
            output_dim=config['embed_dim'] // 2, use_layer_norm=False, 
            use_residual_connections=True, name='x_embedding'
        )
    )(inp)


@hk.without_apply_rng
@hk.transform
def transformer(inp: np.ndarray,):
     return networks.Transformer(
         num_layers=config['num_layers'], input_dim=config['embed_dim'],
         num_heads=config['num_heads'], dim_feedforward=config['dim_feedforward'],
         n_layers_feedforward=config['n_layers_feedforward'],
         n_inducing_points=0, w0=config['w0']
     )(inp)


@hk.without_apply_rng
@hk.transform
def get_transformer_attention(inp: np.ndarray):
     return networks.Transformer(
         num_layers=config['num_layers'], input_dim=config['embed_dim'],
         num_heads=config['num_heads'], dim_feedforward=config['dim_feedforward'],
         n_layers_feedforward=1,
         n_inducing_points=0, w0=config['w0']
     ).get_attention_maps(inp)


def attention_rollout(
    attention_maps: np.ndarray # [L, N, N]
) -> np.ndarray:
    """Compute the rollout attention map; assumes that we have averaged over all heads."""
    dim             = attention_maps[0].shape[0]
    rollout_attn    = onp.zeros(attention_maps.shape)
    I               = onp.eye(dim)
    rollout_attn[0] = 0.5*(attention_maps[0] + I)

    for ii, curr_attn in enumerate(attention_maps[1:]):
        rollout_attn[ii+1] = 0.5*(curr_attn + I) @ rollout_attn[ii]

    return rollout_attn

## Compute Attention

In [None]:
## load in all datasets
base_folder         = '/scratch/nb3397/projects/mips_project/dataset/transfer_learn/'
attn_N              = 4096*2
file_name           = f'phi_transfer/OU_v0=0.025_N={attn_N}_gamma=0.0001_phi=0.5_dt=0.1_beta=7.5_tspace=10.0_A=0.0_k=1.0_eps=0.0_3.npy'
attention_data_dict = pickle.load(open(f'{base_folder}/{file_name}', 'rb'))
attention_snapshot  = attention_data_dict['traj'][0]
xs, gs              = np.split(attention_snapshot, 2)

In [None]:
solid_inds     = onp.arange(attn_N)[find_solid_inds(xs)]
gas_inds       = onp.arange(attn_N)[find_gas_inds(xs)]
interface_inds = onp.setdiff1d(onp.arange(attn_N), onp.union1d(solid_inds, gas_inds))

In [None]:
particle_indices = [
    solid_inds[onp.random.randint(solid_inds.size)], 
    gas_inds[onp.random.randint(gas_inds.size)], 
    interface_inds[onp.random.randint(interface_inds.size)]
]
attn_maps, rollout_attn_maps, inds_arr = [], [], []


for particle_index in particle_indices:
    ## separate the parameters
    separate_params = {
        'g_embedding_params': {key: value for key, value in ema_params.items() if 'g_embedding' in key},
        'x_embedding_params': {key: value for key, value in ema_params.items() if 'x_embedding' in key},
        'trans_params'      : {key: value for key, value in ema_params.items() if 'transformer' in key},
    }


    ## unpack the parameters
    x_embedding_params = separate_params['x_embedding_params']
    g_embedding_params = separate_params['g_embedding_params']
    trans_params       = separate_params['trans_params']


    ## compute the nearest neighbors
    xi, gi = xs[particle_index], gs[particle_index]
    xdiffs = jax.vmap(lambda xj: drifts.wrapped_diff(xi, xj, attention_data_dict['width']))(xs)
    norms  = np.linalg.norm(xdiffs, axis=1)
    inds   = jax.lax.top_k(-norms, config['n_neighbors']+1)[1]
    inds_arr.append(inds)


    ## compute transformer input
    embedded_gi     = g_embedding.apply(g_embedding_params, gi)           # [embed_dim]
    embedded_xs     = x_embedding.apply(x_embedding_params, xdiffs[inds]) # [n_neighbors+1, embed_dim]
    embedded_gs     = g_embedding.apply(g_embedding_params, gs[inds])     # [n_neighbors+1, embed_dim]
    this_particle   = np.concatenate((embedded_gi, embedded_xs[0]))
    other_particles = np.hstack((embedded_gs[1:], embedded_xs[1:]))
    inp             = np.vstack((this_particle, other_particles))


    ## compute attention map
    attn_map         = np.array(get_transformer_attention.apply(trans_params, inp))
    rollout_attn_map = attention_rollout(np.mean(attn_map, axis=1))
    
    attn_maps.append(attn_map)
    rollout_attn_maps.append(rollout_attn_map)

In [None]:
def make_attn_plot(
    particle_indices: list,
    inds_arr: list,
    rollout_attn_maps: list,
    xs_plot: np.ndarray,
    width: float,
    s_fac: float,
    center_cluster: bool,
    extension: str,
    save_str: str
) -> None:
    ## set up basic plot parameters
    plt.close('all')
    sns.set_palette('deep')
    fw, fh    = 4, 4
    fontsize  = 12.0
    nrows     = 1
    ncols     = 1
    cmap      = sns.color_palette('mako', as_cmap=True)
    fig, ax  = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fw*ncols, fh*nrows),
        sharex=False, sharey=True, constrained_layout=True
    )

    ## basic axis setup
    ax.set_facecolor('white')
    ax.set_xlim([-width, width])
    ax.set_ylim([-width, width])
    ax.set_xticks([-0.5*width, 0, 0.5*width],
                  [r"$-L/4$", r"$0.0$", r"$L/4$"])
    ax.set_yticks([-0.5*width, 0, 0.5*width],
                  [r"$-L/4$", r"$0.0$", r"$L/4$"])
    ax.grid(which='both', axis='both', color='0.90', alpha=0.1)
    ax.tick_params(which='both', length=0, labelsize=fontsize)
    ax.axes.set_aspect(1.0)
    ax.set_ylabel(r"$y$", fontsize=fontsize)
    ax.set_xlabel(r"$x$", fontsize=fontsize)
    for spine in ax.spines.values():
        spine.set_edgecolor(None)


    ## plot background cluster
    s    = radius_to_points(ax, s_fac)
    scat = ax.scatter(xs_plot[:, 0], xs_plot[:, 1], s=s, marker='o', c='black', linewidths=0.5, alpha=1.0)

    
    ## plot attention maps
    for particle_index, neighbor_inds, rollout_attn_map in zip(particle_indices, inds_arr, rollout_attn_maps):
        attn = rollout_attn_map[-1][0]
        ax.scatter(xs[neighbor_inds, 0], xs[neighbor_inds, 1],   s=s, marker='o', c=attn, cmap=cmap, linewidths=0.5, zorder=1)

    if save_str != None:
        fig.patch.set_facecolor('k')
        fig.patch.set_alpha(0.0)
        plt.savefig(f'{save_str}_{config["network_type"]}_{cmap_str}_{index}.{extension}', dpi=600, bbox_inches='tight')

In [None]:
make_attn_plot(
    particle_indices=particle_indices,
    inds_arr=inds_arr,
    rollout_attn_maps=rollout_attn_maps,
    xs_plot=xs,
    width=attention_data_dict['width'],
    s_fac=0.5,
    center_cluster=True,
    extension='pdf',
    save_str=None
)