In [None]:
import numpy as np
import os
import sys 
import types

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as manimation

from IPython.display import HTML, Video

In [None]:
# help with Illustrator/inkscape fonts
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['ps.fonttype'] = 42

# Notebook setup (path trick) and local import

In [None]:
PACKAGE_ROOT = os.path.dirname(os.path.abspath(''))
print(PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

In [None]:
from load_sweep import wrapper_load_or_digest_sweep, NAMES_IMPORTANT_GRAPHS

from settings import DIR_OUTPUT
from utils_io import pickle_load
from utils_networkx import check_tree_isomorphism, draw_from_adjacency, check_tree_isomorphism_with_insect

## Specify output directory

In [None]:
NB_OUTPUT = PACKAGE_ROOT + os.sep + 'notebooks' + os.sep + 'sweep'
if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

## Video Functions

In [None]:
def setup_ffmpeg_writer(fps=15, bitrate=None, title='Sample movie', artist='Matplotlib', comment='comment here'):
    # Define the meta data for the movie
    FFMpegWriter = manimation.writers['ffmpeg']
    metadata = dict(title=title, artist=artist, comment=comment)
    writer = FFMpegWriter(fps=fps, bitrate=bitrate, metadata=metadata)
    return writer

def vid_fpath_to_html(fpath):
    """
    - Show video using
        HTML(s)
    - Need to wrap it like this because it doesn't like abspath
    - Alternate method for absolute path:
        Video(fpath)
    """
    s = f"""
        <div align="middle">
        <video width="80%" controls>
              <source src={fpath.split("notebooks")[1][1:]} type="video/mp4">
        </video></div>"""
    return s

# Select sweep to load and visualize
- the directory must contain a file called `sweep.pkl`, an instance of `SweepCellGraph`
- it may additionally contain the following two files to avoid having to reanalyze the sweep (slow)
    - `unique_networks_dict.pkl`
    - `results_of_theta.npz`

In [None]:
# archived sweep data is on ceph and select ones on Onedrive (and copied to project input directory)
archive_sweeps = PACKAGE_ROOT + os.sep + 'input' + os.sep + 'archive_sweeps'

In [None]:
dir_sweep_selected = archive_sweeps + os.sep + '3D' + os.sep + 'small_80x200x61_beta0_zpulse_divPlusMinus'
#dir_sweep_selected = archive_sweeps + os.sep + '3D' + os.sep + 'big_100x160x181_beta0_zpulse_divPlusMinus'

### Load the sweep files, generating the two digest files if they do not yet exist (slow)

In [None]:
sweep_cellgraph, unique_networks_dict, results_of_theta = wrapper_load_or_digest_sweep(dir_sweep_selected)

### Print properties of the sweep class and the loaded instance

In [None]:
properties = [prop for prop in dir(sweep_cellgraph) if (prop[0] != '_' and prop[-1] != '_')]

print('SweepCellGraph Methods:')
print([p for p in properties if type(getattr(sweep_cellgraph, p)) == types.MethodType])

print('\nSweepCellGraph Attributes:')
print([p for p in properties if type(getattr(sweep_cellgraph, p)) != types.MethodType])

print('\nsweep_cellgraph.base_cellgraph Attributes:')
properties_cellgraph = [prop for prop in dir(sweep_cellgraph.base_cellgraph) if (prop[0] != '_' and prop[-1] != '_')]
print([p for p in properties_cellgraph if type(getattr(sweep_cellgraph.base_cellgraph, p)) != types.MethodType])
print(sweep_cellgraph.base_cellgraph.diffusion)
print(sweep_cellgraph.base_cellgraph.diffusion_arg)
print(sweep_cellgraph.base_cellgraph.style_division)

In [None]:
sweep_cellgraph.printer()

print('\nresults_dict[i,j,k,...] points to dict with keys:')
print([k for k in sweep_cellgraph.results_dict[tuple([0]*sweep_cellgraph.k_vary)].keys()])

### Aliases for the digested sweep data

In [None]:
arr_num_cells = results_of_theta['num_cells']
arr_unique_network_id = results_of_theta['unique_network_id']
arr_isos = results_of_theta['isos']
arr_exitstatus = results_of_theta['exitstatus']
arr_end_time = results_of_theta['end_time']

In [None]:
sweep = sweep_cellgraph
results = sweep.results_dict

# Visualization

### ideas: have slider for third axis, then show 3 scatter/imshow plots in a row or column (one each arr)

## Local test of 3D plotting with 2D imshow slice

In [None]:
#imshow_data = arr_num_cells.copy()
imshow_data = arr_unique_network_id.copy()
#imshow_data = arr_isos.copy()

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable


def slice_2d_multi_imshow(p3_idx=None):
    
    assert sweep.k_vary in [2,3]
    
    if sweep.k_vary == 2:
        assert p3_idx is None
        data0 = arr_num_cells[:,:]
        data1 = arr_unique_network_id[:,:]
        data2 = arr_isos[:,:]
        data3 = arr_exitstatus[:,:]
        data4 = arr_end_time[:,:]
        suptitle = 'full 2d slice'
    else:
        if p3_idx is None:
            p3_idx = 0
            print('overriding p3_idx to 0')
        data0 = arr_num_cells[:,:,p3_idx]
        data1 = arr_unique_network_id[:,:,p3_idx]
        data2 = arr_isos[:,:,p3_idx]
        data3 = arr_exitstatus[:,:,p3_idx]
        data4 = arr_end_time[:,:,p3_idx]
        suptitle = '2d slice on p3_idx=%d; val=%.3f' % (p3_idx, sweep.params_values[2][p3_idx])
        
    fig, axarr = plt.subplots(nrows=5, ncols=1, figsize=(9,16), sharex=True)
    print(axarr.shape)

    # Plot 1 - num cells
    im0 = axarr[0].imshow(np.transpose(data0),   # note transpose bc need to flip x and y for imshow
                    origin='lower',
                    aspect='auto',
                    interpolation='none',
                    extent=[sweep.params_values[0][0], 
                            sweep.params_values[0][-1], 
                            sweep.params_values[1][0],
                            sweep.params_values[1][-1]])
    divider = make_axes_locatable(axarr[0])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im0, cax=cax, orientation='vertical', label='final num_cells')
    
    # Plot 2 - unique integer
    im1 = axarr[1].imshow(np.transpose(data1),   # note transpose bc need to flip x and y for imshow
                    origin='lower',
                    aspect='auto',
                    interpolation='none',
                    extent=[sweep.params_values[0][0], 
                            sweep.params_values[0][-1], 
                            sweep.params_values[1][0],
                            sweep.params_values[1][-1]])
    divider = make_axes_locatable(axarr[1])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical', label='unique graph ID')
    
    # Plot 3 - biological string
    im2 = axarr[2].imshow(np.transpose(data2),   # note transpose bc need to flip x and y for imshow
                    cmap='tab20',
                    origin='lower',
                    aspect='auto',
                    interpolation='none',
                    extent=[sweep.params_values[0][0], 
                            sweep.params_values[0][-1], 
                            sweep.params_values[1][0],
                            sweep.params_values[1][-1]])
    # need to create empty dummy colorbar to match axis size of other plots
    divider = make_axes_locatable(axarr[2])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.axis('off')
    
    # create a patch (proxy artist) for every color
    values = np.unique(data2)
    patches = [mpatches.Patch(color=im2.cmap(im2.norm(values[i])), 
                              label=NAMES_IMPORTANT_GRAPHS[int(values[i])]) for i in range(len(values))]
    # put those patched as legend-handles into the legend
    axarr[2].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=10., ncol=1)
    
    # Plot 4 - exitstatus bool
    im3 = axarr[3].imshow(np.transpose(data3),   # note transpose bc need to flip x and y for imshow
                    origin='lower',
                    aspect='auto',
                    interpolation='none',
                    extent=[sweep.params_values[0][0], 
                            sweep.params_values[0][-1], 
                            sweep.params_values[1][0],
                            sweep.params_values[1][-1]])
    
    # need to create empty dummy colorbar to match axis size of other plots
    divider = make_axes_locatable(axarr[3])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.axis('off')
    
    # create a patch (proxy artist) for every color
    valmap = {True: '(True) run completed',
              False: '(False) early exit'}
    values = np.unique(data3)
    print(values)
    patches = [mpatches.Patch(color=im3.cmap(im3.norm(val)), 
                              label=valmap[val])
               for val in values]
    axarr[3].legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=10., ncol=1)
    
    # Plot 5 - end time of sim
    im4 = axarr[4].imshow(np.transpose(data4),   # note transpose bc need to flip x and y for imshow
                    origin='lower',
                    aspect='auto',
                    extent=[sweep.params_values[0][0], 
                            sweep.params_values[0][-1], 
                            sweep.params_values[1][0],
                            sweep.params_values[1][-1]])
    divider = make_axes_locatable(axarr[4])
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im4, cax=cax, orientation='vertical', label='sim end time')
    
    # End of Plot - axis labels
    axarr[-1].set_xlabel(sweep.params_name[0])  # sharex so only xlabel on last ax
    for idx in range(3):
        axarr[idx].set_ylabel(sweep.params_name[1])
    
    plt.suptitle(suptitle)
    plt.tight_layout()
    plt.savefig('output' + os.sep + 'slice2d.pdf')
    plt.show()
    
    
