In [None]:
import ripser, persim, warnings
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from joblib import Parallel, delayed
from functools import partial
from pathlib import Path

## Function that returns bottleneck/wasserstein distance and optionally plots persistence diagram matchings

In [None]:
def getDist(feature, homologyGroup, matching=False, dgm=False):
  '''
  Takes a 2D array with columns x,y
  and integer for homology group
  TODO: accept 3D time series
  '''
                  # Line has same width and number of X points as feature...          but with y values at zero
  line = np.array([np.linspace(min(feature[:,0]),max(feature[:,0]),len(feature[:,0])),np.zeros_like(feature[:,1])])
  feature_dgm = ripser.ripser(feature)['dgms'][homologyGroup]
  line_dgm = ripser.ripser(line.T)['dgms'][homologyGroup]

  # Ignore warnings for dgm with non-finite death times
  with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message=r"dgm[12] has points with non-finite death times;ignoring those points")
    warnings.filterwarnings("ignore", message=r"invalid value encountered in subtract\n  l1_diff = abs(u - v)")
    if dgm:
      persim.bottleneck_matching(feature_dgm, line_dgm, persim.bottleneck( \
        feature_dgm, line_dgm, matching=True)[1], labels=['feature', 'line'])

    if matching:
      return feature_dgm, line_dgm, persim.bottleneck(feature_dgm, line_dgm, matching=True)[1]
    else:
      return persim.sliced_wasserstein(feature_dgm, line_dgm)

## Helper functions for bugfixing and parallelization

In [None]:
def setupBuff(data):
  '''
  Since we have periodic boundary conditions, we need copy the point cloud on
  each end of the domain before looping over the original middle section
  * time, timeRange are not included in this function *
  '''
  if data.ndim==2:
    x,y = data.T
    xRange = x.max() - x.min()
    lbuff, rbuff = data.copy(), data.copy()
    lbuff[:, 0] -= xRange
    rbuff[:, 0] += xRange
    buff = np.vstack((lbuff, data, rbuff))
  elif data.ndim==3:
    # Assumes dims are time,x,y
    x, y = data.reshape(data.shape[1]*data.shape[2],2).T
    xRange = x.max() - x.min()
    buff = np.zeros((data.shape[1]*data.shape[2],2))
    buffx,buffy = buff.T
    for t in range(data.shape[0]):
      lbuff, rbuff = data[t].copy(), data[t].copy()
      lbuff[:, 0] -= xRange
      rbuff[:, 0] += xRange
      buff[t*data.shape[1]:(t+1)*data.shape[1]] = np.vstack((lbuff, data[t], rbuff))
  #buffx,buffy = buff.T
  return x, y, xRange, buff#, buffx, buffy

                                        # 2D is default
def getWindow(pos, offset, buff, time=None, tPos=None, tOffset=None):
  buffx = buff.T[0]
  if buff.ndim==2:
    window = buff[(buffx>pos-offset) & (buffx<pos+offset)]
    # Try to prevent 1-cell disconnected components on the edges
    # This is a rare edge case - this mostly doesn't do anything
    # Doesn't seem to matter for homology groups greater than zero
    window[:,1][window[:,0]==window[0,0]] = window[0,1].min()
    window[:,1][window[:,0]==window[-1,0]] = window[-1,1].min()
  elif buff.ndim==3:
    raise ValueError("Windows for 3D time series not yet implemented")
    # TODO: figure out how to select 3D window
    window = buff[(buffx>pos-offset) & (buffx<pos+offset)]
    # Try to prevent 1-cell disconnected components on the edges
  return window

