## Generate overview figure elements

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

from IPython.display import HTML
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.animation 
import numpy as np

import colorcet as cc
from collections import Counter

from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV
from synapse_utils import io

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 14

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
repo_root = '../..'
checkpoint_path = '../../output/checkpoint__synapseclr__so3__second_stage'
dataset_path = '../../data/MICrONS__L23__8_8_40__processed'

# these are pre-computed (can be set to None)
contamination_indices_path = os.path.join(
    checkpoint_path, 'indices', 'contamination_meta_df_row_indices.npy')
reload_epoch = 99
node_idx_list = [0, 1, 2, 3]

In [None]:
import cuml
from cuml import PCA, TSNE, UMAP

In [None]:
# # load features
load_prefix = 'encoder.fc'
save_prefix = load_prefix.replace('.', '_')
l2_normalize = False

features_nf, meta_df, meta_ext_df = io.load_features(
    checkpoint_path,
    node_idx_list,
    reload_epoch,
    feature_hook=load_prefix,
    dataset_path=dataset_path,
    l2_normalize=l2_normalize,
    contamination_indices_path=contamination_indices_path)

n_pca_components = 128
if n_pca_components < features_nf.shape[-1]:
    features_nf = PCA(n_components=n_pca_components).fit_transform(features_nf)

## Generate images

In [None]:
from synapse_utils import vis

import torch
import numpy as np
import matplotlib.pylab as plt
import pandas as pd

from synapse_dataset import SynapseDataset
from synapse_simclr import utils
from synapse_utils import vis
from synapse_augmenter import SynapseAugmenter
from synapse_augmenter import consts as syn_consts

from scipy.ndimage import binary_erosion
import plotly.graph_objects as go
import pyvista as pv
from skimage.filters import threshold_otsu
from sklearn.decomposition import PCA as SKPCA

import warnings
from ipywidgets import interactive

from typing import Optional

In [None]:
synapse_id = 1806870
synapse_index = np.nonzero((meta_df['synapse_id'] == synapse_id).values)[0].item()

## Mask only

In [None]:
aug_yaml_path = os.path.join(repo_root, 'configs' , 'config__synapseclr__so3__second_stage', 'augmenter_display.yaml')
ctx = vis.SynapseVisContext(
    dataset_path, aug_yaml_path, meta_df_override=meta_df, device='cuda')

inside_erosion_radius = 3
max_points = 200_000
max_triangles = 100_000
otsu_prefactor = 0.3
point_cloud_opacity = 0.1
point_cloud_size = 0.8
surface_opacity = 0.03
surface_point_cloud_opacity = 0.02
surface_point_cloud_size = 1.0
surface_max_points = 100_000
tri_alpha = 0.05
zoom_out = 1.75
fig_width = 500
fig_height = 500
view_plane = 1

# get data
intensity_bcxyz, mask_bcxyz = ctx.aug.augment_raw_data([ctx.synapse_dataset[synapse_index][1]])
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()
intensity_xyz = intensity_bcxyz[0, 0, ...].cpu().numpy()

# camera props
cam_props = vis.get_optimal_camera_props(mask_cxyz)

# generate grid
final_img_size = mask_cxyz.shape[-1]
X, Y, Z = np.mgrid[:final_img_size, :final_img_size, :final_img_size]
X = (X ) / final_img_size
Y = (Y ) / final_img_size
Z = (Z ) / final_img_size

# preprocess
inside_mask_xyz = np.sum(mask_cxyz, 0) > 0
inside_mask_xyz = vis.erode_mask(inside_mask_xyz, inside_erosion_radius)
dark_mask_xyz = vis.get_dark_mask(intensity_xyz, inside_mask_xyz, otsu_prefactor)
plot_mask_xyz = dark_mask_xyz

indices = np.random.permutation(X[plot_mask_xyz].flatten().shape[0])[:max_points]
x = X[plot_mask_xyz].flatten()[indices]
y = Y[plot_mask_xyz].flatten()[indices]
z = Z[plot_mask_xyz].flatten()[indices]
c = intensity_xyz[plot_mask_xyz].flatten()[indices]
color = plt.cm.Greys_r(c)

data = []

mask_bcxyz[0, 0] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 0], radius=5)[0]
mask_bcxyz[0, 2] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 2], radius=5)[0]
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()