slice_2d_multi_imshow(p3_idx=None)

In [None]:
if sweep.k_vary == 3:
    run_idx = (0, 0, 1)
    print('Look at test point', run_idx, '...')

    print(results[run_idx])
    print([sweep.params_name[i] for i in range(3)])
    print([sweep.params_values[i][run_idx[i]] for i in range(3)])

    print('\nresults_of_theta values at test point')
    for k in results_of_theta.keys():
        print('\t', k, results_of_theta[k][run_idx])

### Check unique_networks_dict -- important output of digest_sweep()

In [None]:
totalruns = 0
totalunique = 0
for i in np.unique(results_of_theta['num_cells']):
    print('num_cells = %d' % i)
    print(unique_networks_dict[i].keys())
    for k in unique_networks_dict[i].keys():
        totalunique += 1
        v = unique_networks_dict[i][k]
        nruns = len(v['runs'])
        totalruns += nruns
        A_run = v['adjacency']
        print('\t', k, '->', 'nruns=%d\n\t\t' % len(v['runs']), v['iso'], v['unique_int'], v['bio_int'], A_run.shape)

print('Should be equal:', totalruns, np.prod(sweep.sizes))
print('Total unique graphs:', totalunique)

## Plotly methods

In [None]:
# (see here) https://plotly.com/python/3d-volume-plots/
# (see here) https://plotly.com/python/line-and-scatter/
# (see here for 2d slider of slices) https://plotly.com/python/imshow/ ("Animations of xarray datasets")
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go

