In [None]:
import copy
import numpy as np
import os
import sys
import time
import types

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as manimation

from IPython.display import HTML, Video
from scipy.interpolate import CubicSpline

In [None]:
# "interactive" plots (instead of %matplotlib inline)
%matplotlib widget

In [None]:
# help with Illustrator/inkscape fonts
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['ps.fonttype'] = 42

# Notebook setup (path trick) and local import

In [None]:
PACKAGE_ROOT = os.path.dirname(os.path.abspath(''))
print(PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

In [None]:
from class_cellgraph import CellGraph
from class_singlecell import SingleCell
from dynamics_vectorfields import PWL_g_of_x, vectorfield_PWL3_swap
from dynamics_generic import ode_solve_ivp
from preset_solver import PRESET_SOLVER
from preset_cellgraph import PRESET_CELLGRAPH
from preset_network_layouts import *
from run_cellgraph import create_cellgraph
from settings import IS_RUNNING_ON_CLUSTER
from utils_io import run_subdir_setup, pickle_load
from utils_networkx import draw_from_adjacency

In [None]:
COLOR_CELL0 = '#66c2a5'
COLOR_X = '#ab329f'
COLOR_Y = '#c9af45'
COLOR_Z = '#33a5c3'

## Video Functions

In [None]:
def setup_ffmpeg_writer(fps=15, bitrate=None, title='Sample movie', artist='Matplotlib', comment='comment here'):
    # Define the meta data for the movie
    FFMpegWriter = manimation.writers['ffmpeg']
    metadata = dict(title=title, artist=artist, comment=comment)
    writer = FFMpegWriter(fps=fps, bitrate=bitrate, metadata=metadata)
    return writer

def vid_fpath_to_html(fpath):
    """
    - Show video using
        HTML(s)
    - Need to wrap it like this because it doesn't like abspath
    - Alternate method for absolute path:
        Video(fpath)
    """
    s = f"""
        <div align="middle">
        <video width="80%" controls>
              <source src={fpath.split("notebooks")[1][1:]} type="video/mp4">
        </video></div>"""
    return s

# Load or Perform CellGraph run
- if load: load cellgraph from specific input folder
- else: the generation block is similar to run_cellgraph.py

In [None]:
load_instead_of_generate = False

NB_OUTPUT = PACKAGE_ROOT + os.sep + 'notebooks' + os.sep + 'vis_cellgraph'
if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

In [None]:
preset_layout_choice = None
#preset_layout_choice = preset_layout_drosophila_M12  # choose None or pick from preset_network_layouts.py

In [None]:
if load_instead_of_generate:

    cellgraph_dir = PACKAGE_ROOT + os.sep + 'input' + os.sep + 'archive_cellgraph'
    #specific_dir = cellgraph_dir + os.sep + 'lacewingnegative_short_pulse'
    specific_dir = cellgraph_dir + os.sep + 'lacewing1_shortpulse'

    # load classdump pickle file from "specific dir"
    fpath = specific_dir + os.sep + 'classdump.pkl'
    cellgraph = pickle_load(fpath)

else:
    flag_preset = False
    subgraphIsoCheck = False  # if True - terminate loop when the graph can no longer reach a known structure
    flag_plotly_traj = False  # very slow currently; can be useful for debugging/trajectory inspection

    if flag_preset:

        # drosophila_oneshort, nasonia_vitripennis
        
        cellgraph_preset_choice = 'lacewing_R3'

        io_dict = run_subdir_setup(run_subfolder='cellgraph')
        solver_kwargs = PRESET_SOLVER['solve_ivp_radau_strict']['kwargs']
        solver_kwargs['vectorized'] = True

        cellgraph_preset = PRESET_CELLGRAPH[cellgraph_preset_choice]
        cellgraph_preset['io_dict'] = io_dict
        cellgraph_preset['verbosity'] = 1
        cellgraph = create_cellgraph(**cellgraph_preset)

    else:
        # High-level initialization & graph settings
        style_ode = 'PWL3_swap'                      # styles: ['PWL2', 'PWL3', 'PWL3_swap', 'Yang2013', 'toy_flow', 'toy_clock']
        style_detection = 'manual_crossings_1d_mid'  # styles: ['ignore', 'scipy_peaks', 'manual_crossings_1d_mid', 'manual_crossings_1d_hl', 'manual_crossings_2d']
        style_division = 'plus_minus_delta_ndiv_bam' # styles: check style_division_valid in settings.py
        style_diffusion = 'xy'                       # styles: ['all', 'xy']
        M = 1
        alpha_sharing = -0.004
        beta_sharing = 0.0
        diffusion_arg = 0.300
        
        verbosity = 0  # in 0, 1, 2 (highest)

        # Main-loop-specific settings
        add_init_cells = 0

        # Initialization modifications for different cases
        if style_ode == 'PWL2':
            state_history = np.array([[100, 100]]).T     # None or array of shape (NM x times)
        elif style_ode == 'PWL3_swap':
            state_history = np.array([[0, 0, 0]]).T  # None or array of shape (NM x times)
        else:
            state_history = None

        # Specify time interval which is separate from solver kwargs (used in graph_trajectory explicitly)
        t0 = 0
        t1 = 200
        timeintervalrun = 1
        # Prepare io_dict
        io_dict = run_subdir_setup(run_subfolder='cellgraph')

        # Instantiate the graph and modify ode params if needed
        cellgraph = CellGraph(
            num_cells=M,
            style_ode=style_ode,
            style_detection=style_detection,
            style_division=style_division,
            style_diffusion=style_diffusion,
            state_history=state_history,
            alpha_sharing=alpha_sharing,
            beta_sharing=beta_sharing,
            diffusion_arg=diffusion_arg,
            t0=t0,
            t1=t1,
            timeintervalrun=timeintervalrun,
            io_dict=io_dict,
            verbosity=verbosity)
        if cellgraph.style_ode in ['PWL2', 'PWL3', 'PWL3_swap']:
            #pass
            cellgraph.sc_template.params_ode['epsilon'] = 1e-2
            cellgraph.sc_template.params_ode['pulse_vel'] = 0.008375
            cellgraph.sc_template.params_ode['a1'] = 1.0
            cellgraph.sc_template.params_ode['a2'] = 0.25
            cellgraph.sc_template.params_ode['t_pulse_switch'] = 75.0

        # Add some cells through manual divisions (two different modes - linear or random) to augment initialization
        for idx in range(add_init_cells):
            dividing_idx = np.random.randint(0, cellgraph.num_cells)
            print("Division event (idx, div idx):", idx, dividing_idx)
            # Mode choice (divide linearly or randomly)
            cellgraph = cellgraph.division_event(idx, 0)  # Mode 1 - linear division idx
            # Output plot & print
            cellgraph.print_state()
            print()

        # Setup solver kwargs for the graph trajectory wrapper
        solver_kwargs = {}  
        solver_kwargs['method'] = 'Radau'
        solver_kwargs['t_eval'] = None      
        solver_kwargs['max_step'] = np.Inf  
        solver_kwargs['atol'] = 1e-8
        solver_kwargs['rtol'] = 1e-4

    # Write initial CellGraph info to file
    cellgraph.print_state(msg='initialized in visualize_cellgraph.ipynb')
    cellgraph.write_metadata()
    cellgraph.write_state(fmod='init')
    if not IS_RUNNING_ON_CLUSTER:
        cellgraph.plot_graph(fmod='init')

    # From the initialized graph (after all divisions above), simulate graph trajectory
    print('\nExample trajectory for the graph...')
    start_time = time.time()
    event_detected, cellgraph = cellgraph.wrapper_graph_trajectory(subgraphIsoCheck=subgraphIsoCheck, **solver_kwargs)
    print("--- %s seconds ---" % (time.time() - start_time))
    print("\n in main: num cells after wrapper trajectory =", cellgraph.num_cells)

    # Plot the timeseries for each cell
    cellgraph.plot_state_unified(arrange_vertical=True, fmod='final')
    if not IS_RUNNING_ON_CLUSTER:
        cellgraph.plot_graph(fmod='final')
        cellgraph.plot_graph(fmod='final', spring_seed=0, by_degree=True, by_ndiv=False, by_last_div=False, by_age=False)
        cellgraph.plot_graph(fmod='final', gviz_prog='dot', by_degree=True, by_ndiv=False, by_last_div=False, by_age=False)
        cellgraph.plot_graph(fmod='final', gviz_prog='circo', by_degree=True, by_ndiv=False, by_last_div=False, by_age=False)
        cellgraph.plot_graph(fmod='final', gviz_prog='twopi', by_degree=True, by_ndiv=False, by_last_div=False, by_age=False)
        cellgraph.plot_xyz_state_for_specific_cell(plot_cell_index=0, decorate=True)
    if cellgraph.sc_dim_ode > 1:
        cellgraph.plot_xy_separate(fmod='final')
    if flag_plotly_traj:
        cellgraph.plotly_traj(fmod='final', show=True, write=True)
    # Plot walltimes between division events
    cellgraph.plot_walltimes(fmod='final')

    # Save class state as pickle object
    cellgraph.pickle_save('classdump.pkl')

# Briefly inspect the run

In [None]:
# Shorthands
pp = cellgraph.sc_template.params_ode
times = cellgraph.times_history
state = cellgraph.state_history
state_tensor = cellgraph.state_to_rectangle(state)
div_events = cellgraph.division_events
birthdays = cellgraph.cell_stats[:, 2]

# Print method of CellGraph
cellgraph.print_state()

In [None]:
print(div_events)

In [None]:
cellgraph.sc_template.params_ode

# 0) Utility functions used multiple times below

In [None]:
def limit_cycle_geometry(cellgraph):
    """
    define geometry of oscillatory region (assuming single cell / weak coupling)
    """
    assert cellgraph.style_ode in ['PWL3_swap', 'PWL3_zstepdecay']
    pp = cellgraph.sc_template.params_ode

    xlow = 0.5 * pp['a1']
    xhigh = 0.5 * (pp['a1'] + pp['a2'])
    xmid = 0.5 * (xlow + xhigh)
    ylow = pp['a1'] - pp['a2']
    yhigh = pp['a1']
    xhit_topright = yhigh / 2.0 + pp['a2']
    xhit_botleft = ylow / 2.0;
    zlow = xlow + pp['gamma'] * pp['a1']
    zhigh = xhigh + pp['gamma'] * (pp['a1'] - pp['a2'])
    
    return xlow, xhigh, xmid, ylow, yhigh, xhit_topright, xhit_botleft, zlow, zhigh

In [None]:
def get_limit_cycle_line_segments(cellgraph):
    """
    Get 4 line segments defining the limit cycle as a dict of {i: {'A': start, 'B': end}}
    """
    # define the geometry of the limit cycle based on four corners
    xlow, xhigh, xmid, ylow, yhigh, xhit_topright, xhit_botleft, zlow, zhigh = \
        limit_cycle_geometry(cellgraph)
    c0 = np.array([xhit_botleft, ylow])
    c1 = np.array([xlow, yhigh])
    c2 = np.array([xhit_topright, yhigh])
    c3 = np.array([xhigh, ylow])

    line_segments_limitcycle = {
        0: {'A': c0, 'B': c1},
        1: {'A': c1, 'B': c2},
        2: {'A': c2, 'B': c3},
        3: {'A': c3, 'B': c0},
    }
    return line_segments_limitcycle

In [None]:
def slice_state_and_time_for_cell_given_tlims(cellgraph, cell_idx, tlims_idx=None):
    """
    Args:
        tlims_idx: None or 2-tuple of indices (integers) which will subsample the trajectory
    Returns:
        r_slice, times_slice
    Note:
        the state vector trajectory "r_slice" is returned as shape time x dim_ode
    """
    state_tensor = cellgraph.state_to_rectangle(cellgraph.state_history)  # dim_ode x cells x time
    times = cellgraph.times_history
    
    # time index at which cell appeared
    r = state_tensor[:, cell_idx, :].T
    birth_idx = cellgraph.cell_stats[cell_idx, 2]  
    
    if tlims_idx is None:
        times_slice = times[birth_idx:]
        r_slice = r[birth_idx:, :]
    else:
        # case 1: birthday BEFORE tlims (e.g. b=5 L=7 R=10  -> [7 10] - plot it fully)
        # case 2: birthday AFTER tlims  (e.g. L=7 R=10 b=20 -> []     - plot nothing; empty slice)
        # case 3: birthday within tlims (e.g. L=7 b=8 R=10  -> [8 10] - plot smaller slice)
        if birth_idx >= tlims_idx[1]:
            left_idx = tlims_idx[1]
            right_idx = tlims_idx[1]  # left and right same -- will lead to empty arrays, which we want
        else:
            left_idx = max(tlims_idx[0], birth_idx)
            right_idx = tlims_idx[1]
        times_slice = times[left_idx : right_idx]
        r_slice = r[left_idx : right_idx, :]
        
    return r_slice, times_slice

In [None]:
def number_of_progeny(cellgraph, cell_idx):
    """
    E.g. cell 0 has (total_cells - 1) progeny; the tips of the network have 0 progeny
    Note: more efficient to work backwards and build a dictionary from the tips
    """
    
    def list_of_daughters(cell_idx):
        return [i for i, val in enumerate(cellgraph.adjacency[cell_idx, :]) if val == 1.0 and i > cell_idx]
        
    def recurse(start_cell):
        daughters = list_of_daughters(start_cell)
        if len(daughters) >= 1:
            return len(daughters) + sum([recurse(i) for i in daughters])
        else:
            return 0
    return recurse(cell_idx)
    
#for idx in range(cellgraph.num_cells):
#    print(idx, '->', number_of_progeny(cellgraph, idx))

# 1) Visualize the (x,y,z) trajectory of the cell graph

### Implement 3d scatterplot variant with time or z as third axis
- option A: plot cell 0 trajectory then decorate
- option B: plot each cell as xyz trajectory?

In [None]:
def plot_cells_traj_and_marks_2d_or_3d(cellgraph, 
                                       ax=None,
                                       cells_to_traj=[0], tlims_traj=None,
                                       cells_to_mark='all', times_to_mark=[], 
                                       mode='xy', traj_color_style='cell', quiver=True, 
                                       decorate=True, show_zplanes=True, show_legend=True,
                                       proj2d=True, verbose=False, fmod='', save=True, show=False):
    """
    Given a cellgraph instance and a time value, make a scatter plot in {x,y} space of all cells at t = time_value
    Notes:
        - if time_value < birthday for a particular cell, no marker will be plotted for that cell
    Overview:
        (1) make plot of first cells trajectory
        (2) for that cell, and any other cells, plot its position at time "t" with a marker
    Args:    
        cells_to_traj: list of int in [0, num_cells], usually [0] for cell 0 only
        tlims_traj: times for which the trajectories will be plotted
        cells_to_mark: list, len == k, of int in [0, num_cells]
        times_to_mark: str: 'all', or list of len == k, of time values (float)
        mode='xy':     induces plot style, must be in ['xy', 'xyz', 'xyt']
        traj_color_style: color by 'cell' index (qualitative, pastel) or by 'time' (quantitative, spectral)
    """
    state_tensor = cellgraph.state_to_rectangle(cellgraph.state_history)  # dim_ode x cells x time
    times = cellgraph.times_history
    birthdays = cellgraph.cell_stats[:, 2]  
    num_cells = cellgraph.num_cells
    
    # (0) Input processing and asserts    
    # 0.1 - create "idx_time_values" (list of indices) corresponding to times_to_mark (list of float)
    if tlims_traj is None:
        tlims_traj = [times[0], times[-1]]
        tlims_idx_traj = [0, len(times)]
        full_tlims_traj = True
    else:
        assert tlims_traj[0] <= tlims_traj[1] and len(tlims_traj) == 2
        assert times[0] <= tlims_traj[0] <= times[-1]
        assert times[0] <= tlims_traj[1] <= times[-1]
        tlims_idx_traj = [np.searchsorted(times, tlims_traj[0]),
                          np.searchsorted(times, tlims_traj[1])]
        full_tlims_traj = False
        
    idx_time_values = [0] * len(times_to_mark)
    for mark_idx, tval in enumerate(times_to_mark):
        assert times[0] <= tval <= times[-1]
        # find the closest position in the times array to the specified time_value
        idx_time_values[mark_idx] = np.searchsorted(times, tval)

    # 0.2 - note: we have two methods to specify "cells_to_mark" -- 'all' or explicit list
    # - if 'all', generate cells_to_mark, expand times_to_mark along with idx_time_values
    if cells_to_mark == 'all':
        cells_to_mark = [ele for ele in range(num_cells) for _ in range(len(times_to_mark))]
        times_to_mark = times_to_mark * num_cells
        idx_time_values = idx_time_values * num_cells
    else:
        assert len(cells_to_mark) == len(times_to_mark)
    num_marks = len(times_to_mark)

    for cell_idx in cells_to_traj:
        assert 0 <= cell_idx <= num_cells
    
    # 0.3 - plot asserts
    assert traj_color_style in ['time', 'cell']
    
    # (1) Prepare plotting
    assert mode in ['xy', 'xyz', 'xyt']
    dim_plot = len(mode)
    # plot constants
    traj_alpha = 1 #0.5
    star_kwargs = dict(
        marker='*', edgecolors='k', linewidths=0.5, zorder=10, c='green',
    )
    traj_cmap = {'time': plt.get_cmap('Spectral'),
                 'cell': plt.get_cmap('Set2')
                }[traj_color_style]
    traj_ln_kwargs = dict(linewidth=3.0, alpha=traj_alpha, color=COLOR_CELL0)
    proj2d_sc_kwargs = dict(alpha=1.0, marker='o')
    proj2d_ln_kwargs = dict(linewidth=1.0, color='k')
    fs = 14
    
    # plot settings that depend on dim
    if ax is None:
        figsize = (8, 8)
        fig = plt.figure(figsize=figsize)  # initiate figure
        if dim_plot == 2:
            ax = plt.gca()
        else:
            ax = fig.add_subplot(projection='3d')
    else:
        fig = plt.gcf()
        
    def mode_plot_args(r_slice, t_slice):
        xt, yt, zt = r_slice[:, 0], r_slice[:, 1], r_slice[:, 2]
        
        # these are for "quiver" plotting of traj arrows
        # TODO try midpoints instead
        weighted_start = 0.5  # default: 0
        weighted_shrink = 0.5 # default: 1
        x0 = (xt[:-1]     + weighted_start * xt[1:]) / (1 + weighted_start)
        y0 = (yt[:-1]     + weighted_start * yt[1:]) / (1 + weighted_start)
        z0 = (zt[:-1]     + weighted_start * zt[1:]) / (1 + weighted_start)
        t0 = (t_slice[:-1] + weighted_start * t_slice[1:]) / (1 + weighted_start)
        u0 = weighted_shrink * (xt[1:] - xt[:-1])
        v0 = weighted_shrink * (yt[1:] - yt[:-1])
        w0 = weighted_shrink * (zt[1:] - zt[:-1])
        q0 = weighted_shrink * (t_slice[1:] - t_slice[:-1])
        
        if mode == 'xy':
            traj_args = (xt, yt)
            quiver_args = (x0, y0, u0, v0)
        elif mode == 'xyz':
            traj_args = (xt, yt, zt)
            quiver_args = (x0, y0, z0, u0, v0, w0)
        else:
            assert mode == 'xyt'
            traj_args = (xt, yt, t_slice)
            quiver_args = (x0, y0, t0, u0, v0, q0)
        return traj_args, quiver_args
    
    def get_traj_c(cell_idx, times_slice):
        if traj_color_style == 'time':
            traj_c = times_slice
        else: 
            assert traj_color_style == 'cell'
            traj_c = traj_cmap(cell_idx)
        return traj_c 
    
    traj_kwargs = dict(alpha=traj_alpha, cmap=traj_cmap)
    if traj_color_style == 'time':
        traj_kwargs['vmin'] = 40.0
        traj_kwargs['vmax'] = 80.0
    
    # Plotting is performed in this central loop
    for idx in range(num_cells):        
        r = np.transpose(state_tensor[:, idx, :])  # time x dim_ode
        # start at different points for each cell (based on "birthday")
        init_idx = birthdays[idx]
        r_slice, times_slice = slice_state_and_time_for_cell_given_tlims(cellgraph, idx, tlims_idx=tlims_idx_traj)
        
        # different plotting arguments will depend on "mode"
        traj_args, quiver_args = mode_plot_args(r_slice, times_slice)
        
        # (2) Plot trajectories of each cell idx in cells_to_traj
        if idx in cells_to_traj:
            # Trajectory plotting: two calls -- (A) scatter to plot timepoints, (B) draw line between pts
            traj_c = get_traj_c(idx, times_slice)
            #sc = ax.scatter(*traj_args, c=traj_c, cmap=traj_cmap, alpha=traj_alpha, label='cell %d' % idx)  # draw scatter points
            #sc = ax.scatter(*traj_args, **traj_kwargs, c=traj_c, label='cell %d' % idx)  # draw scatter points
            ax.plot(*traj_args, '-', **traj_ln_kwargs,
                   label='cell %d trajectory' % idx)  # draw connections between points

            if quiver:
                # NOTE: quiver has different behaviour in 2D or 3D
                # - 2d: https://matplotlib.org/2.0.2/api/pyplot_api.html
                # - 3d: https://matplotlib.org/stable/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html#mpl_toolkits.mplot3d.axes3d.Axes3D.quiver
                if dim_plot == 2:
                    ax.quiver(*quiver_args, color='k', alpha=traj_alpha,
                              angles='xy', scale=1, scale_units='xy')
                              #width=0.01, headwidth=5)  # these don;t work great and don't apply to 3D?
                else:
                    assert dim_plot == 3
                    # TODO need to fix arrow head size too big when trajectory steps are large
                    ax.quiver(*quiver_args, color='k', alpha=traj_alpha, arrow_length_ratio=0.2)
            
            # plot the time cbsr for ONLY the earliest cell being plotted
            if traj_color_style == 'time' and idx == min(cells_to_traj):
                cbar = fig.colorbar(sc)
                #cbar.ax.set_ylim(times[0], times[-1])  # full extents
                cbar.set_label(r'$t$')
            
            # if 3d (xyz or xyt), can also visualize 2d projection of xyz onto xy, z=0 plane
            if proj2d and dim_plot == 3:
                xt, yt, ut = traj_args  # "ut" can be time or z, depending on mode arg
                zvec = np.zeros_like(xt)
                # plot xy-proj
                #sc = ax.scatter(xt, yt, zvec, color='gray', **proj2d_sc_kwargs)  # draw scatter points
                ax.plot(xt, yt, zvec, '-', **proj2d_ln_kwargs)   # draw connections between points
                # plot xu-proj
                #sc = ax.scatter(xt, zvec, ut, color='gray', **proj2d_sc_kwargs) # draw scatter points
                #ax.plot(xt, zvec, ut, '-k', **proj2d_ln_kwargs)  # draw connections between points                
                # plot yu-proj
                #sc = ax.scatter(zvec, yt, ut, color='gray', **proj2d_sc_kwargs) # draw scatter points
                #ax.plot(zvec, yt, ut, '-k', **proj2d_ln_kwargs)  # draw connections between points
                
    # (3) Mark all points explicitly specified in the paired lists: time_to_mark, cell_to_mark
    for mark_idx, tval in enumerate(times_to_mark): 
        cell_to_mark = cells_to_mark[mark_idx]
        idx_tval = idx_time_values[mark_idx]
        cell_birthidx = birthdays[cell_to_mark]
        # perform the "mark" only if the cell exists at that time
        if idx_tval >= cell_birthidx:
            xmark = state_tensor[0, cell_to_mark, idx_tval]
            ymark = state_tensor[1, cell_to_mark, idx_tval]
            zmark = state_tensor[2, cell_to_mark, idx_tval]
            args_mark_from_mode = {'xy': (xmark, ymark),
                                   'xyz': (xmark, ymark, zmark),
                                   'xyt': (xmark, ymark, tval)}
            ax.scatter(*args_mark_from_mode[mode], **star_kwargs)
        else:
            if verbose:
                print('Warning: skipping mark on cell %d at time %.2f -- cell not yet born' % (cell_to_mark, tval)) 
    # TODO consider also marking birthdays on the plot, in similar way to above
    
    if decorate:
        assert cellgraph.style_ode in ['PWL3_swap', 'PWL3_zstepdecay']
        # define geometry of oscillatory region (assuming single cell / weak coupling)
        xlow, xhigh, xmid, ylow, yhigh, xhit_topright, xhit_botleft, zlow, zhigh = limit_cycle_geometry(cellgraph)
        if dim_plot == 2:
            ax.axvline(xlow, linestyle='--', c='gray', label=r'$x_{low}$')
            ax.axvline(xhigh, linestyle='--', c='gray', label=r'$x_{high}$')
            ax.axhline(ylow, linestyle='-.', c='gray', label=r'$y_{low}$')
            ax.axhline(yhigh, linestyle='-.', c='gray', label=r'$y_{high}$')
        else:
            assert dim_plot == 3 and mode == 'xyz'
            # (A) planes: z-low and z-high
            if show_zplanes:
                alpha_zplane = 0.25
                xx, yy = np.meshgrid(np.linspace(xhit_botleft, xhit_topright, 100), np.linspace(ylow, yhigh, 100))
                zz_low = (zlow - 0*xx - 0*yy) / 1.0
                zz_high = (zhigh - 0*xx - 0*yy) / 1.0
                sA1 = ax.plot_surface(xx, yy, zz_low, alpha=alpha_zplane, color='gray', label=r'$z_{low}$')
                sA2 = ax.plot_surface(xx, yy, zz_high, alpha=alpha_zplane, color='gray', label=r'$z_{high}$')
                for s in [sA1, sA2]:
                    s._edgecolors2d = s._edgecolor3d
                    s._facecolors2d = s._facecolor3d

            # (B) plane: poincare section defined by midpoint in x amplitude
            alpha_poincare = 0.1  #1.0 for testing zorder
            ybuff = 0.25*ylow
            yy, zz = np.meshgrid(np.linspace(ylow-ybuff, yhigh+ybuff, 2), np.linspace(0.75*zlow, 1.25*zhigh, 2))
            xx_mid = (xmid - 0*yy - 0*zz) / 1.0
            sB = ax.plot_surface(xx_mid, yy, zz, 
                                 alpha=alpha_poincare, color='gray', 
                                 zorder=5)          
            sBedges = ax.plot_wireframe(xx_mid, yy, zz, 
                                      alpha=0.0, facecolor=None, 
                                      edgecolor='k', linewidth=2.0, linestyle='--',
                                      zorder=20)     
            poincare_legend_decoy = ax.plot([], [], '--k', linewidth=2.0,
                                            label=r'Poincaré section')
            # set alpha to 0.0 in sB-edges wireframe plot above; replot as 4 lines for custom zorder/depth
            ya, yb = yy[0,0], yy[-1, -1]
            za, zb = zz[0,0], zz[-1, -1]
            ax.plot([xmid, xmid], [ya, ya], [za, zb], '--k', linewidth=2.0, zorder=20)  # front, up z
            ax.plot([xmid, xmid], [yb, yb], [za, zb], '--k', linewidth=2.0, zorder=5)   # back, up z
            ax.plot([xmid, xmid], [ya, yb], [za, za], '--k', linewidth=2.0, zorder=5)   # lower, move y depth
            ax.plot([xmid, xmid], [ya, yb], [zb, zb], '--k', linewidth=2.0, zorder=20)  # upper, move y depth
            # ongoing hack for matplotlib legend to work with 3d surface
            for s in [sB]:
                s._edgecolors2d = s._edgecolor3d
                s._facecolors2d = s._facecolor3d
            
            if proj2d:
                # (C) plot line on xy plane representing poincare section projected to 2D
                ax.plot([xmid, xmid], [yy[0, 0], yy[-1, -1]], [0, 0], '--k', linewidth=1.5, zorder=20)
               
    # (4) Plot axes labels and colorbar etc
    #plt.title('Trajectory of cells: %s, mode: %s' % (cells_to_traj, mode), fontsize=14)
    ax.set_xlabel(r'$x$', fontsize=fs)
    ax.set_ylabel(r'$y$', fontsize=fs)
    if dim_plot == 3:
        ax.set_zlabel(r'$%s$' % mode[2], fontsize=fs)  # this should be 'z' or 't'
    ax.xaxis.set_tick_params(labelsize=fs)
    ax.yaxis.set_tick_params(labelsize=fs)
    ax.zaxis.set_tick_params(labelsize=fs)
    
    # colorbar
    #cbar.set_clim(times[0], times[-1])
    if traj_color_style == 'cell' and show_legend:
        plt.legend(fontsize=fs)

    view_ele = 23
    view_azi = 260
    ax.view_init(view_ele, view_azi)
        
    # (5) Save and show plot
    if save:
        fpath = NB_OUTPUT + os.sep + 'trajmark_%s' % mode
        if fmod is not None:
            fpath += fmod
        plt.savefig(fpath + '.pdf')
        print('plot saved to...', fpath + '.pdf')
    if show:
        plt.show()
    #plt.close()
    return ax

In [None]:
%matplotlib widget

#plot_cells_traj_and_marks_2d_or_3d(cellgraph, 80, quiver=True, decorate=True, fmod=None, show=True)
plot_cells_traj_and_marks_2d_or_3d(
    cellgraph, 
    cells_to_traj=[0], tlims_traj=None,  # cells_to_mark='all', times_to_mark=[0,72.67], 
    cells_to_mark=[], times_to_mark=[], 
    mode='xyz', 
    traj_color_style='cell', 
    proj2d=True, 
    show_zplanes=False,
    quiver=False,
    decorate=True,
    show=True) 

In [None]:
# note: this functions is not really used rn
def plot_two_cell_compare(cellgraph, idx_c1, idx_c2):
    """
    Compare components of state vector timeseries across two cells; 
    if they are in sync it will just be a diagonal line zigzagging between low/high amplitude
    """
    state_tensor = cellgraph.state_to_rectangle(cellgraph.state_history)
    times = cellgraph.times_history
    birthdays = cellgraph.cell_stats[:, 2]

    fig, axarr = plt.subplots(3, 3, figsize=(8,4))
    fs = 14
    fixed_cmap = plt.get_cmap('Spectral')
    star_kwargs = dict(
        marker='*', edgecolors='k', linewidths=0.5, zorder=10,
    )

    idx_time_start = max(birthdays[idx_c1], birthdays[idx_c2])
    times_slice = times[idx_time_start:]

    print(state_tensor.shape)
    r1 = np.transpose(state_tensor[:, idx_c1, idx_time_start:])
    r2 = np.transpose(state_tensor[:, idx_c2, idx_time_start:])

    coords = [r'$x$', r'$y$', r'$z$']
    for sidx in range(3):
        coord = coords[sidx]
        axarr[0, sidx].plot(r1[:,sidx], r2[:, sidx])
        axarr[0, sidx].set_xlabel('%s cell %d' % (coord, idx_c1), fontsize=fs)
        axarr[0, sidx].set_ylabel('%s cell %d' % (coord, idx_c2), fontsize=fs)

        axarr[1, sidx].plot(times_slice, r1[:, sidx])
        axarr[1, sidx].set_xlabel(r'$t$', fontsize=fs)
        axarr[1, sidx].set_ylabel('%s cell %d' % (coord, idx_c1), fontsize=fs)

        axarr[2, sidx].plot(times_slice, r2[:, sidx])
        axarr[2, sidx].set_xlabel(r'$t$', fontsize=fs)
        axarr[2, sidx].set_ylabel('%s cell %d' % (coord, idx_c2), fontsize=fs)

    plt.subplots_adjust(wspace=0.5, hspace=1)
    plt.show()

In [None]:
plot_two_cell_compare(cellgraph, 0, 1)

In [None]:
def plot_xyz_vs_t_one_cell(cellgraph, cell_idx, ax=None, draw_divisions=True, fill_z=True, save=True, show=False):
    if ax is None:
        fig = plt.figure(figsize=(6,3))
        ax = plt.gca()
    else:
        fig = plt.gcf()
    fs = 14
    fs_legend = 14
        
    state_tensor = cellgraph.state_to_rectangle(cellgraph.state_history)  # dim_ode x cells x time
    times = cellgraph.times_history
    birthdays = cellgraph.cell_stats[:, 2]
    div_events = cellgraph.division_events

    r_slice, times_slice = slice_state_and_time_for_cell_given_tlims(cellgraph, cell_idx, tlims_idx=None)
    xt, yt, zt = r_slice[:, 0], r_slice[:, 1], r_slice[:, 2]
    
    #ax.plot(times_slice, xt, '-o', label=r'$x$', color=COLOR_X, zorder=11, linewidth=0.6, markersize=0.5)
    #ax.plot(times_slice, yt, '-o', label=r'$y$', color=COLOR_Y, zorder=11, linewidth=0.6, markersize=0.5)
    #ax.plot(times_slice, zt, '-o', label=r'$z$', color=COLOR_Z, zorder=11, linewidth=0.6, markersize=0.5)
    ax.plot(times_slice, xt, '-', label=r'$x$', color=COLOR_X, zorder=11, linewidth=0.75)
    ax.plot(times_slice, yt, '-', label=r'$y$', color=COLOR_Y, zorder=11, linewidth=0.75)
    ax.plot(times_slice, zt, '-', label=r'$z$', color=COLOR_Z, zorder=11, linewidth=0.75)

    if draw_divisions:
        for div_idx in range(len(div_events)):
            mother, daughter, t_idx = div_events[div_idx]
            if mother == cell_idx:    
                tval = cellgraph.times_history[t_idx]
                ax.axvline(tval, linestyle='--', linewidth=1.2, c='k', zorder=2)

    if fill_z:
        xlow, xhigh, xmid, ylow, yhigh, xhit_topright, xhit_botleft, zlow, zhigh = limit_cycle_geometry(cellgraph)
        # Style 1
        #ax.fill_between(times_slice, zlow, zt, where=zt > zlow, facecolor ='green',alpha = 0.1, zorder=10,
        #                label='Oscillatory region')
        # Style 2: full band
        ax.fill_between(times_slice, zlow, zhigh, facecolor ='green',alpha = 0.1, zorder=10,
                        label='Oscillatory \nregime')
                
    ax.set_xlabel(r'Time since pulse initiation', fontsize=fs)
    ax.set_ylabel(r'State of cell %d' % cell_idx, fontsize=fs)
    ax.xaxis.set_tick_params(labelsize=fs)
    ax.yaxis.set_tick_params(labelsize=fs)
    plt.xticks(fontsize=fs)
    plt.yticks(fontsize=fs)
    #ax.tick_params(fs=fs)
    
    ax.legend(fontsize=fs_legend)
    #ax.grid(color='gray', alpha=0.5)
    ax.grid(c='#e5e5e5', zorder=1)
    
    ax.set_xlim(0, 152)
    #ax.set_ylim(3, 7)
    
    plot_zlines = True
    if plot_zlines:
        m = 2
        a1 = cellgraph.sc_template.params_ode['a1']
        a2 = cellgraph.sc_template.params_ode['a2']
        zL = (1 + cellgraph.sc_template.params_ode['gamma'] * m) / 2
        ax.axhline(a1 * zL)
        zR = zL + (a2/a1)/2*(1 -  cellgraph.sc_template.params_ode['gamma'] * m)
        ax.axhline(a1 * zR)
    
    if save:
        fpath = NB_OUTPUT + os.sep + 'traj_xyz_vs_t_%d' % cell_idx
        plt.savefig(fpath + '.svg')
        plt.savefig(fpath + '.pdf')
        print('plot saved to...', fpath + '.pdf')
    if show:
        plt.show()
    
    return ax

In [None]:
%matplotlib inline

ax = plot_xyz_vs_t_one_cell(cellgraph, 0, draw_divisions=True, fill_z=False, save=True, show=True)

In [None]:
ax = plot_xyz_vs_t_one_cell(cellgraph, 3, draw_divisions=True, fill_z=False, save=True, show=True)

# 2) Visualize the cell graph as a network
### Plot networks before/after divisions events
### TODO - make plot of limit cycle in 2D, colored by phase1 (or x etc)
this could act like a circular colorbar

In [None]:
moment_after_division = list(birthdays)
moment_before_division = [0] + [i-1 for i in birthdays[1:]]

print(moment_before_division)
print(moment_after_division)

In [None]:
def regenerate_adjacency_timeseries_from_cellgraph(cellgraph):
    """
    - we don't store adjacency timeseries in cellgraph attirbutes
    - it can however be regenerated from attribute: cell_stats, division_events
        - self.cell_stats: arr (M x 3); stores cell metadata: [n_div, time_idx_last_div, time_idx_birth]
        - self.division_events: arr (d x 3); for each division event, append row: [mother_idx, daughter_idx, time_idx] 
    """
    assert cellgraph.num_cells == cellgraph.division_events.shape[0] + 1  # i.e. must start from single cell graph
    
    def update_adj(current_adj, updated_num_cells, idx_dividing_cell):
        updated_adj = np.zeros((updated_num_cells, updated_num_cells))
        updated_adj[0:updated_num_cells - 1, 
                    0:updated_num_cells - 1] = current_adj
        # add new row/column with index corresponding k to the generating cell
        # - i.e. for input i = idx_dividing_cell, set A[i, k] = 1 and A[k, i] = 1
        updated_adj[idx_dividing_cell, -1] = 1
        updated_adj[-1, idx_dividing_cell] = 1
        return updated_adj
    
    current_adj = np.array([[0]])
    current_num_cells = 1
    list_of_adjacencies = [current_adj] * cellgraph.num_cells
    
    for event_idx in range(0, cellgraph.num_cells - 1):

        mother_idx, daughter_idx, time_idx = cellgraph.division_events[event_idx, :]
        # expand the adjacency matrix
        expanded_adj = update_adj(current_adj, current_num_cells + 1, mother_idx) # note event_idx == current_num_cells
        
        # update the list of adj matrices
        list_of_adjacencies[event_idx+1] = expanded_adj
        current_adj = np.copy(expanded_adj)
        current_num_cells += 1

    
    return list_of_adjacencies

In [None]:
def get_closest_point_line_segment(pvec, avec, bvec):
    """
    Given point p, find closest point on the line segment connecting avec, bvec
    - idea: closest point either extremum or normal vector to line segment (compare all 3)

    Consider eqn for line V(t) = A*(1-t) + B*t -- segment defined only for 0 <= t <= 1
    - we find projection onro line and "t* value"
    - if t* in [0,1] then return the projection info, otherwise an extrema (A or B) is returned
    """
    dA = np.linalg.norm(pvec - avec)
    dB = np.linalg.norm(pvec - bvec)

    # edge case detect - horizontal segment
    if avec[1] == bvec[1]:
        avec[0] != bvec[0]

        dist_pvec_to_infline = np.abs(pvec[1] - avec[1])
        proj_coord = np.array([pvec[0], avec[1]])
        t_value = (proj_coord[0] - avec[0]) / (bvec[0] - avec[0])

    # edge case detect - vertical segment
    elif avec[0] == bvec[0]:
        dist_pvec_to_infline = np.abs(pvec[0] - avec[0])
        proj_coord = np.array([avec[0], pvec[1]])
        t_value = (proj_coord[1] - avec[1]) / (bvec[1] - avec[1])

    # general case
    else:

        # intermediate case -- point on interior of finite line segment
        x0, y0 = pvec
        x1, y1 = avec
        x2, y2 = bvec
        norm_segment = np.linalg.norm(bvec - avec)
        unitvec_line = (bvec - avec) / norm_segment
        unitvec_norm = np.array([-unitvec_line[1], unitvec_line[0]])

        # - step 1
        dist_pvec_to_infline = np.abs((x2 - x1) * (y1 - y0) - (x1 - x0) * (y2 - y1)) / norm_segment

        # - step 2 - choose projection coord based on two cases (i.e. sign of our normal vector)
        proj_coord_plus = pvec  + dist_pvec_to_infline * unitvec_norm
        proj_coord_minus = pvec - dist_pvec_to_infline * unitvec_norm
        # check which one of the two options is on the line
        #   how: compare dist to line using above formula
        dist_plus = np.abs((x2 - x1) * (y1 - proj_coord_plus[1]) - (x1 - proj_coord_plus[0]) * (y2 - y1))
        dist_minus = np.abs((x2 - x1) * (y1 - proj_coord_minus[1]) - (x1 - proj_coord_minus[0]) * (y2 - y1))
        if dist_minus > dist_plus:
            proj_coord = proj_coord_plus 
            assert dist_plus <= 1e-6
        else:
            proj_coord = proj_coord_minus
            assert dist_minus <= 1e-6

        # - step 3
        t_value = np.linalg.norm(proj_coord - avec) / norm_segment
        plus_sign_case = np.linalg.norm(t_value*bvec + (1 - t_value)*avec - proj_coord)
        minus_sign_case = np.linalg.norm(-t_value*bvec + (1 + t_value)*avec - proj_coord)

        if plus_sign_case > 1e-6:
            t_value = -t_value
            assert minus_sign_case <= 1e-6

        assert np.abs(dist_pvec_to_infline - np.linalg.norm(pvec - proj_coord)) <= 1e-6

    assert np.linalg.norm(t_value*bvec + (1 - t_value)*avec - proj_coord) <= 1e-6
    assert np.abs(dist_pvec_to_infline - np.linalg.norm(pvec - proj_coord)) <= 1e-6

    if t_value <= 0.0:
        return dA, avec, 0.0
    elif t_value >= 1.0:
        return dB, bvec, 1.0
    else:
        return dist_pvec_to_infline, proj_coord, t_value

def get_closest_point_limit_cycle(x0, y0, z0):
    """
    Get min of { min(dist_A, dist_B, dist_normal) } over the four line segments defining the limit cycle 
    """
    pvec = np.array([x0, y0])
    line_segments_limitcycle = get_limit_cycle_line_segments(cellgraph)
    
    segment_coord = np.zeros(2)
    segment_index = None
    dist_to_segment = 1e9
    for idx in range(4):
        avec = line_segments_limitcycle[idx]['A']
        bvec = line_segments_limitcycle[idx]['B']
        dist, loc_vec, loc_scalar = get_closest_point_line_segment(pvec, avec, bvec)
        #print(dist, loc_vec, loc_scalar)
        if dist < dist_to_segment:
            dist_to_segment = dist
            segment_index = idx
            segment_scalar = idx + loc_scalar

    return segment_coord, segment_scalar, segment_index, dist_to_segment


In [None]:
def plot_networks_build_cellgraph(cellgraph, time_indices, 
                                  ax=None,
                                  node_color='x', 
                                  unified_colorbar=False, 
                                  clims_global=None,
                                  layout_fn=None,
                                  gviz_prog='dot', 
                                  vis_directed=True, 
                                  fmod='',
                                  stack_vertical=False,
                                  save=True, 
                                  show=True):
    """
    Args:
        ax: [default: None]
            (A) - ax is None - elaborate but static figure with subpanels
            (B) - ax given   - simple animated figure of one network (assert len(time_indices) == 1 and not unified_colorbar)
        node_color in ['fixed', 'degree', 'phase1', 'x', 'y', 'z']
        layout_fn: function that takes adjacency A, returns pos->(x,y) dict and xlims, ylims
        unified_colorbar: make single unfiied colorbar for all the ax plots; place it below the network
        clims_global: None or 2-tuple, pass clims to networkx plot fn
            note: if unified_colorbar, clims_global will be inferred if not explicitly specified
    """
    list_of_adjacencies = regenerate_adjacency_timeseries_from_cellgraph(cellgraph)
    birthdays = cellgraph.cell_stats[:, 2]
    state_tensor = cellgraph.state_to_rectangle(cellgraph.state_history)  # dim_ode x cells x time

    num_plots = len(time_indices)
    
    if unified_colorbar:
        
        if stack_vertical:
            cbar_col_ratio = 0.1
            gridspec_kw = {'width_ratios': [1, cbar_col_ratio]}
            figsize = (2 + 2 * cbar_col_ratio, num_plots*2) 
            ncols = 2
            nrows = num_plots
            orientation = 'vertical'
        else:            
            cbar_row_ratio = 0.1
            gridspec_kw = {'height_ratios': [1, cbar_row_ratio]}
            figsize = (num_plots*2, 2 + 2 * cbar_row_ratio) 
            nrows = 2
            ncols = num_plots
            orientation = 'horizontal'
        
        if clims_global is None:
            print('Note: unified_colorbar=True but clims_global not specified; auto-generating...')
            if node_color == 'fixed':
                clow = 1
                chigh = 1
            elif node_color == 'degree':
                # degree auto-range set to 1, max(degree_vec) of final graph
                clow = 1
                chigh = max(cellgraph.degree)
            elif node_color == 'phase1':
                clow = 0.0
                chigh = 4.0
            else:
                # take max of state value over all times in cellgraph.state_history
                assert node_color in ['x', 'y', 'z']
                state_idx = {'x': 0, 'y': 1, 'z': 2}[node_color]
                clow = 0.0
                chigh = np.max(state_tensor[state_idx, :, :])
                #chigh = np.max(state_tensor[state_idx, :, min(time_indices) : max(time_indices)])
            clims_global = (clow, chigh)
            print('clims_global =', clims_global, 'for node_color = %s' % node_color)
    else:

        if stack_vertical:
            clims_global = None
            gridspec_kw = {}
            figsize = (2, num_plots*2) 
            ncols = 1
            nrows = num_plots
        else:            
            clims_global = None
            gridspec_kw = {}
            figsize = (num_plots*2, 2) 
            nrows = 1
            ncols = num_plots
            
    # initiate the plot (which depends on if we are adding a colorbar)
    # - nrows, figsize, and gridspec_kw will depend on colorbar presence
    if ax is None:
        fig = plt.figure(constrained_layout=False, figsize=figsize)
        gspec = fig.add_gridspec(ncols=ncols, nrows=nrows, **gridspec_kw)
        figure_simple = False
    else:
        assert len(time_indices)==1 and not unified_colorbar
        figure_simple = True
    
    def get_phase_from_xyz(state_at_time, style='phase1'):
        """
        Multiple ways to define phase given x, y, z:
        'phase1': position on 4 line segments 
            (1) project xy onto idealized limit cycle (i.e. one of four line segments)
            (2) define phase as number from 0 to 4 with 0...1, 1...2, 2..3, 3..4=0 being line segments
        """
        assert style in ['phase1']
        num_cells_at_time = state_at_time.shape[1]
        node_val = np.zeros(num_cells_at_time)
        for idx in range(num_cells_at_time):
            x0, y0, z0 = state_at_time[:, idx]
            segment_coord, segment_scalar, segment_index, distance = get_closest_point_limit_cycle(x0, y0, z0)
            node_val[idx] = segment_scalar
            
        return node_val
    
    def get_plot_cmap_and_node_vals(adj, time_idx):
        state_at_time = state_tensor[:, :, time_idx] # dim_ode x cells
        assert node_color in ['fixed', 'degree', 'phase1', 'x', 'y', 'z']
        if node_color == 'fixed':            
            node_val = np.arange(adj.shape[0])
            # hacky cmap using ListedColormap
            color_base = "#dae2ec"
            color_list = cellgraph.num_cells * [color_base]
            color_list[0] = '#7dbfa6'
            cmap = mpl.colors.ListedColormap(color_list)
        elif node_color == 'degree':            
            degree = np.diag(np.sum(adj, axis=1))
            node_val = np.diag(degree)
            cmap = 'Pastel1'
        elif node_color == 'phase1':
            node_val = get_phase_from_xyz(state_at_time[:, 0:adj.shape[0]], style='phase1')
            # cyclic cmaps:
            # - try hsv, N=8
            # - try twilight, twilight_shifted
            cmap = plt.get_cmap('Spectral', 4)  
        elif node_color == 'x':
            node_val_all = state_at_time[0, :]  # this needs to be truncated based on size of adj
            node_val = node_val_all[0:adj.shape[0]]
            cmap = plt.get_cmap('Spectral', 6)
        elif node_color == 'y':
            node_val_all = state_at_time[1, :]
            node_val = node_val_all[0:adj.shape[0]]
            cmap = plt.get_cmap('Spectral', 6)
        else:
            assert node_color == 'z'
            node_val_all = state_at_time[2, :]
            node_val = node_val_all[0:adj.shape[0]]
            cmap = plt.get_cmap('Spectral', 6)
        return node_val, cmap
            
    def plot_network_on_ax(A, node_val, cmap, ax, title=''):
        draw_edge_labels = False  
        draw_division = None
        explicit_layout = None
        
        if vis_directed:
            if A.shape[0] > 1:
                draw_division = cellgraph.division_events[:A.shape[0]-1, :]
        
        if layout_fn is not None:
            explicit_layout, xlims, ylims = layout_fn(A)
        else:
            xlims = None
            ylims = None
        
        draw_from_adjacency(
            A, 
            title=title, 
            node_color=node_val, 
            labels=None, 
            cmap=cmap,
            gviz_prog=gviz_prog,
            draw_edge_labels=draw_edge_labels,
            draw_division=draw_division,
            explicit_layout=explicit_layout,
            xlims=xlims,
            ylims=ylims,
            clims=clims_global,
            ax=ax)
        return
    
    def get_adj_for_time_idx(time_idx):
        num_cells_at_time = np.searchsorted(birthdays, time_idx, side='right')
        appropriate_adj = list_of_adjacencies[num_cells_at_time - 1]
        return appropriate_adj
    
    for ax_idx, time_idx in enumerate(time_indices):
        
        if not figure_simple:
            # generate ax from gridspec
            if stack_vertical:
                ax = fig.add_subplot(gspec[ax_idx, 0])
            else:
                ax = fig.add_subplot(gspec[0, ax_idx])
        
        adj = get_adj_for_time_idx(time_idx)
        node_val, cmap = get_plot_cmap_and_node_vals(adj, time_idx)
        plot_network_on_ax(
            adj, 
            node_val,
            cmap,
            ax, 
            title='t=%.2f' % cellgraph.times_history[time_idx])
    
    # with unified_colorbar, plot colorbar below
    if unified_colorbar:
        if stack_vertical:
            ax = fig.add_subplot(gspec[:, 1])
        else:
            ax = fig.add_subplot(gspec[1, :])
        norm = mpl.colors.Normalize(vmin=clims_global[0], vmax=clims_global[1])
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        cbar = fig.colorbar(sm, cax=ax, orientation=orientation) 
    
    if save:
        fpath = NB_OUTPUT + os.sep + 'networks_n%s_c%s_unifiedc%d%s' % (len(time_indices), node_color, unified_colorbar, fmod)
        plt.savefig(fpath + '.pdf')
    if show:
        plt.show()
    return

In [None]:
networkx_kwargs = dict(
    node_color = 'fixed',
    gviz_prog = 'twopi', 
    vis_directed = True,
    layout_fn = preset_layout_choice,
    unified_colorbar = True,
    clims_global = (0, 12),  # for x: None or (3,6) for the limit cycle
)

plot_networks_build_cellgraph(cellgraph, moment_before_division, **networkx_kwargs, fmod='_pre-div')
plot_networks_build_cellgraph(cellgraph, moment_after_division, **networkx_kwargs, fmod='_post-div')

In [None]:
networkx_kwargs = dict(
    node_color = 'fixed',
    gviz_prog = 'twopi', 
    vis_directed = True,
    layout_fn = preset_layout_choice,
    unified_colorbar = True,
    clims_global = (0, 12),  # for x: None or (3,6) for the limit cycle
    stack_vertical=True,
)

plot_networks_build_cellgraph(cellgraph, moment_before_division, **networkx_kwargs, fmod='_pre-div')
plot_networks_build_cellgraph(cellgraph, moment_after_division, **networkx_kwargs, fmod='_post-div')

In [None]:
networkx_kwargs = dict(
    node_color = 'phase1',
    gviz_prog = 'twopi', 
    vis_directed = True,
    layout_fn = preset_layout_choice,
    unified_colorbar = True,
    clims_global = None, #(3, 6),  # for x: None or (3,6) for the limit cycle
)

plot_networks_build_cellgraph(cellgraph, moment_before_division, **networkx_kwargs, fmod='_pre-div')
plot_networks_build_cellgraph(cellgraph, moment_after_division, **networkx_kwargs, fmod='_post-div')