for mask_int, color, plot_type in zip(
    [0, 1, 2],
    ['rgb(138,43,226)',
     'rgb(255,128,0)',
     'rgb(52,235,128)'],
    ['surface',
     'points',
     'surface']):

    if plot_type == 'surface':

        pre_mask_xyz = mask_cxyz[mask_int] ^ vis.erode_mask(mask_cxyz[mask_int].copy(), 1)
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': surface_point_cloud_size,
                        'color': color,
                        'opacity': 2 * surface_point_cloud_opacity
                    }))

    elif plot_type == 'points':

        pre_mask_xyz = mask_cxyz[mask_int]
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': point_cloud_size,
                        'color': color,
                        'opacity': point_cloud_opacity
                    }))


fig = go.Figure(data);

# setup the scene
camera_dict = dict(
    eye=dict(
        x=-3,
        y=-4,
        z=5
    )
)


x_mid = 0.5 * (np.max(X) + np.min(X))
y_mid = 0.5 * (np.max(Y) + np.min(Y))
z_mid = 0.5 * (np.max(Z) + np.min(Z))
lam = 0.1

scene_dict = dict(
    xaxis_title='',
    yaxis_title='',
    zaxis_title='',
    aspectratio=dict(x=1, y=1, z=1),
    xaxis=dict(
        range=[np.min(X) + lam * (x_mid - np.min(X)), np.max(X) - lam * (np.max(X) - x_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(240,240,240)"),
    yaxis=dict(
        range=[np.min(Y) + lam * (y_mid - np.min(Y)), np.max(Y) - lam * (np.max(Y) - y_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(230,230,230)"),
    zaxis=dict(
        range=[np.min(Z) + lam * (z_mid - np.min(Z)), np.max(Z) - lam * (np.max(Z) - z_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(220,220,220)"),
)

fig.update_layout(
    scene=scene_dict,
    scene_camera=camera_dict,
    autosize=False,
    width=fig_width,
    height=fig_height,
    showlegend=False,
    font=dict(size=24),
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0,
        pad=4
    ));

fig.layout.scene.camera.projection.type = "orthographic"

In [None]:
fig.write_image('../../output/analysis/overview/mask_only.png')

In [None]:
aug_yaml_path = os.path.join(repo_root, 'configs' , 'config__synapseclr__so3__second_stage', 'augmenter_display.yaml')
ctx = vis.SynapseVisContext(
    dataset_path, aug_yaml_path, meta_df_override=meta_df, device='cuda')

inside_erosion_radius = 3
max_points = 200_000
max_triangles = 100_000
otsu_prefactor = 0.3
point_cloud_opacity = 0.1
point_cloud_size = 0.8
surface_opacity = 0.03
surface_point_cloud_opacity = 0.02
surface_point_cloud_size = 1.0
surface_max_points = 100_000
tri_alpha = 0.05
zoom_out = 1.75
fig_width = 500
fig_height = 500
view_plane = 1

# get data
intensity_bcxyz, mask_bcxyz = ctx.aug.augment_raw_data([ctx.synapse_dataset[synapse_index][1]])
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()
intensity_xyz = intensity_bcxyz[0, 0, ...].cpu().numpy()

# override mask
mask_cxyz = np.ones_like(mask_cxyz)

# camera props
cam_props = vis.get_optimal_camera_props(mask_cxyz)

# generate grid
final_img_size = mask_cxyz.shape[-1]
X, Y, Z = np.mgrid[:final_img_size, :final_img_size, :final_img_size]
X = (X ) / final_img_size
Y = (Y ) / final_img_size
Z = (Z ) / final_img_size

# preprocess
inside_mask_xyz = np.sum(mask_cxyz, 0) > 0
inside_mask_xyz = vis.erode_mask(inside_mask_xyz, inside_erosion_radius)
dark_mask_xyz = vis.get_dark_mask(intensity_xyz, inside_mask_xyz, otsu_prefactor)
plot_mask_xyz = dark_mask_xyz

indices = np.random.permutation(X[plot_mask_xyz].flatten().shape[0])[:max_points]
x = X[plot_mask_xyz].flatten()[indices]
y = Y[plot_mask_xyz].flatten()[indices]
z = Z[plot_mask_xyz].flatten()[indices]
c = intensity_xyz[plot_mask_xyz].flatten()[indices]
color = plt.cm.Greys_r(c)

data = []

data.append(
    go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker={
                'size': point_cloud_size,
                'color': color,
                'opacity': point_cloud_opacity
            }))

fig = go.Figure(data);

# setup the scene
camera_dict = dict(
    eye=dict(
        x=-3,
        y=-4,
        z=5
    )
)


x_mid = 0.5 * (np.max(X) + np.min(X))
y_mid = 0.5 * (np.max(Y) + np.min(Y))
z_mid = 0.5 * (np.max(Z) + np.min(Z))
lam = 0.1

scene_dict = dict(
    xaxis_title='',
    yaxis_title='',
    zaxis_title='',
    aspectratio=dict(x=1, y=1, z=1),
    xaxis=dict(
        range=[np.min(X) + lam * (x_mid - np.min(X)), np.max(X) - lam * (np.max(X) - x_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(240,240,240)"),
    yaxis=dict(
        range=[np.min(Y) + lam * (y_mid - np.min(Y)), np.max(Y) - lam * (np.max(Y) - y_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(230,230,230)"),
    zaxis=dict(
        range=[np.min(Z) + lam * (z_mid - np.min(Z)), np.max(Z) - lam * (np.max(Z) - z_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(220,220,220)"),
)

fig.update_layout(
    scene=scene_dict,
    scene_camera=camera_dict,
    autosize=False,
    width=fig_width,
    height=fig_height,
    showlegend=False,
    font=dict(size=24),
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0,
        pad=4
    ));

fig.layout.scene.camera.projection.type = "orthographic"

In [None]:
fig.write_image('../../output/analysis/overview/intensity_only.png')

In [None]:
aug_yaml_path = os.path.join(repo_root, 'configs' , 'config__synapseclr__so3__second_stage', 'augmenter_display.yaml')
ctx = vis.SynapseVisContext(
    dataset_path, aug_yaml_path, meta_df_override=meta_df, device='cuda')

inside_erosion_radius = 3
max_points = 200_000
max_triangles = 100_000
otsu_prefactor = 0.3
point_cloud_opacity = 0.1
point_cloud_size = 0.8
surface_opacity = 0.03
surface_point_cloud_opacity = 0.02
surface_point_cloud_size = 1.0
surface_max_points = 100_000
tri_alpha = 0.05
zoom_out = 1.75
fig_width = 500
fig_height = 500
view_plane = 1

# get data
intensity_bcxyz, mask_bcxyz = ctx.aug.augment_raw_data([ctx.synapse_dataset[synapse_index][1]])
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()
intensity_xyz = intensity_bcxyz[0, 0, ...].cpu().numpy()

# camera props
cam_props = vis.get_optimal_camera_props(mask_cxyz)

# generate grid
final_img_size = mask_cxyz.shape[-1]
X, Y, Z = np.mgrid[:final_img_size, :final_img_size, :final_img_size]
X = (X ) / final_img_size
Y = (Y ) / final_img_size
Z = (Z ) / final_img_size

# preprocess
inside_mask_xyz = np.sum(mask_cxyz, 0) > 0
inside_mask_xyz = vis.erode_mask(inside_mask_xyz, inside_erosion_radius)
dark_mask_xyz = vis.get_dark_mask(intensity_xyz, inside_mask_xyz, otsu_prefactor)
plot_mask_xyz = dark_mask_xyz

indices = np.random.permutation(X[plot_mask_xyz].flatten().shape[0])[:max_points]
x = X[plot_mask_xyz].flatten()[indices]
y = Y[plot_mask_xyz].flatten()[indices]
z = Z[plot_mask_xyz].flatten()[indices]
c = intensity_xyz[plot_mask_xyz].flatten()[indices]
color = plt.cm.Greys_r(c)

data = []

data.append(
    go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker={
                'size': point_cloud_size,
                'color': color,
                'opacity': point_cloud_opacity
            }))

mask_bcxyz[0, 0] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 0], radius=5)[0]
mask_bcxyz[0, 2] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 2], radius=5)[0]
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()

#     ['rgb(138,43,226)',
#      'rgb(255,128,0)',
#      'rgb(52,235,128)'],

for mask_int, color, plot_type in zip(
    [0, 2],
    ['rgb(0,0,0)',
     'rgb(0,0,0)'],
    ['surface',
     'surface']):

    if plot_type == 'surface':

        pre_mask_xyz = mask_cxyz[mask_int] ^ vis.erode_mask(mask_cxyz[mask_int].copy(), 1)
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': surface_point_cloud_size,
                        'color': color,
                        'opacity': 2 * surface_point_cloud_opacity
                    }))

    elif plot_type == 'points':

        pre_mask_xyz = mask_cxyz[mask_int]
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': point_cloud_size,
                        'color': color,
                        'opacity': point_cloud_opacity
                    }))


fig = go.Figure(data);

# setup the scene
camera_dict = dict(
    eye=dict(
        x=-3,
        y=-4,
        z=5
    )
)


x_mid = 0.5 * (np.max(X) + np.min(X))
y_mid = 0.5 * (np.max(Y) + np.min(Y))
z_mid = 0.5 * (np.max(Z) + np.min(Z))
lam = 0.1

scene_dict = dict(
    xaxis_title='',
    yaxis_title='',
    zaxis_title='',
    aspectratio=dict(x=1, y=1, z=1),
    xaxis=dict(
        range=[np.min(X) + lam * (x_mid - np.min(X)), np.max(X) - lam * (np.max(X) - x_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(240,240,240)"),
    yaxis=dict(
        range=[np.min(Y) + lam * (y_mid - np.min(Y)), np.max(Y) - lam * (np.max(Y) - y_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(230,230,230)"),
    zaxis=dict(
        range=[np.min(Z) + lam * (z_mid - np.min(Z)), np.max(Z) - lam * (np.max(Z) - z_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(220,220,220)"),
)

fig.update_layout(
    scene=scene_dict,
    scene_camera=camera_dict,
    autosize=False,
    width=fig_width,
    height=fig_height,
    showlegend=False,
    font=dict(size=24),
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0,
        pad=4
    ));

fig.layout.scene.camera.projection.type = "orthographic"

In [None]:
fig.write_image('../../output/analysis/overview/joint.png')

## SynapseCLR

In [None]:
synapse_id = 1615486
synapse_index = np.nonzero((meta_df['synapse_id'] == synapse_id).values)[0].item()

In [None]:
aug_yaml_path = os.path.join(repo_root, 'configs' , 'config__synapseclr__so3__second_stage', 'augmenter_display.yaml')
ctx = vis.SynapseVisContext(
    dataset_path, aug_yaml_path, meta_df_override=meta_df, device='cuda')

inside_erosion_radius = 3
max_points = 200_000
max_triangles = 100_000
otsu_prefactor = 0.3
point_cloud_opacity = 0.1
point_cloud_size = 0.8
surface_opacity = 0.03
surface_point_cloud_opacity = 0.02
surface_point_cloud_size = 1.0
surface_max_points = 100_000
tri_alpha = 0.05
zoom_out = 1.75
fig_width = 500
fig_height = 500
view_plane = 1

# get data
intensity_bcxyz, mask_bcxyz = ctx.aug.augment_raw_data([ctx.synapse_dataset[synapse_index][1]])
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()
intensity_xyz = intensity_bcxyz[0, 0, ...].cpu().numpy()

# camera props
cam_props = vis.get_optimal_camera_props(mask_cxyz)

# generate grid
final_img_size = mask_cxyz.shape[-1]
X, Y, Z = np.mgrid[:final_img_size, :final_img_size, :final_img_size]
X = (X ) / final_img_size
Y = (Y ) / final_img_size
Z = (Z ) / final_img_size

# preprocess
inside_mask_xyz = np.sum(mask_cxyz, 0) > 0
inside_mask_xyz = vis.erode_mask(inside_mask_xyz, inside_erosion_radius)
dark_mask_xyz = vis.get_dark_mask(intensity_xyz, inside_mask_xyz, otsu_prefactor)
plot_mask_xyz = dark_mask_xyz

indices = np.random.permutation(X[plot_mask_xyz].flatten().shape[0])[:max_points]
x = X[plot_mask_xyz].flatten()[indices]
y = Y[plot_mask_xyz].flatten()[indices]
z = Z[plot_mask_xyz].flatten()[indices]
c = intensity_xyz[plot_mask_xyz].flatten()[indices]
color = plt.cm.Greys_r(c)

data = []

data.append(
    go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker={
                'size': point_cloud_size,
                'color': color,
                'opacity': point_cloud_opacity
            }))

mask_bcxyz[0, 0] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 0], radius=5)[0]
mask_bcxyz[0, 2] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 2], radius=5)[0]
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()

#     ['rgb(138,43,226)',
#      'rgb(255,128,0)',
#      'rgb(52,235,128)'],

for mask_int, color, plot_type in zip(
    [0, 2],
    ['rgb(0,0,0)',
     'rgb(0,0,0)'],
    ['surface',
     'surface']):

    if plot_type == 'surface':

        pre_mask_xyz = mask_cxyz[mask_int] ^ vis.erode_mask(mask_cxyz[mask_int].copy(), 1)
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': surface_point_cloud_size,
                        'color': color,
                        'opacity': 2 * surface_point_cloud_opacity
                    }))

    elif plot_type == 'points':

        pre_mask_xyz = mask_cxyz[mask_int]
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': point_cloud_size,
                        'color': color,
                        'opacity': point_cloud_opacity
                    }))


fig = go.Figure(data);

# setup the scene
camera_dict = dict(
    eye=dict(
        x=-3,
        y=-4,
        z=5
    )
)


x_mid = 0.5 * (np.max(X) + np.min(X))
y_mid = 0.5 * (np.max(Y) + np.min(Y))
z_mid = 0.5 * (np.max(Z) + np.min(Z))
lam = 0.1

scene_dict = dict(
    xaxis_title='',
    yaxis_title='',
    zaxis_title='',
    aspectratio=dict(x=1, y=1, z=1),
    xaxis=dict(
        range=[np.min(X) + lam * (x_mid - np.min(X)), np.max(X) - lam * (np.max(X) - x_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(240,240,240)"),
    yaxis=dict(
        range=[np.min(Y) + lam * (y_mid - np.min(Y)), np.max(Y) - lam * (np.max(Y) - y_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(230,230,230)"),
    zaxis=dict(
        range=[np.min(Z) + lam * (z_mid - np.min(Z)), np.max(Z) - lam * (np.max(Z) - z_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(220,220,220)"),
)

fig.update_layout(
    scene=scene_dict,
    scene_camera=camera_dict,
    autosize=False,
    width=fig_width,
    height=fig_height,
    showlegend=False,
    font=dict(size=24),
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0,
        pad=4
    ));

fig.layout.scene.camera.projection.type = "orthographic"

In [None]:
fig.write_image(f'../../output/analysis/overview/joint__{synapse_id}.png')

In [None]:
aug_yaml_path = os.path.join(repo_root, 'configs' , 'config__synapseclr__so3__second_stage', 'augmenter_display.yaml')
ctx = vis.SynapseVisContext(
    dataset_path, aug_yaml_path, meta_df_override=meta_df, device='cuda')

inside_erosion_radius = 3
max_points = 200_000
max_triangles = 100_000
otsu_prefactor = 0.3
point_cloud_opacity = 0.05
point_cloud_size = 0.8
surface_opacity = 0.03
surface_point_cloud_opacity = 0.02
surface_point_cloud_size = 1.0
surface_max_points = 100_000
tri_alpha = 0.05
zoom_out = 1.75
fig_width = 500
fig_height = 500
view_plane = 1

# get data
intensity_bcxyz, mask_bcxyz = ctx.aug.augment_raw_data([ctx.synapse_dataset[synapse_index][1]])
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()
intensity_xyz = intensity_bcxyz[0, 0, ...].cpu().numpy()

cam_props = {
    'cleft_com_x': 95.,
    'cleft_com_y': 95.,
    'cleft_com_z': 95.,
    'cleft_normal': np.asarray([0., 0., 1.]),
    'cleft_inplane_1': np.asarray([1., 0., 0.]),
    'cleft_inplane_2': np.asarray([0., 1., 0.]),
}

# generate grid
final_img_size = mask_cxyz.shape[-1]
X, Y, Z = np.mgrid[:final_img_size, :final_img_size, :final_img_size]
X = (X ) / final_img_size
Y = (Y ) / final_img_size
Z = (Z ) / final_img_size

# preprocess
inside_mask_xyz = np.sum(mask_cxyz, 0) > 0
inside_mask_xyz = vis.erode_mask(inside_mask_xyz, inside_erosion_radius)
dark_mask_xyz = vis.get_dark_mask(intensity_xyz, inside_mask_xyz, otsu_prefactor)
plot_mask_xyz = dark_mask_xyz

indices = np.random.permutation(X[plot_mask_xyz].flatten().shape[0])[:max_points]
x = X[plot_mask_xyz].flatten()[indices]
y = Y[plot_mask_xyz].flatten()[indices]
z = Z[plot_mask_xyz].flatten()[indices]
c = intensity_xyz[plot_mask_xyz].flatten()[indices]
color = plt.cm.Greys_r(c)

data = []

data.append(
    go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker={
                'size': point_cloud_size,
                'color': color,
                'opacity': point_cloud_opacity
            }))

mask_bcxyz[0, 0] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 0], radius=5)[0]
mask_bcxyz[0, 2] = ctx.aug.inflate_binary_mask(mask_bcxyz[:, 2], radius=5)[0]
mask_cxyz = mask_bcxyz[0, :, :, :].cpu().numpy()

#     ['rgb(138,43,226)',
#      'rgb(255,128,0)',
#      'rgb(52,235,128)'],

for mask_int, color, plot_type in zip(
    [0, 2],
    ['rgb(0,0,0)',
     'rgb(0,0,0)'],
    ['surface',
     'surface']):

    if plot_type == 'surface':

        pre_mask_xyz = mask_cxyz[mask_int] ^ vis.erode_mask(mask_cxyz[mask_int].copy(), 1)
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': surface_point_cloud_size,
                        'color': color,
                        'opacity': 2 * surface_point_cloud_opacity
                    }))

    elif plot_type == 'points':

        pre_mask_xyz = mask_cxyz[mask_int]
        x = X[pre_mask_xyz].flatten()
        y = Y[pre_mask_xyz].flatten()
        z = Z[pre_mask_xyz].flatten()
        indices = np.random.permutation(len(x))[:surface_max_points]
        x = x[indices]
        y = y[indices]
        z = z[indices]

        data.append(
            go.Scatter3d(
                    x=x,
                    y=y,
                    z=z,
                    mode='markers',
                    marker={
                        'size': point_cloud_size,
                        'color': color,
                        'opacity': point_cloud_opacity
                    }))


fig = go.Figure(data);

# setup the scene
camera_dict = dict(
    eye=dict(
        x=-3,
        y=-4,
        z=5
    )
)


x_mid = 0.5 * (np.max(X) + np.min(X))
y_mid = 0.5 * (np.max(Y) + np.min(Y))
z_mid = 0.5 * (np.max(Z) + np.min(Z))
lam = 0.1

scene_dict = dict(
    xaxis_title='',
    yaxis_title='',
    zaxis_title='',
    aspectratio=dict(x=1, y=1, z=1),
    xaxis=dict(
        range=[np.min(X) + lam * (x_mid - np.min(X)), np.max(X) - lam * (np.max(X) - x_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(240,240,240)"),
    yaxis=dict(
        range=[np.min(Y) + lam * (y_mid - np.min(Y)), np.max(Y) - lam * (np.max(Y) - y_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(230,230,230)"),
    zaxis=dict(
        range=[np.min(Z) + lam * (z_mid - np.min(Z)), np.max(Z) - lam * (np.max(Z) - z_mid)],
        visible=True,
        showticklabels=False,
        linewidth=2,
        linecolor='black',
        showgrid=False,
        gridcolor='rgb(200,200,200)',
        backgroundcolor="rgb(220,220,220)"),
)

fig.update_layout(
    scene=scene_dict,
    scene_camera=camera_dict,
    autosize=False,
    width=fig_width,
    height=fig_height,
    showlegend=False,
    font=dict(size=24),
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0,
        pad=4
    ));

fig.layout.scene.camera.projection.type = "orthographic"

In [None]:
fig.write_image(f'../../output/analysis/overview/joint__{synapse_id}__aug__{int(1000 * np.random.rand())}.png')