In [None]:
def sweep_plotly_goVolume(sweep, plotstyle='num_cells'):
    assert plotstyle in ['num_cells', 'isos']
    
    if plotstyle == 'num_cells':
        plotdata = arr_num_cells.copy()  # these are in mem in notebook, should instead pass to function
    else:
        plotdata = arr_isos.copy()       # these are in mem in notebook, should instead pass to function
    
    num_runs = sweep.total_runs
    label = sweep.sweep_label
    k_vary = sweep.k_vary
    assert k_vary in [2,3]
    
    swap_xy = True
    assert swap_xy  # can compare against scatter plot function to triple check necessity
    if swap_xy:
        kx = 1
        ky = 0
    else:
        kx = 0
        ky = 1
    x_vals = sweep.params_values[kx]
    y_vals = sweep.params_values[ky]  
    z_vals = sweep.params_values[2]
    X, Y, Z = np.meshgrid(x_vals, y_vals, z_vals)
        
    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=plotdata.flatten(),
        #isomin=1.0,
        #isomax=4.5,
        opacity=0.2, # needs to be small to see through all surfaces
        surface_count=21, # needs to be a large number for good volume rendering
        ))
    fig.update_layout(
        title={
            'text': "sweep %s (%s)'; care for x,y axis flip'" % (sweep.sweep_label, plotstyle),
            'y':0.9,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
    )
    fig.update_layout(
        scene=dict(
            xaxis_title=sweep.params_name[kx],
            yaxis_title=sweep.params_name[ky],
            zaxis_title=sweep.params_name[2])
    )
    fig.write_html(DIR_OUTPUT + os.sep + "sweep_plotly_goVol_%s_%s.html" % (sweep.sweep_label, plotstyle))
    fig.show()

In [None]:
def sweep_plotly_scatter(sweep, plotstyle='num_cells'):
    
    assert plotstyle in ['num_cells', 'isos']
    pd_color_key = plotstyle
    assert plotstyle == 'isos'  # just for testing
        
    plotdata_numcells = arr_num_cells.copy()  # these are in mem in notebook, should instead pass to function
    plotdata_isos = arr_isos.copy()           # these are in mem in notebook, should instead pass to function
    
    plotdata_names = np.zeros_like(plotdata_isos, dtype=np.dtype(('U', 20)))
    for idx, aval in np.ndenumerate(plotdata_isos):
        plotdata_names[idx] = NAMES_IMPORTANT_GRAPHS[aval]

    num_runs = sweep.total_runs
    label = sweep.sweep_label
    k_vary = sweep.k_vary
    assert k_vary == 3
    
    x_label = sweep.params_name[0]
    y_label = sweep.params_name[1]
    z_label = sweep.params_name[2]
    
    num_runs_slice = np.prod(plotdata_numcells.shape)
    flat_x_vals = np.zeros(num_runs_slice)
    flat_y_vals = np.zeros(num_runs_slice)
    flat_z_vals = np.zeros(num_runs_slice)
    flat_numcells = np.zeros(num_runs_slice)
    flat_isos = np.zeros(num_runs_slice)
    flat_names = np.zeros(num_runs_slice, dtype=np.dtype(('U', 20)))
    idx = 0
    for run_id_list in np.ndindex(*plotdata_numcells.shape):
        a, b, c = run_id_list
        flat_x_vals[idx] = sweep.params_values[0][a]
        flat_y_vals[idx] = sweep.params_values[1][b]
        flat_z_vals[idx] = sweep.params_values[2][c]
        flat_numcells[idx] = plotdata_numcells[run_id_list]
        flat_isos[idx] = plotdata_isos[run_id_list]
        flat_names[idx] = plotdata_names[run_id_list]
        idx += 1
    
    df = pd.DataFrame({'index': range(num_runs_slice),
                       'num_cells': flat_numcells,
                       'isos': flat_isos,
                       'bionames': flat_names,
                       x_label: flat_x_vals,
                       y_label: flat_y_vals,
                       z_label: flat_z_vals})
    
    fig = px.scatter_3d(df, x=x_label, y=y_label, z=z_label,
                        color='bionames',
                        title="sweep scatter3d %s (%s)" % (sweep.sweep_label, plotstyle),
                        opacity=0.2)
    
    fig.write_html(DIR_OUTPUT + os.sep + "sweep_plotly_scatter3d_%s_%s.html" % (sweep.sweep_label, plotstyle))
    fig.show()

In [None]:
def sweep_plotly_imshow_slice2d(sweep, plotstyle='num_cells', kslice=2, idxslice=0):
    """
    kslice - axis on which to slice (0, 1 or 2)
    idx slice - the index of the kslice parameter on which to slice
    """
    assert plotstyle in ['num_cells', 'isos']
    pd_color_key = plotstyle
    
    plotdata_numcells = arr_num_cells.copy()  
    plotdata_isos = arr_isos.copy()  
    plotdata_endtime = arr_end_time.copy()
    
    if plotstyle == 'num_cells':
        imdata = plotdata_numcells
        c_label = 'num cells'
    else:
        imdata = plotdata_isos
        c_label = 'bio iso int'

    
    num_runs = sweep.total_runs
    label = sweep.sweep_label
    k_vary = sweep.k_vary
    
    assert k_vary in [2,3]
    if k_vary == 3:
    
        if kslice == 0:
            plotdata_numcells = plotdata_numcells[idxslice, :, :].T
            plotdata_isos = plotdata_isos[idxslice, :, :].T
            plotdata_endtime = plotdata_endtime[idxslice, :, :].T
            imdata = imdata[idxslice, :, :].T
            k_x = 1
            k_y = 2
        elif kslice == 1:
            plotdata_numcells = plotdata_numcells[:, idxslice, :].T
            plotdata_isos = plotdata_isos[:, idxslice, :].T
            plotdata_endtime = plotdata_endtime[:, idxslice, :].T
            imdata = imdata[:, idxslice, :].T
            k_x = 0
            k_y = 2
        else:
            assert kslice == 2
            plotdata_numcells = plotdata_numcells[:, :, idxslice].T
            plotdata_isos = plotdata_isos[:, :, idxslice].T
            plotdata_endtime = plotdata_endtime[:, :, idxslice].T
            imdata = imdata[:, :, idxslice].T
            k_x = 0
            k_y = 1
        plottitle = "sweep imshow2d slice ax%d idx%d %s (%s)" % (kslice, idxslice, sweep.sweep_label, plotstyle)
        fname = "sweep_plotly_imshow2d_slice_ax%d_idx%d_%s_%s.html" % (kslice, idxslice, sweep.sweep_label, plotstyle)
    else:
        plotdata_numcells = plotdata_numcells[:, :].T
        plotdata_isos = plotdata_isos[:, :].T
        plotdata_endtime = plotdata_endtime[:, :].T
        imdata = imdata[:, :].T
        k_x = 0
        k_y = 1
        plottitle = "sweep imshow2d %s (%s)" % (sweep.sweep_label, plotstyle)
        fname = "sweep_plotly_imshow2d_%s_%s.html" % (sweep.sweep_label, plotstyle)
        
    x_vals = sweep.params_values[k_x]
    y_vals = sweep.params_values[k_y]
    x_label = sweep.params_name[k_x]
    y_label = sweep.params_name[k_y]
    
    print(imdata.shape, len(x_label), len(y_label))
    fig = px.imshow(imdata, 
                    x=x_vals, y=y_vals, 
                    aspect="auto", origin="lower",
                    labels=dict(x=x_label, y=y_label, color=c_label),
                    color_continuous_scale='spectral',
                    title=plottitle,
                    animation_frame=None)

    # add custom hovertext    
    plotdata_names = np.zeros_like(plotdata_isos, dtype=np.dtype(('U', 20)))
    for idx, aval in np.ndenumerate(plotdata_isos):
        plotdata_names[idx] = NAMES_IMPORTANT_GRAPHS[aval]
    fig.update(
        data=[{'customdata': np.dstack((plotdata_numcells, plotdata_isos, plotdata_names, plotdata_endtime)),    
               'hovertemplate': "(x) %{x} \
                                 <br>(y) %{y}\
                                 <br>(val) %{z}\
                                 <br>(num cells) %{customdata[0]}\
                                 <br>(bio int) %{customdata[1]}\
                                 <br>(bio name) %{customdata[2]}\
                                 <br>(sim endtime) %{customdata[3]}"
              }
             ]
    )    
    fig.write_html(DIR_OUTPUT + os.sep + fname)
    fig.show()

In [None]:
def sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='num_cells', kslice=2):
    """
    kslice - axis on which to slice (0, 1 or 2)
    idx slice - the index of the kslice parameter on which to slice
    """
    assert plotstyle in ['num_cells', 'isos']
    pd_color_key = plotstyle
    
    plotdata_numcells = arr_num_cells.copy()  # these are in mem in notebook, should instead pass to function
    plotdata_isos = arr_isos.copy()   # these are in mem in notebook, should instead pass to function
    plotdata_endtime = arr_end_time.copy()

    if plotstyle == 'num_cells':
        imdata = plotdata_numcells
        c_label = 'num cells'
    else:
        imdata = plotdata_isos
        c_label = 'bio iso int'

    
    num_runs = sweep.total_runs
    label = sweep.sweep_label
    k_vary = sweep.k_vary
    assert k_vary == 3
    
    if kslice == 0:
        plotdata_numcells = np.transpose(plotdata_numcells, (0,2,1))
        plotdata_isos = np.transpose(plotdata_isos, (0,2,1))
        plotdata_endtime = np.transpose(plotdata_endtime, (0,2,1))
        imdata = np.transpose(imdata, (0,2,1))
        k_x = 1
        k_y = 2
    elif kslice == 1:
        plotdata_numcells = np.transpose(plotdata_numcells, (2,1,0))
        plotdata_isos = np.transpose(plotdata_isos, (2,1,0))
        plotdata_endtime = np.transpose(plotdata_endtime, (2,1,0))
        imdata = np.transpose(imdata, (2,1,0))
        k_x = 0
        k_y = 2
    else:
        assert kslice == 2
        plotdata_numcells = np.transpose(plotdata_numcells, (1,0,2))
        plotdata_isos = np.transpose(plotdata_isos, (1,0,2))
        plotdata_endtime = np.transpose(plotdata_endtime, (1,0,2))
        imdata = np.transpose(imdata, (1,0,2))
        k_x = 0
        k_y = 1
    
    x_vals = sweep.params_values[k_x]
    y_vals = sweep.params_values[k_y]
    z_vals = sweep.params_values[kslice]
    x_label = sweep.params_name[k_x]
    y_label = sweep.params_name[k_y]
    z_label = sweep.params_name[kslice]
    
    print(imdata.shape, len(x_label), len(y_label))
    fig = px.imshow(imdata, 
                    x=x_vals, y=y_vals, 
                    aspect="auto", origin="lower",
                    labels=dict(x=x_label, y=y_label, color=c_label),
                    color_continuous_scale='spectral',
                    title="sweep imshow2d slice ax%d %s (%s)" % (kslice, sweep.sweep_label, plotstyle),
                    animation_frame=kslice)

    # add custom hovertext    
    plotdata_names = np.zeros_like(plotdata_isos, dtype=np.dtype(('U', 20)))
    for idx, aval in np.ndenumerate(plotdata_isos):
        plotdata_names[idx] = NAMES_IMPORTANT_GRAPHS[aval]
    fig.update(
        data=[{'customdata': np.dstack((plotdata_numcells, plotdata_isos, plotdata_names, plotdata_endtime)),    
               'hovertemplate': "(x) %{x} \
                                 <br>(y) %{y}\
                                 <br>(val) %{z}\
                                 <br>(num cells) %{customdata[0]}\
                                 <br>(bio int) %{customdata[1]}\
                                 <br>(bio name) %{customdata[2]}\
                                 <br>(sim endtime) %{customdata[3]}"
              }
             ]
    )
    
    for idx, frame in enumerate(fig.frames):
        if kslice == 0:
            frame.data[0].customdata = np.dstack((plotdata_numcells[idx,:,:], 
                                                  plotdata_isos[idx,:,:], 
                                                  plotdata_names[idx,:,:],
                                                  plotdata_endtime[idx,:,:]))
        elif kslice == 1:
            frame.data[0].customdata = np.dstack((plotdata_numcells[:,idx,:],
                                                  plotdata_isos[:,idx,:], 
                                                  plotdata_names[:,idx,:],
                                                  plotdata_endtime[:,idx,:]))
        else:
            frame.data[0].customdata = np.dstack((plotdata_numcells[:,:,idx],
                                                  plotdata_isos[:,:,idx], 
                                                  plotdata_names[:,:,idx],
                                                  plotdata_endtime[:,:,idx]))
        frame.data[0].hovertemplate = "(x) %{x} \
                               <br>(y) %{y}\
                               <br>(val) %{z}\
                               <br>(num cells) %{customdata[0]}\
                               <br>(bio int) %{customdata[1]}\
                               <br>(bio name) %{customdata[2]}\
                               <br>(sim endtime) %{customdata[3]}"
        frame['layout'].update(title_text='%s 2D slice: %s=%.3f (idx %d)' % (plotstyle, z_label, z_vals[idx], idx))
    
    fig.write_html(DIR_OUTPUT + os.sep + "sweep_plotly_imshow2d_slice_ax%d_%s_%s.html" % (kslice, sweep.sweep_label, plotstyle))
    fig.show()