def distsAtPos(pos, offsets, homologyGroups, buff):
  '''
  Gets bottleneck distances at an alongshore position given window offsets, homology groups, and buffered data
  '''
  # print(pos, flush = True)
  # Each parallel job will return a row of bottleneck distances for all offsets and homology groups
  posDists = np.zeros((len(offsets), len(homologyGroups)))
  for offsetIdx, offset in enumerate(offsets):
    # TODO : 3D case
    window = getWindow(pos, offset, buff)
    for hgIdx,hg in enumerate(homologyGroups):
        posDists[offsetIdx,hgIdx] = getDist(window, hg)
  # Be careful of this indent! Otherwise it will return only the first offset
  return posDists

## Below are functions for precomputing the distance plots across timesteps and animating them
The animation should look like plot_heatmaps in the old notebook

To do:

* Refactor into:
  * "Pre-computation" function with parameters:
    * videoNum
    * noPonds
  * "animation" function with parameters:
    * npz filename
      * videoNum
      * noPonds
    * hspace
    * hg_to_animate
* I wish I could plot just one timestep with this method... not sure if any of the Gemini changes are problematic
  * Gemini uses set_title instead of colorbar label
  * not sure if I can/ how to animate the colorbar label...

* Verify hspace with plot_heatmap() before animating
* Fix order of yMin and yMax outputs, parameters and arguments in the old notebook
* Merge/Use sliding_windows or sliding_windows_par in precompute()

## Function that loops over data and stores bottleneck distances in a higher order array