In [None]:
if sweep.k_vary == 3:
    sweep_plotly_scatter(sweep, plotstyle='isos')

In [None]:
# These single 2D sweeps are not as useful as the "slider" animated 2d imshows below
if sweep.k_vary == 2:
    sweep_plotly_imshow_slice2d(sweep, plotstyle='num_cells')
    sweep_plotly_imshow_slice2d(sweep, plotstyle='isos')

In [None]:
if sweep.k_vary == 3:
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='num_cells', kslice=2)
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='isos', kslice=2)

In [None]:
if sweep.k_vary == 3:   
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='num_cells', kslice=1)
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='isos', kslice=1)

In [None]:
if sweep.k_vary == 3:   
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='num_cells', kslice=0)
    sweep_plotly_imshow_slice2dAnimate(sweep, plotstyle='isos', kslice=0)

# Stats

In [None]:
def printer_array_statistics(arr_unique_network_id, arr_num_cells, labels=None, hist=True):
    """
    arr_unique_network_id: array of ints, size sweep.sizes
    arr_num_cells:         array of ints, size sweep.sizes
    labels: (dict or list) str labels for the elements of arr
    """
    arr_unique_id = arr_unique_network_id.flatten().copy()
    arr_num_cells = arr_num_cells.flatten().copy()

    outx, out_idx, out_counts = np.unique(arr_unique_id, return_index=True, return_counts=True)
    # sort outputs in descending order of frequency
    indices = np.argsort(out_counts)[::-1]
    outx = outx[indices]
    out_idx = out_idx[indices]
    out_counts = out_counts[indices]
    outy = out_counts
    
    if hist:
        plt.bar(outx, out_counts)
        if labels is not None:
            ax = plt.gca()
            ax.set_xticks(outx)
            ax.set_xticklabels([labels[v] for v in outx], rotation=45, ha='right', rotation_mode='anchor')
        plt.ylabel('observtions')
        plt.show()
    
    for idx in range(len(outx)):
        ncells = arr_num_cells[out_idx[idx]]
        variety_pct = 100*float(out_counts[idx])/arr_unique_id.shape[0]
        if labels is not None:
            variety_str = "val=%d (num_cells: %d) (label: %s)" % (outx[idx], ncells, labels[outx[idx]])
        else:
            variety_str = "val=%d (num_cells: %d)" % (outx[idx], ncells)
        amount_str = "events=%d (%.2f%%)" % (out_counts[idx], variety_pct)
        print(f"{variety_str: <55}{amount_str:<55}")
        
    return outx, outy

outx, outy = printer_array_statistics(arr_isos, arr_num_cells, labels=NAMES_IMPORTANT_GRAPHS, hist=True)

In [None]:
outx, outy = printer_array_statistics(arr_unique_network_id, arr_num_cells)

# Plot specific networks
gviz_prog options: twopi, circo, dot 
   (else do spring plot and specify spring_seed as integer)

In [None]:
from utils_networkx import draw_from_adjacency

def plot_network_by_degree(A, fpath, title='', spring_seed=None, gviz_prog='dot'):
    figsize=(4,4)
    degree = np.diag(np.sum(A, axis=1))
    degree_vec = np.diag(degree)
    tvar = title + ' (degree)'
    fpathvar = fpath + '_Degree'

    draw_from_adjacency(
        A, title=tvar, node_color=degree_vec, 
        labels=None, cmap='Pastel1',
        fpath=fpathvar, 
        figsize=figsize, spring_seed=spring_seed, 
        gviz_prog=gviz_prog)

def plot_unique_networks_given_ncells(ncells, spring_seed=None, gviz_prog='dot'):
    info_subdict = unique_networks_dict[ncells]

    outdir = DIR_OUTPUT + os.sep + 'variants_%dcell' % ncells
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    for k, v in info_subdict.items():
        print(v)
        unique_int = v['unique_int']
        bio_int = v['bio_int']
        adjacency = v['adjacency']
        specific_runs = v['runs']
        title = 'Network %s: %d events' % (k, len(specific_runs))
        fpath = outdir + os.sep + 'unetwork_%s' % k
        plot_network_by_degree(adjacency, fpath, title=title, spring_seed=spring_seed, gviz_prog=gviz_prog)
        
    return

def plot_unique_networks_given_ID(uid, spring_seed=None, gviz_prog='dot'):
    uid_1 = uid.split('_')[0]
    uid_ncell = int(uid_1.split('M')[1])
    v = unique_networks_dict[uid_ncell][uid]

    outdir = DIR_OUTPUT + os.sep + 'unique_networks'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    
    unique_int = v['unique_int']
    bio_int = v['bio_int']
    adjacency = v['adjacency']
    specific_runs = v['runs']
    title = 'Network %s: %d events' % (uid, len(specific_runs))
    fpath = outdir + os.sep + 'unetwork_%s' % uid
    plot_network_by_degree(adjacency, fpath, title=title, spring_seed=spring_seed, gviz_prog=gviz_prog)
        
    return

In [None]:
ncell = 8
plot_unique_networks_given_ncells(ncell, gviz_prog='dot')
plot_unique_networks_given_ncells(ncell, gviz_prog='twopi')
plot_unique_networks_given_ncells(ncell, gviz_prog='circo')
plot_unique_networks_given_ncells(ncell, spring_seed=0)

In [None]:
print(unique_networks_dict.keys())

In [None]:
arr_num_cells.max()

In [None]:
for uu in ['M15_v9', 'M12_v11', 'M20_v22']:
    plot_unique_networks_given_ID(uu, gviz_prog='dot')
    plot_unique_networks_given_ID(uu, gviz_prog='twopi')
    plot_unique_networks_given_ID(uu, gviz_prog='circo')
    plot_unique_networks_given_ID(uu, spring_seed=0)

In [None]:
print(unique_networks_dict.keys())

In [None]:
uid='M15_v9'
uid_1 = uid.split('_')[0]
uid_ncell = int(uid_1.split('M')[1])
print(uid_1.split('M'))
unique_networks_dict[15]['M15_v71']

# Testing "voxel" plot approach for 3D data

In [None]:
mini_arr_isos = arr_isos.copy()
mini_arr_isos = arr_isos[::3, ::3, ::3]

In [None]:
# widget seemed less stable...
#%matplotlib widget 
%matplotlib inline

In [None]:
def color_name_to_hex_alpha(color_str, alpha):
    assert 0<= alpha <= 1
    rgba_01 = list(mpl.colors.to_rgba('red'))
    rgba_01[-1] = alpha
    return mpl.colors.to_hex(rgba_01, keep_alpha=True)

print('Example color in rgba hex:')
print(color_name_to_hex_alpha('red', 0.5))

In [None]:
XYZ_LABEL_TEXT = ['pulse slope', 'diffusion', 'division asymmetry']
XYZ_LABEL_MATH = [r'$v$', r'$c$', r'$\alpha$']

In [None]:
NAMES_IMPORTANT_GRAPHS_OVERRIDE = [
    'background','onecell',
    'L. humile (16)','D. melanogaster (16)','L. humile (32)','B. terrestris (64)',
    'G. natator (8)','C. perla (12)','C. perla (13)','C. perla (14)', 'O. labronica (2)', 
    'P. sauteri (3)', 'P. communis (4)', 'D. parthogeneticus (8)',
    '5cellcyst_linear','6cellcyst_linear', '7cellcyst_linear', '8cellcyst_linear', '9cellcyst_linear',
    '10cellcyst_linear', '11cellcyst_linear', '12cellcyst_linear', '13cellcyst_linear','14cellcyst_linear', 
    '15cellcyst_linear', '16cellcyst_linear','H. juglandis (32)', 
    'fulvicephalus1','fulvicephalus2', 'fulvicephalus3']
NAMES_TO_INT = {s: i for i, s in enumerate(NAMES_IMPORTANT_GRAPHS_OVERRIDE)}
                
print('total names:', len(NAMES_IMPORTANT_GRAPHS_OVERRIDE))
for idx in range(len(NAMES_IMPORTANT_GRAPHS_OVERRIDE)):
    print('\t', idx, '->', NAMES_IMPORTANT_GRAPHS_OVERRIDE[idx])

In [None]:
ISO_COLLECTION_A = {
    'linear networks (8 to 16 cell)': dict(
        group = [NAMES_TO_INT[s] for s in [
                    '8cellcyst_linear', '9cellcyst_linear', '10cellcyst_linear', '11cellcyst_linear', '12cellcyst_linear', 
                    '13cellcyst_linear','14cellcyst_linear', '15cellcyst_linear', '16cellcyst_linear']
                ],
        color = '#faf0e680'),
    'linear networks (6, 7 cell)': dict(
        group = [NAMES_TO_INT[s] for s in [
                    '6cellcyst_linear', '7cellcyst_linear']
                ],
        color = 'red'),
}