In [None]:
def precompute(videoNum,noPondsStr="_noPonds"):
  '''
  If you wish to animate a different video, you can change `videoNum`
  noPondsStr = "_noPonds" # set to empty to use full point cloud
  '''

  fnBase = f"jgrf217-sup-000{videoNum+1}-ms0{videoNum}{noPondsStr}"
  print(fnBase)
  readin = np.load(f"data/{fnBase}.npz")
  thisCoast = [readin[key] for key in readin.keys()]
  readin.close()

  # Calculate maxY and minY across all timesteps in thisCoast for consistent scatter plot limits
  maxY = max([data_slice[:,1].max() for data_slice in thisCoast])
  minY = min([data_slice[:,1].min() for data_slice in thisCoast])

  # Calculate global min/max x for `thisCoast` to define constant x-axis for animation
  global_min_x = min([data_slice[:,0].min() for data_slice in thisCoast])
  global_max_x = max([data_slice[:,0].max() for data_slice in thisCoast])
  global_x_range = global_max_x - global_min_x

  # Use the same parameters for sliding_windows_par as in the last execution
  dx_val = 50
  maxWindow_val = 200
  step_val = 50
  minWindow_val = step_val # Default from sliding_windows_par
    
  # Define constant positions and offsets based on the global x-range
  positions = dx_val * np.arange(global_x_range // dx_val + 2) + global_min_x
  offsets = step_val * np.arange(1, maxWindow_val // step_val + 1) / 2
  offsets = offsets[offsets >= minWindow_val / 2]
  sizes = offsets * 2

  # Homology groups are constant (0 and 1 for 2D data)
  homology_groups = [0, 1]

  all_dists_per_timestep = []

# TODO: Check lengths of inputs and match case for loops with different partial funcs
  print("Starting pre-computation for animation...")
  for i, data_timestep in enumerate(thisCoast):
      print(f"Processing timestep {i+1}/{len(thisCoast)}", flush=True)

      # Setup buffer for the current timestep's data
      buff_i = setupBuff(data_timestep)[3]

      # Create a partial function for distsAtPos with fixed arguments (buff_i is specific to this timestep)
      distsAtPos_partial = partial(distsAtPos, offsets=offsets, homologyGroups=homology_groups, buff=buff_i)

      # Execute in parallel for each position
      allPosDists_i = Parallel(n_jobs=-1, backend='threading')(
          delayed(distsAtPos_partial)(pos) for pos in positions)

      # Reshape the results into the final dists array
      dists_i = np.array(allPosDists_i).transpose(2,1,0) # (num_homology_groups, num_offsets, num_positions)
      all_dists_per_timestep.append(dists_i)

  print("Pre-computation complete.")

  # Convert list to array for easier indexing
  # Shape: (num_timesteps, num_homology_groups, num_offsets, num_positions)
  all_dists_per_timestep = np.array(all_dists_per_timestep)
  return fnBase, all_dists_per_timestep, positions, sizes, homology_groups, maxY, minY

## Ideas for more plots

IF WE HAVE 3 HOMOLOGY GROUPS WE CAN MAKE RGB MOVIES OF THE BOTTLENECK DISTANCES CHANGING OVER TIME!!!

This will only work with a single window size for time at each timestep

For now, try animating plot_heatmap()

Can I increase step and dx to reduce the number of positions and offsets to save computation time?
I don't think so, they are about half the size of the patterns (nyquist frequency), so if step and dx were any larger we could see anomalies but not trends.
Actually, doubling step to 50 (4 sizes) works fine! 3 is fine too but much less impressive

I think the actual dx is 100m. Should probably go back and change in the video processing notebook and the data files. Or is it better to use 0-indexed integers for axes in math?

Are the data files even accurate? dy looks larger than dx. Need to check the ashton paper for dx,dy,dt and size of domain. If these data files are bad, need to regenerate from new model runs :( Probably won't do this for the class project...

Not in this notebook: Make a figure with 8 subplots for the pretty argmax plots stacked over the noPonds scatterplots of the final frame of each video

In [None]:
def animate(videoNum, noPondsStr="_noPonds", homologyGroups=None, hspace_val=-1, save=True): 
  '''
  Choose a videoNum and homology group to animate (e.g., 0 for connected components, 1 for loops)
  homologyGroups = [0,1] # Change this to 1 to animate for the other homology group
  noPondsStr = "_noPonds" # set to empty to use full point cloud
  hspace_val = -1 # vertical space between subplots - depends on yMax-yMin and len(sizes)
                  # Verify hspace with plot_heatmap() before animating
  '''
  if type(videoNum) is int:
      fnBase = f"jgrf217-sup-000{videoNum+1}-ms0{videoNum}{noPondsStr}"
      readin = np.load(f"data/{fnBase}.npz")
  else:
      fnBase = Path(videoNum).stem
  print(fnBase)

  thisCoast = [readin[key] for key in readin.keys()]
  readin.close()
  readin = np.load(f"data/{fnBase}_dists.npz")
  all_dists, positions, sizes, homologyGroups, maxY, minY = [readin[key] for key in readin.keys()]
  readin.close()

  # Setup the figure and initial plot
  fig, (ax1, ax2) = plt.subplots(2, 1)#, figsize=(10, 8))

  # Calculate constant x-extent for the heatmap and scatter plot
  dx = positions[1] - positions[0]
  x_extent = [positions[0] - dx / 2, positions[-1] + dx / 2]

  # Calculate extent for imshow (y-axis is constant, x-axis is now also constant)
  step_y_extent = sizes[1] - sizes[0] if len(sizes) > 1 else sizes[0] # Handle single window size case
  y_extent = [sizes[0] - step_y_extent, sizes[-1]]
  constant_extent = x_extent + y_extent

  if homologyGroups is None:
    homologyGroups = range(1)
  anims = [None]*len(homologyGroups)
  for hg_to_animate in homologyGroups:
    # TODO: modify plot_heatmap and run it here with a flag for animation outputs
    # Initial data for the first frame
    initial_dists = all_dists[0, hg_to_animate]
    initial_data = thisCoast[0]

    im = ax1.imshow(initial_dists, aspect='auto', origin='lower', extent=constant_extent,
                    vmin=0, vmax=all_dists[:, hg_to_animate].max())
    cb = fig.colorbar(im, ax=ax1, orientation='horizontal', label=f"Bottleneck Distance for Homology Group {homology_groups[hg_to_animate]} at Time {frame+1}",\
                      location='top', aspect=50)
    ax1.set_ylabel('Window Size')
    ax1.set_title(f'Homology Group {homologyGroups[hg_to_animate]} - Timestep 1')
    ax1.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=True,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False) # labels along the bottom edge are off

    sc = ax2.scatter(initial_data[:,0], initial_data[:,1], s=1)
    ax2.set_xlabel('Position')
    ax2.set_ylabel('Y')
    ax2.set_aspect('equal', adjustable='box')
    ax2.set_xlim(*x_extent) # Use the constant x-extent
    scatpad = (maxY - minY) / 20
    ax2.set_ylim([minY - scatpad, maxY + scatpad])

    # hspace_val = -1.1 # As used in printAndPlot
    plt.subplots_adjust(hspace=hspace_val)
    plt.tight_layout()

    # Update function for the animation
    def update(frame):
        # Get pre-computed data for the current frame
        dists_frame = all_dists[frame, hg_to_animate]
        data_frame = thisCoast[frame]

        # Update heatmap
        im.set_data(dists_frame)
        im.set_extent(constant_extent) # Use the constant extent

        # Update scatter plot
        sc.set_offsets(data_frame)
        ax2.set_xlim(*x_extent) # Ensure x-axis of scatter plot matches heatmap (constant)
        # ax1.set_title(f'Homology Group {homologyGroups[hg_to_animate]} - Timestep {frame + 1}')
        cb.set_label(f"Bottleneck Distance for Homology Group {homology_groups[hg_to_animate]} at Time {frame+1}")

        return [im, sc, cb]

    # Create the animation
    anim = FuncAnimation(fig, update, frames=len(thisCoast), interval=200, blit=True)

    class PillowWriterNG(PillowWriter):
        def finish(self):
                self._frames[0].save(
                    self.outfile, save_all=True, append_images=self._frames[1:],
                    duration=int(1000 / self.fps), loop=None)

    if save:
        anim.save(f"{fnBase}_H{hg_to_animate}_wasserstein.gif", writer="pillow")#PillowWriterNG())

    anims[hg_to_animate] = anim

  return anims

### Function that animates precomputed files as they appear (Run in background)

In [None]:
from watchfiles import watch
import threading

hspace_vals = [-1.1, -1.1, -1, -1, -1]

def watch_folder(folder="data"):
    for changes in watch(folder):
        #     1:New path  # TODO: Should also check it's not a directory
        if changes[0]==1 and changes[2].endswith(".npz"):
            for videoNumStr in reversed(changes[2]):
                if videoNumStr.isdigit(): break
            try:
                # Can't keep these outputs for display in stdout :(
                anims = animate(changes[2], hspace_val=hspace_vals[int(videoNumStr)-1])
            except Exception as e:
                warnings.warn(f"Error generating animation in background: {e}")
                
# You can run this watcher in a background thread so your notebook remains responsive:
watch_thread = threading.Thread(target=watch_folder, args=("data",), daemon=True)
watch_thread.start()

### Precompute files

In [None]:
for videoNum in range(1,6):
  for noPondsStr in ["", "_noPonds"]:
    fnBase, all_dists, positions, sizes, homologyGroups, maxY, minY = precompute(videoNum, noPondsStr)
    np.savez_compressed(f"data/{fnBase}_distsWasserstein.npz", all_dists, positions, sizes, homologyGroups, maxY, minY)

### Finally, run animations again (just to display them in the notebook)

In [None]:
%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"

# Already defined in watch_folder
# hspace_vals = [-1.1, -1.1, -1, -1, -1]
for hspaceIdx, videoNum in enumerate(range(1,6)):
  for noPondsStr in ["", "_noPonds"]:                                                 # Don't forget!
    anims = animate(videoNum, noPondsStr=noPondsStr, hspace_val=hspace_vals[hspaceIdx], save=False)
    # Display the animation (Colab will render this automatically if %matplotlib inline and jshtml is set)
    anims[0]
    plt.show()
    anims[1]
    plt.show()