ISO_COLLECTION_H3 = {
    'Maximally branched (8, 16, 32)': dict(
        group = [NAMES_TO_INT[s] for s in [
                    'G. natator (8)', 'D. melanogaster (16)', 'H. juglandis (32)']
                ],
        color = '#1f77b4A6'),
    '2 cell': dict(
        group = [NAMES_TO_INT[s] for s in [
                    'O. labronica (2)']
                ],
        color = '#7f7f7fA6'),
    '3 cell': dict(
        group = [NAMES_TO_INT[s] for s in [
                    'P. sauteri (3)']
                ],
        color = '#9467bdA6'),
    '4 cell linear': dict(
        group = [NAMES_TO_INT[s] for s in [
                    'P. communis (4)']
                ],
        color = '#8c564bA6'),
    'Other linear networks': dict(
        group = [NAMES_TO_INT[s] for s in [
                    '5cellcyst_linear', '6cellcyst_linear', '7cellcyst_linear', '8cellcyst_linear', 
                    '9cellcyst_linear', '10cellcyst_linear', '11cellcyst_linear', '12cellcyst_linear', 
                    '13cellcyst_linear','14cellcyst_linear', '15cellcyst_linear', '16cellcyst_linear']
                ],
        color = '#ff7f0e80'),
    r'$\mathit{C. perla}$': dict(
        group = [NAMES_TO_INT[s] for s in [
                    'C. perla (12)']
                ],
        color = '#2ca02cA6'),
}


In [None]:
iso_color_default = '#00000080'
ISO_TO_COLOR_BASE = {
    0: iso_color_default,
    1: iso_color_default,
    2: '#1f77b4A6',
    3: '#ff7f0e80',
    4: iso_color_default,
    5: iso_color_default,
    6: '#2ca02c80',
    7: '#d62728A6',
    8: '#9467bdA6',
    9: iso_color_default,
    10: '#e377c2A6',
    11: '#7f7f7fA6',
    12: '#8c564b80',
    13: iso_color_default,
    14: '#17becf80',
    15: '#00000080',
    16: '#bf00bf80',
    17: '#faf0e680',
    18: '#faf0e680',
    19: '#faf0e680',
    20: '#faf0e680',
    21: '#faf0e680',
    22: '#faf0e680',
    23: '#faf0e680',
    24: '#faf0e680',
    25: '#faf0e680',
    26: '#ffd700A6',
    27: iso_color_default,
    28: iso_color_default,
    29: iso_color_default,
}


In [None]:
# custom x,y,z axes ticks - original arr_isos
def custom_xyz_ticks(arr_isos, sweep, xlims=None, ylims=None, zlims=None):
    nnx, nny, nnz = arr_isos.shape
    xlabel, ylabel, zlabel = sweep.params_name

    
    if xlims is not None:
        xlims = [np.searchsorted(sweep.params_values[0], xlims[0]),
                 np.searchsorted(sweep.params_values[0], xlims[1]) + 1]  # these are ints
    else:
        xlims = [0, nnx]
    if ylims is not None:
        ylims = [np.searchsorted(sweep.params_values[1], ylims[0]),
                 np.searchsorted(sweep.params_values[1], ylims[1]) + 1]
    else:
        ylims = [0, nny]
    if zlims is not None:
        zlims = [np.searchsorted(sweep.params_values[2], zlims[0]),
                 np.searchsorted(sweep.params_values[2], zlims[1]) + 1]
    else:
        zlims = [0, nnz]
        
    n_ticks_force = 3
    xlow, xhigh = sweep.params_values[0][xlims[0]], sweep.params_values[0][xlims[1]-1]
    ylow, yhigh = sweep.params_values[1][ylims[0]], sweep.params_values[1][ylims[1]-1]
    zlow, zhigh = sweep.params_values[2][zlims[0]], sweep.params_values[2][zlims[1]-1]
    
    xyz_labels_dict = dict(
        labels = XYZ_LABEL_TEXT,  # XYZ_LABEL_TEXT or XYZ_LABEL_MATH
        xlims = xlims,
        ylims = ylims, 
        zlims = zlims, 
        xticks = np.linspace(xlims[0], xlims[1], n_ticks_force),
        xtick_labels = np.round(np.linspace(xlow, xhigh, n_ticks_force), 3),
        yticks = np.linspace(ylims[0], ylims[1], n_ticks_force),
        ytick_labels = np.round(np.linspace(ylow, yhigh, n_ticks_force), 2),
        zticks = np.linspace(zlims[0], zlims[1], n_ticks_force),
        ztick_labels = np.round(np.linspace(zlow, zhigh, n_ticks_force), 3),
    )
    return xyz_labels_dict


xyz_ticks_orig = custom_xyz_ticks(arr_isos, sweep)
xyz_ticks_mini = custom_xyz_ticks(mini_arr_isos, sweep)

In [None]:
def voxel_mpl_3D(arr_isos, isos_to_show, iso_collections=None,
                 sweep=None, xyz_labels_dict=None, legend=True, proj=False, rasterized=True, ax=None):
    """
    Args:
        arr_isos: array of integers
        isos_to_show: list of integers representing a subset of the isomorphisms in arr_isos
            the order of this list sets the order of the legend
        sweep: if sweep is given and xyz_dict is None, then x,y,z ticks and labels are automated
        xyz_labels_dict: if supplied, it will force the x, y, z labels and ticks accordingly
             Structure assumed:
                'labels': [label x, label y, label z]
                'lims': [2-tuple x, 2-tuple y, 2-tuple z]
                'xticks': nnx list,
                'xtick_labels': nnx list, 
                'yticks': nny list,
                'ytick_labels': nny list, 
                'zticks': nnz list,
                'ztick_labels': nnz list, 
        iso_collections: [default: None] dict of the form 
            {some_label_string: {'group': [iso_1, ..., iso_k]}, 'color': color}}    
    """
    arr_isos = np.copy(arr_isos)  # avoid in place operations overwriting array
    
    if ax is None:
        fig = plt.figure(figsize=(4,4))
        ax = plt.figure().add_subplot(projection='3d') 
        ax.set_zorder(20)  # for later rasterizing
        
    # dict stores, for each isomorphism, set of properties related to plotting
    # - label (for legend)
    # - color (for voxels)
    isos_to_properties = {
        i: dict(label=val, 
                color=ISO_TO_COLOR_BASE[i])
        for i, val in enumerate(NAMES_IMPORTANT_GRAPHS_OVERRIDE)
    }

    # post-process isos_to_properties by collecting/grouping them, removing elements of the groups on the way
    if iso_collections is not None:
        counter = 0
        for k, v in iso_collections.items():
            counter -= 1
            group = v['group'] 
            idx_using_for_collection = counter
            assert counter not in isos_to_properties.keys()
            isos_to_properties[idx_using_for_collection] = dict(label=k, color=v['color'])
            isos_to_show += [idx_using_for_collection]
            for idx in group:
                # 1) remove elements of the groups from isos_to_show
                isos_to_show = [a for a in isos_to_show if a not in group]
                # 2) update voxel array accordingly
                arr_isos[arr_isos == idx] = idx_using_for_collection
            
    print('forming boolean arrays (voxelarray)...')
    voxelarray = np.zeros_like(arr_isos, dtype=np.bool_)
    for idx in isos_to_show:
        isos_to_properties[idx]['cube'] = (arr_isos==idx)
        voxelarray = voxelarray | isos_to_properties[idx]['cube']
    
    print('forming color arrays (colors)...')
    colors = np.empty(voxelarray.shape, dtype=object)
    for idx in isos_to_show:
        colors[isos_to_properties[idx]['cube']] = isos_to_properties[idx]['color']
        
    print('Main "voxels" plot call...')
    edgecolors = None # default: None, or use some edgecolors scalar e.g. 'k', or use list
    shade = True      # default: True
    ax.voxels(voxelarray, facecolors=colors, edgecolors=edgecolors, shade=shade, zorder=10, rasterized=rasterized)
    
    # Fill in x, y, z labels and ticks
    # - if xyz_label_dict is not None, 
    # - otherwise, if sweep is not None, this is automated
    # - otherwise, they are automatic and called 'param0', 'param1', 'param2'
    ax_fs = 12
    ax_labelpad = 10
    # the following are currently unused, but could be helpful
    #ax_tickpad = 5
    #rotation=-15
    #verticalalignment='baseline',
    #horizontalalignment='left'
    
    if xyz_labels_dict is not None:
        xlabel, ylabel, zlabel = xyz_labels_dict['labels']
        xlims, ylims, zlims = xyz_labels_dict['xlims'], xyz_labels_dict['ylims'], xyz_labels_dict['zlims']
        xticks = xyz_labels_dict['xticks']
        xtick_labels = xyz_labels_dict['xtick_labels']
        yticks = xyz_labels_dict['yticks']
        ytick_labels = xyz_labels_dict['ytick_labels']
        zticks = xyz_labels_dict['zticks']
        ztick_labels = xyz_labels_dict['ztick_labels']
        # settings
        ax.set_xticks(xticks, xtick_labels, fontsize=ax_fs)
        ax.set_yticks(yticks, ytick_labels, fontsize=ax_fs)
        ax.set_zticks(zticks, ztick_labels, fontsize=ax_fs)
        # set ax limits at end, after ticks
        if xlims is not None:
            ax.set_xlim(xlims)
        if ylims is not None:
            ax.set_ylim(ylims)
        if xlims is not None:
            ax.set_zlim(zlims)
    elif sweep is not None:
        xlabel, ylabel, zlabel = sweep.params_name
        xlow, xhigh = sweep.params_values[0][0], sweep.params_values[0][-1]
        ylow, yhigh = sweep.params_values[1][0], sweep.params_values[1][-1]
        zlow, zhigh = sweep.params_values[2][0], sweep.params_values[2][-1]
        
        n_ticks = 4
        nnx, nny, nnz = arr_isos.shape
        
        xticks = np.linspace(0, nnx, n_ticks)
        xtick_labels = np.round(np.linspace(xlow, xhigh, n_ticks), 3)
        yticks = np.linspace(0, nny, n_ticks)
        ytick_labels = np.round(np.linspace(ylow, yhigh, n_ticks), 2)
        zticks = np.linspace(0, nnz, n_ticks)
        ztick_labels = np.round(np.linspace(zlow, zhigh, n_ticks), 3)
        
        ax.set_xticks(xticks, xtick_labels, fontsize=ax_fs)
        ax.set_yticks(yticks, ytick_labels, fontsize=ax_fs)
        ax.set_zticks(zticks, ztick_labels, fontsize=ax_fs)
        
    else: 
        print('Warning: need to pass xyz_labels_dict (preferred) OR Sweep object to generate axis labels/ticks')
        xlabel, ylabel, zlabel = ['p1', 'p2', 'p3']
    # ax labels
    ax.set_xlabel(xlabel, fontsize=ax_fs, labelpad=ax_labelpad)
    ax.set_ylabel(ylabel, fontsize=ax_fs, labelpad=ax_labelpad)
    ax.set_zlabel(zlabel, fontsize=ax_fs, labelpad=ax_labelpad)
    
    if proj:
        assert sweep is not None
        param_vals_x = sweep.params_values[0]
        param_vals_y = sweep.params_values[1]
        param_vals_z = sweep.params_values[2]
        xs = []
        ys = []
        zs = []
        for _, pair in enumerate(np.ndenumerate(voxelarray)):
            idx, boolv = pair
            if boolv:
                #xs.append(param_vals_x[idx[0]])
                #ys.append(param_vals_y[idx[1]])
                #zs.append(param_vals_z[idx[2]])
                xs.append(idx[0])
                ys.append(idx[1])
                zs.append(idx[2])
        #print(xs)
        alpha_sc = 0.05
        s_sc = 5
        c_sc = 'darkgrey'
        zshadow = zticks[0] #zlow
        ax.scatter(xs, ys, zshadow, c=c_sc, alpha=alpha_sc, s=s_sc, zorder=0, rasterized=rasterized)
        #xshadow = xticks[0] #xlow
        #ax.scatter(xshadow, ys, zs, c=c_sc, alpha=alpha_sc, s=s_sc, zorder=0, rasterized=rasterized)
        #yshadow = yticks[-1] #yhigh
        #ax.scatter(xs, yshadow, zs, c=c_sc, alpha=alpha_sc, s=s_sc, zorder=0, rasterized=rasterized)'''
    
    print('Post-processing (labels, legend)...')
    # generate legend
    patches = [mpatches.Patch(color=isos_to_properties[i]['color'], 
                              label=isos_to_properties[i]['label']) 
               for i in isos_to_show]
    if legend:
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=2.5, fontsize=14., ncol=1)
    
    ax.view_init(elev=11, azim=-133) 
    #ax.view_init(elev=45, azim=-100, roll=-98) 
    #ax.set_xlim(xticks[-1], xticks[0])  # flips x axis
    
    ax.set_rasterization_zorder(0)  # rasterize z < 10 according to MPL... 

    # Save files in pdf and eps format
    plt.savefig(NB_OUTPUT + os.sep + "mpl_voxel_isos%s.pdf" % isos_to_show, dpi=1200)
    plt.savefig(NB_OUTPUT + os.sep + "mpl_voxel_isos%s.svg" % isos_to_show, dpi=1200)
    
    print('exiting function and plotting...')
    return ax, patches


In [None]:
#%matplotlib widget
%matplotlib inline

fig = plt.figure(figsize=(4,4))
ax = plt.figure().add_subplot(projection='3d')

arr_isos_to_plot = mini_arr_isos

isos_to_show = [2, 3, 6, 7, 8, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
#isos_to_show = []
xyz_labels_dict = None
iso_collections = ISO_COLLECTION_A  # ISO_COLLECTION_A or ISO_COLLECTION_H3
ax, _ = voxel_mpl_3D(arr_isos_to_plot, isos_to_show, iso_collections=iso_collections, 
                     sweep=sweep, xyz_labels_dict=xyz_labels_dict, ax=ax)

In [None]:
np.unique(mini_arr_isos)

In [None]:
fig = plt.figure(figsize=(4,4))
ax = plt.figure().add_subplot(projection='3d')

isos_to_show = (2, 3, 6, 7, 8, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26)

In [None]:
print(mini_arr_isos.shape)
print(arr_isos.shape)

In [None]:
print('ax view re-init...?')
ax.view_init(30, 0)
print('done')

In [None]:
plt.show(ax)
print('done')

In [None]:
def video_cube_rotate(arr_isos, xyz_labels_dict=None):
    """
    Plots 3d isomorphism cube data as a rotating video    
    """
    nframes, fps = 3, 1
    #nframes, fps = 360, 20
    
    # 3D orientation parameters
    elev = 15
    azim = 0
    dz = 1.5 * 360.0 / nframes
    
    writer = setup_ffmpeg_writer(
    fps=fps, bitrate=None, title='rotating_cube', artist='Matplotlib', 
    comment='Rotating isomorphism cube (v1)')
    dpi = 450
    fpath_video = NB_OUTPUT + os.sep + "cube_v1.mp4"
    
    # Initial drawing of the plot
    gridspec_kw = {'width_ratios': [1, 1]}
    #gridspec_kw = {}
    fig = plt.figure(constrained_layout=False, figsize=(10, 5))
    gspec = fig.add_gridspec(ncols=2, nrows=1, **gridspec_kw)

    # Left plot ax0 
    ax0 = fig.add_subplot(gspec[0, 0], projection='3d')
    #isos_to_show = [2, 3, 6, 7, 8, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
    isos_to_show = []  # setting this to empty list means only plot isos from isos_collections
    print('Drawing the first frame...')
    ax0, legend_patches = voxel_mpl_3D(arr_isos, isos_to_show, 
                                       iso_collections=ISO_COLLECTION_H3, 
                                       xyz_labels_dict=xyz_labels_dict,
                                       legend=False, ax=ax0)
    print('initial frame done.')

    # Right plot ax1
    ax1 = fig.add_subplot(gspec[0, 1])
    ax1.grid('off')
    ax1.axis('off')
    #ax1.legend(handles=legend_patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=14., ncol=1)
    ax1.legend(handles=legend_patches, fontsize=12., ncol=1, loc='center right')
    
    # Update the frames for the movie (just rotate the plot)
    print('Rotating:')
    with writer.saving(fig, fpath_video, dpi):
        for i in range(nframes):
            print("\tframe %d of %d..." % (i, nframes))
            ax0.view_init(elev, azim + i * dz)
            writer.grab_frame()
            
    print('Video done. Saved to:', fpath_video)
    return fpath_video

In [None]:
#fpath_video_cube = video_cube_rotate(mini_arr_isos, xyz_labels_dict=xyz_ticks_mini)
fpath_video_cube = video_cube_rotate(arr_isos, xyz_labels_dict=xyz_ticks_orig)

In [None]:
HTML(vid_fpath_to_html(fpath_video_cube))

## Plot individual bioints on cube

In [None]:
#%matplotlib widget
%matplotlib inline

In [None]:
alphaval_dros = '80'
alphaval = ''  # '80' or '' for alpha=0.8 or 1.0
proj = False  # project voxels onto each plane

ISO_COLLECTION_DROSOPHILA = {
    'Drosophila (16 cells)': dict(
        group = [NAMES_TO_INT[s] for s in ['D. melanogaster (16)']
                ],
        color = '#FFFFFF%s' % alphaval_dros),
}
# original color #ECC986%s

ISO_COLLECTION_LACEWINGS = {
    'Lacewing (main; 12 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (12)']
                ],
        color = '#72C2EC%s' % alphaval),
    'Lacewing (alt.; 13 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (13)']
                ],
        color = '#D8A0C8%s' % alphaval),
    'Lacewing (alt.; 14 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (14)']
                ],
        color = '#8BCBA7%s' % alphaval),
}

ISO_COLLECTION_COMBO = {
    'Drosophila (16 cells)': dict(
        group = [NAMES_TO_INT[s] for s in ['D. melanogaster (16)']
                ],
        color = '#FFFFFF%s' % alphaval_dros),
    'Lacewing (main; 12 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (12)']
                ],
        color = '#72C2EC%s' % alphaval),
    'Lacewing (alt.; 13 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (13)']
                ],
        color = '#D8A0C8%s' % alphaval),
    'Lacewing (alt.; 14 cell)': dict(
        group = [NAMES_TO_INT[s] for s in ['C. perla (14)']
                ],
        color = '#8BCBA7%s' % alphaval),
}

In [None]:
xyz_ticks_sweep = custom_xyz_ticks(arr_isos, sweep, xlims=[0.0058, 0.015], ylims=[0, 7.0], zlims=[-0.09, 0.09])

In [None]:
fig = plt.figure(figsize=(4,4))
ax = plt.figure().add_subplot(projection='3d')

arr_isos_to_plot = arr_isos #mini_arr_isos

isos_to_show = [3]

ax, _ = voxel_mpl_3D(arr_isos_to_plot, isos_to_show, iso_collections=ISO_COLLECTION_DROSOPHILA, 
                     sweep=sweep, xyz_labels_dict=xyz_ticks_sweep, proj=proj, ax=ax)

In [None]:
fig = plt.figure(figsize=(4,4))
ax = plt.figure().add_subplot(projection='3d')

arr_isos_to_plot = arr_isos #mini_arr_isos

#isos_to_show = [2, 3, 6, 7, 8, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
#isos_to_show = [3, 7, 8, 9]
isos_to_show = [7, 8, 9]

ax, _ = voxel_mpl_3D(arr_isos_to_plot, isos_to_show, iso_collections=ISO_COLLECTION_LACEWINGS, 
                     sweep=sweep, xyz_labels_dict=xyz_ticks_sweep, proj=proj, ax=ax)

In [None]:
fig = plt.figure(figsize=(4,4))
ax = plt.figure().add_subplot(projection='3d')

arr_isos_to_plot = arr_isos #mini_arr_isos

isos_to_show = [3,7,8,9]

ax, _ = voxel_mpl_3D(arr_isos_to_plot, isos_to_show, iso_collections=ISO_COLLECTION_COMBO, 
                     sweep=sweep, xyz_labels_dict=xyz_ticks_sweep, proj=proj, ax=ax)