In [None]:
import nest
import os
import anndata
import scipy
import sklearn.metrics
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
import squidpy as sq
import networkx as nx
from pathlib import Path
import time
from tqdm import tqdm
from scipy.sparse import csr_matrix
from scipy.stats import pearsonr


import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import rc

import leidenalg as la
import igraph as ig
from scipy.sparse import coo_array
from scipy.spatial.distance import cdist
from scipy.spatial import KDTree
from scipy.spatial import ConvexHull
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.colors as colors

import warnings
warnings.filterwarnings("ignore")

In [None]:
available_datasets = ["V1_Mouse_Brain_Sagittal_Anterior", "V1_Mouse_Brain_Sagittal_Posterior", "seqfish", "merfish",
                     "slideseq", "V1_Breast_Cancer_Block_A_Section_1"]
dataset = "merfish"

In [None]:
cache_dir = os.path.expanduser(f"~/Dropbox/data/ms/datasets/{dataset}")
image_save_dir = os.path.expanduser(f"images/{dataset}/")
nest.plot.set_dataset_plot_parameters(dataset)

#adata = gs.read(os.path.join(cache_dir, 'adata'))
#adata.layers['exp'] = adata.X.expm1()

In [None]:
adata = sq.datasets.merfish()
bregma_values = pd.unique(adata.obs.Bregma)
adata.uns['um_scale'] = 0.001
adata.obs['z'] = adata.obs.Bregma*0.01

In [None]:
adata

In [None]:
um_scale = adata.uns['um_scale']
secreted_std = 50 * um_scale
contact_threshold = 20 * um_scale

perform_permutations = False
activity_matrix = nest.compute_activity(adata, secreted_std=secreted_std,
                                        contact_threshold=contact_threshold,
                                        sig_threshold=0.95,
                                        perform_permutation=True,
                                        save_activity=True, verbose=True,
                                        min_active_count=20,
                                        interactions=None,
                                        z_key="z",
                                        K=None)

In [None]:
neighbor_eps=0.06
min_samples=5
hotspot_min_size=5

from nest.hotspot.hotspot import _compute_cutoff

cols = enumerate(list(activity_matrix))
method = "permutation"
z_key = "z"
region_dict = {}
for _, interaction in cols:
    # modify this to account for possibly having to go over multiple z layers
    data = activity_matrix[interaction]

    region_offset = 0
    if z_key is None:
        if method == 'permutation':
            cutoff = adata.uns['activity_significance_cutoff'][interaction]
        else:
            cutoff = _compute_cutoff(data, log=False)
        inds = data > cutoff
        regions = compute_hotspots(adata=adata, input_data=np.where(inds)[0], return_regions=True, **kwargs)
    else:
        regions = -1 * np.ones(adata.shape[0])
        for val in np.unique(adata.obs[z_key]):
            data_sub = data[adata.obs[z_key] == val]
            if method == 'permutation':
                cutoff = adata.uns['activity_significance_cutoff'][interaction]
            else:
                cutoff = _compute_cutoff(data_sub, log=False)
            inds = data > cutoff
            cur_slice_inds = np.logical_and(inds, adata.obs[z_key] == val)
            regions_sub = nest.compute_hotspots(adata=adata, input_data=np.where(cur_slice_inds)[0],
                                           return_regions=True, 
                                                eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                                                 core_only=False)
            # combine together into one array
            if regions_sub is None:
                continue
            out_inds = np.where(pd.notnull(regions_sub))[0]
            #print(out_inds)
            v = np.array(regions_sub[out_inds])
            regions[out_inds] = v + region_offset
            #print(regions[regions > 0])
            region_offset += np.max(v)
        if np.count_nonzero(regions != -1) > 0:
            regions = pd.Categorical(regions, categories=np.arange(1, np.max(regions) + 1))
            regions.categories = regions.categories.astype(np.int_)
        else:
            regions = None

    if regions is not None:
        region_dict[f"hotspots_{interaction}"] = regions
        
hotspots_df = pd.DataFrame(region_dict, index=adata.obs.index)
adata.obs.drop(adata.obs.filter(regex='hotspots_').columns.tolist(), axis=1, inplace=True)
adata.obs = pd.concat([adata.obs, hotspots_df], axis=1)

In [None]:
#tmp = adata.obs['TAC1_TACR1'] > adata.uns['activity_significance_cutoff']['TAC1_TACR1']
#adata.obs['tmp']= tmp.astype(np.int_)
adata_sub = adata[adata.obs.Bregma == bregma_values[8]]
nest.plot.spatial(adata_sub, color=['hotspots_TAC1_TACR1'], color_map="Blues")
#nest.plot.hotspots(adata_sub_list_2[8], "TAC1_TACR1")

In [None]:
nv = len(np.unique(bregma_values))
for interaction in adata.uns['interactions']['interaction_name']:
    print(interaction)

    fig, axs = plt.subplots(nv, 2, figsize=(8, 3*nv))
    for idx, val in enumerate(np.unique(bregma_values)):

        adata_sub = adata[adata.obs.Bregma == val]
        nest.plot.hotspots(adata_sub, color=interaction, ax=axs[idx, 0], show=False)
        nest.plot.hotspots(adata_sub_list_2[idx], interaction, ax=axs[idx, 1], show=False)
    fig.savefig(f'images/merfish/{interaction}_2d_3d_comp.png', dpi=300, transparent=True, bbox_inches='tight')

In [None]:
neighbor_eps=0.06
min_samples=5
hotspot_min_size=5

nest.interaction_hotspots(adata, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                         core_only=False, method="otsu")

In [None]:
neighbor_eps=0.06
min_samples=5
hotspot_min_size=5

#sub_gene_list = ["Cck", "Penk"]
sub_gene_list = list(set(adata.var_names)-{'Fos'})
adata_sub_list = []
adata_sub_list_2 = []
region_vals = [6, 10, 14]
for val in tqdm(pd.unique(adata.obs.Bregma)):
    adata_sub = nest.data.get_data("merfish", bregma=val)[:, sub_gene_list].copy()
    nest.compute_gene_hotspots(adata_sub, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                            core_only=False, log=True)
    nest.interaction_hotspots(adata_sub, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                            core_only=False, method="permutation",
                              sig_threshold=0.95,
                              K=None, save_activity=True)
    adata_sub_list_2.append(adata_sub.copy())

In [None]:
nest.plot.spatial(adata_sub_list_2[0], color=["Penk", "Oprk1", "PENK_OPRK1"], color_map="Blues")

In [None]:
nest.plot.spatial(adata_sub_list_2[1], color=["Penk", "Oprk1", "PENK_OPRK1"], color_map="Blues")

In [None]:
nest.plot.spatial(adata_sub_list_2[6], color=["Cck", "Cckbr", "CCK_CCKBR"], color_map="Blues")

In [None]:
%matplotlib inline
rc('font',**{'family':'serif','serif':['Arial'], 'size':8})

fig, axs = plt.subplots(1, 3, figsize=[6.5, 2])
for i in range(3):
    if i == 2:
        legend_loc = "right margin"
    else:
        legend_loc = None
    adata_sub = adata[adata.obs.Bregma == bregma_values[5+i]]
    nest.plot.spatial(adata_sub, color=["Cell_class"], groups=["Ambiguous", "Excitatory", "Inhibitory"],
                     ax=axs[i], title=f"bregma={int(bregma_values[5+i])}", frameon=False, show=False,
                     legend_loc=legend_loc)
    

def save_fig(fig, name):
    fig.savefig(os.path.join(image_save_dir, name), dpi=300, bbox_inches='tight', transparent=True)
save_fig(fig, "cell_type.png")

In [None]:
nest.plot.spatial(adata_sub, color=["Cell_class"], groups=["Ambiguous", "Excitatory", "Inhibitory"],
                      title=f"bregma={int(bregma_values[5+i])}", frameon=False, show=False)

In [None]:
nest.plot.spatial(adata_sub_list_2[5], color=["Cckbr"], color_map="Blues", use_raw=False)

In [None]:
nest.plot.spatial(adata_sub_list_2[5], color=["Cck"], color_map="Blues", use_raw=False)

In [None]:
np.quantile(adata_sub_list_2[5][:, "Cckbr"].X.toarray(), 0.999999)

In [None]:
nv = len(np.unique(bregma_values))
for interaction in adata.uns['interactions']['interaction_name']:
    print(interaction)

    fig, axs = plt.subplots(nv, 2, figsize=(8, 3*nv))
    for idx, val in enumerate(np.unique(bregma_values)):

        adata_sub = adata[adata.obs.Bregma == val]
        #nest.plot.spatial(adata_sub, color=interaction, ax=axs[idx, 0], show=False, color_map="Blues")
        nest.plot.hotspots(adata_sub, color=interaction, ax=axs[idx, 0], show=False)
        #nest.plot.spatial(adata_sub_list_2[idx], color=interaction, ax=axs[idx, 1], show=False, color_map="Blues")
        nest.plot.hotspots(adata_sub_list_2[idx], color=interaction, ax=axs[idx, 1], show=False)
    fig.savefig(f'images/merfish/{interaction}_2d_3d_comp.png', dpi=300, transparent=True, bbox_inches='tight')

In [None]:
adata_sub_list_2

In [None]:
adata_sub = adata_sub_list_2[8]
adata_sub.obs['tmp'] = (adata_sub.obs['TAC1_TACR1'] > 
                        adata.uns['activity_significance_cutoff']['TAC1_TACR1']).astype(np.int_)
nest.plot.spatial(adata_sub, color='tmp')

In [None]:


#sub_gene_list = ["Cck", "Penk"]
sub_gene_list = list(set(adata.var_names)-{'Fos'})
adata_sub_list = []
adata_sub_list_2 = []
region_vals = [6, 10, 14]
for val in tqdm(pd.unique(adata.obs.Bregma)):
    adata_sub = nest.data.get_data("merfish", bregma=val)[:, sub_gene_list].copy()

    nest.compute_gene_hotspots(adata_sub, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                            core_only=False, log=True)
    nest.interaction_hotspots(adata_sub, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                            core_only=False)
    adata_sub_list_2.append(adata_sub.copy())
    nest.hotspot_closure(adata_sub)
    nest.coexpression_hotspots(adata_sub, min_genes=3, verbose=False, threshold=0.5)
    nest.compute_multi_boundaries(adata_sub, 1, 0.01)
    adata_sub_list.append(adata_sub)
    #print(adata_sub.shape)

In [None]:
adata.uns['um_scale']

In [None]:
adata_sub = adata[adata.obs.Bregma == bregma_values[2]]
neighbor_eps=0.06
min_samples=5
hotspot_min_size=5
adata_sub.uns['um_scale'] = 0.001
nest.interaction_hotspots(adata_sub, eps=neighbor_eps, min_size=hotspot_min_size, min_samples=min_samples,
                            core_only=False)

In [None]:
adata.obs.Bregma

In [None]:
adata.obs['z'] = adata.obs.Bregma*0.01

In [None]:
adata.obs.z

In [None]:
adata.obsm['spatial'].shape

In [None]:
adata

In [None]:
nest.plot.spatial(adata_sub, color="Nts", color_map="Blues")

In [None]:
inds = np.where(adata_sub[:, "Nts"].X.toarray().ravel() > 0.4)[0]

In [None]:
out = nest.compute_hotspots(adata_sub, input_data=inds, return_regions=True, min_samples=min_samples, eps=neighbor_eps,
                     min_size=hotspot_min_size)

In [None]:
out2 = out.copy()
out2.categories = out2.categories+2

In [None]:
out2.categories

In [None]:
out2.codes

In [None]:
v = np.where(pd.notnull(out))[0]

In [None]:
out[v]

In [None]:
pd.Categorical(out2.codes[v], categories=[0,1])

In [None]:
np.arange(1, 5+1)

In [None]:
adata

In [None]:
adata.obsp['transport_secreted']

In [None]:
nest.plot.spatial(adata[adata.obs['Bregma']==bregma_values[6]], color="Cckbr", 
                  crop_coord=[0.58, 0.94, 0.41, 0.77], color_map="Blues", use_raw=False, vmax=3)

In [None]:
cc = nest.methods.CellChat()
cc.cellchat_score(adata, interaction="Cck  - Cckbr", group_by="Cell_class")

In [None]:
nest.plot.spatial(adata[adata.obs.Bregma==bregma_values[5]], color="cellchat_score")

In [None]:
sns.color_palette("Reds", as_cmap=True)(0.7)

In [None]:
adata_tmp = adata_sub_list_2[6]
adata_tmp = adata_tmp[adata_tmp[:, "Cck"].X.toarray() > 0.2]
nest.plot.spatial(adata_tmp, color="Cck",color_map="Blues")

In [None]:
from nest.hotspot.hotspot import _compute_cutoff
expr = adata_tmp[:, "Cck"].X.toarray()
_compute_cutoff(expr, log=False)

In [None]:
in_z = (adata.obs['Bregma'].to_numpy() >= bregma_values[5]) & (adata.obs['Bregma'].to_numpy() <= bregma_values[7])
coords = adata.obsm['spatial']
x = coords[:, 0]
y = coords[:, 1]
bbox = [0.58, 0.94, 0.41, 0.73]
in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
#active_receiver = pd.notnull(adata.obs['hotspots_CCK_CCKBR'])
receiver_inds = in_z & in_bbox & active_receiver
other_inds = in_z & in_bbox & ~active_receiver

In [None]:
pd.value_counts(adata[receiver_inds].obs['Cell_class'])

In [None]:
pd.value_counts(adata[receiver_inds].obs['Cell_class']) / pd.value_counts(adata[other_inds].obs['Cell_class'])

In [None]:
import colorcet as cc

%matplotlib inline
rc('font',**{'family':'serif','serif':['Arial'], 'size':5})
bregma_vals = bregma_values[5:8]
cci_inds_2d = []
cci_inds_3d = []

bbox = [0.58, 0.94, 0.41, 0.73]

cm = sns.color_palette(cc.glasbey)
ax = plt.figure(figsize=(2, 2)).add_subplot(projection='3d')
ax.view_init(elev=13, azim=98)
z_scale=0.10

c1 = (0.1791464821222607, 0.49287197231833907, 0.7354248366013072, 1.0)
c2 = (0.8503344867358708, 0.14686658977316416, 0.13633217993079583, 1.0)
#c3 = sns.color_palette(cc.glasbey)[17]
c3 = (0, 0, 0, 1.0)
#c1 = sns.color_palette("muted")[0]
#c2 = sns.color_palette("muted")[1]
#c3 = sns.color_palette("muted")[5]
for z_ind, bregma in enumerate(bregma_vals):
    adata_sub = adata_sub_list_2[5+z_ind]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords_tmp1 = adata_sub.obsm['spatial'].copy()
    
    inds_li = (adata_sub[:, "Cck"].X.toarray() > 0.2).ravel()
    
    inds_2d = pd.notnull(adata_sub.obs['hotspots_CCK_CCKBR'])
    
    adata_sub = adata[adata.obs["Bregma"] == bregma]
    coords = adata_sub.obsm['spatial']

    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords_tmp2 = adata_sub.obsm['spatial'].copy()
    coords = adata_sub.obsm['spatial']
    color_vector = np.stack([np.array([0.9, 0.9, 0.9, 0.0])]*adata_sub.shape[0])
    inds_3d = pd.notnull(adata_sub.obs['hotspots_CCK_CCKBR'])
    inds_3d &= ~inds_2d
    #cci_inds_3d.append(inds_3d)
    cur_z = z_scale*z_ind
    
    if z_ind == 0:
        label_li, label_2d, label_3d = "sender", "receiver (2D)", "receiver (3D)"
    else:
        label_li, label_2d, label_3d = None, None, None
    
    ax.scatter(coords[inds_li, 0], coords[inds_li, 1], cur_z, color=c3, s=1, label=label_li, depthshade=False)
    ax.scatter(coords[inds_2d, 0], coords[inds_2d, 1], cur_z, color=c2, s=1, label=label_2d, depthshade=False)
    ax.scatter(coords[inds_3d, 0], coords[inds_3d, 1], cur_z, color=c1, s=1, label=label_3d, depthshade=False)
    
    # draw boundary around the slice
    eps = 0.02
    xs = [bbox[0]-eps, bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps]
    ys = [bbox[2]-eps, bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5)
    

    
# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes\
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_box_aspect((bbox[1]-bbox[0], bbox[3]-bbox[2], 2*z_scale))

plt.legend(loc="right", bbox_to_anchor=(0.92, 0.10))

plt.savefig(f'images/merfish/Cck_2d_3d_comp.pdf', dpi=300, transparent=True, bbox_inches='tight')

In [None]:
adata.obs.columns

In [None]:
adata.obs['Cell_class'].cat.categories

In [None]:
adata.uns['Cell_class_colors']

In [None]:
adata_sub = adata[adata.obs.Bregma==bregma_values[5]]
adata_sub.obs['Cell_class'].cat.categories

In [None]:
adata_sub.uns['Cell_class_colors']

In [None]:
import colorcet as cc

%matplotlib inline
rc('font',**{'family':'serif','serif':['Arial'], 'size':5})
bregma_vals = bregma_values[5:8]
cci_inds_2d = []
cci_inds_3d = []

cm = sns.color_palette(cc.glasbey)
ax = plt.figure(figsize=(2, 2)).add_subplot(projection='3d')
ax.view_init(elev=13, azim=98)
z_scale=0.10

c1 = (0.1791464821222607, 0.49287197231833907, 0.7354248366013072, 1.0)
c2 = (0.8503344867358708, 0.14686658977316416, 0.13633217993079583, 1.0)
#c3 = sns.color_palette(cc.glasbey)[17]
c3 = (0, 0, 0, 1.0)
for z_ind, bregma in enumerate(bregma_vals):
    cur_z = zscale*z_ind
    adata_sub = adata_sub_list_2[5+z_ind]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords_tmp1 = adata_sub.obsm['spatial'].copy()
    
    inds_li = (adata_sub[:, "Cck"].X.toarray() > 0.2).ravel()
    
    inds_2d = pd.notnull(adata_sub.obs['hotspots_CCK_CCKBR'])
    
    adata_sub = adata[adata.obs["Bregma"] == bregma]
    coords = adata_sub.obsm['spatial']

    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords = adata_sub.obsm['spatial']
    
    for k, c in enumerate(adata_sub.obs['Cell_class'].cat.categories):
        c_ind = np.where(adata_sub.obs['Cell_class'].cat.categories == c)[0]
        color = adata_sub.uns['Cell_class_colors'][c_ind]
        inds = adata_sub.obs['Cell_class'] == c
        if z_ind == 0:
            label = c
        else:
            label = None
            
        ax.scatter(coords[inds, 0], coords[inds, 1], cur_z, color=color, s=1, label=label)
    
    # draw boundary around the slice
    eps = 0.02
    xs = [bbox[0]-eps, bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps]
    ys = [bbox[2]-eps, bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5)
    

    
# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes\
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_box_aspect((bbox[1]-bbox[0], bbox[3]-bbox[2], 2*z_scale))

plt.legend(loc="right", bbox_to_anchor=(1.4, 0.5))

plt.savefig(f'images/merfish/merfish_cell_type.pdf', dpi=300, transparent=True, bbox_inches='tight')

In [None]:
adata_sub.uns['Cell_class_colors']

In [None]:
bregma_vals = bregma_values[5:8]

cm = sns.color_palette("Blues", as_cmap=True)
ax = plt.figure(figsize=(2, 2)).add_subplot(projection='3d')
ax.view_init(elev=13, azim=98)
z_scale=0.10

gene = "Cck"

for z_ind, bregma in enumerate(bregma_vals):
    adata_sub = adata[adata.obs["Bregma"] == bregma]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    adata_sub = adata_sub[adata_sub[:, 'Cck'].X.toarray() > 0.1]
    coords = adata_sub.obsm['spatial']
    expr = adata_sub[:, gene].X.toarray()
    color_vector = cm(expr)

    cur_z = zscale*z_ind
    
    if z_ind == 0:
        label_2d, label_3d = "identified with NeST-2D", "further identified with NeST-3D"
    else:
        label_2d, label_3d = None, None
    
    ax.scatter(coords[:, 0], coords[:, 1], cur_z, color=color_vector, s=0.5, alpha=0.75)
    
    # draw boundary around the slice
    eps = 0.04
    xs = [bbox[0]-eps, bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps]
    ys = [bbox[2]-eps, bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5)
    
    
# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_box_aspect((bbox[1]-bbox[0], bbox[3]-bbox[2], 2*z_scale))

norm = mpl.colors.Normalize(vmin=0, vmax=1)

# add the color bar
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cm), ax=ax, orientation='vertical', pad=-0.05,
                   shrink=0.4, aspect=7)
plt.savefig(os.path.join(image_save_dir, f"Cck.png"), dpi=300, transparent=True, bbox_inches='tight')

In [None]:
z_scale

In [None]:
import colorcet as cc
ax = plt.figure().add_subplot(projection='3d')
ax.view_init(elev=6, azim=98)

zscale = 0.5

bbox = [0.58, 0.94, 0.41, 0.73]

bregma_vals = bregma_values[5:8]
for z_ind, bregma in enumerate(bregma_vals):
    adata_sub = adata[adata.obs["Bregma"] == bregma]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords = adata_sub.obsm['spatial']
    color_vector = np.stack([np.array([0.9, 0.9, 0.9, 0.0])]*adata_sub.shape[0])
    inds = pd.notnull(adata_sub.obs['hotspots_CCK_CCKBR'])
    #color_vector[inds] = np.array((0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0))
    cur_z = zscale*z_ind
    ax.scatter(coords[inds, 0], coords[inds, 1], cur_z, color="C3", s=1)
    
    # draw boundary around the slice
    eps = 0.02
    xs = [bbox[0]-eps, bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps]
    ys = [bbox[2]-eps, bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5)


# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_xlim([bbox[0]-0.2, bbox[1]])
ax.set_ylim([bbox[2], bbox[3]])
ax.set_zlim([0*zscale-0.2, len(bregma_vals)*zscale])

plt.savefig(os.path.join(image_save_dir, f"Cck_Cckbr_3d.png"), dpi=300, transparent=True, bbox_inches='tight')

In [None]:
import colorcet as cc
ax = plt.figure().add_subplot(projection='3d')
ax.view_init(elev=6, azim=98)

zscale = 0.5

bbox = [0.58, 0.94, 0.41, 0.73]

bregma_vals = bregma_values[5:8]
for z_ind, bregma in enumerate(bregma_vals):
    adata_sub = adata_sub_list_2[5+z_ind]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords = adata_sub.obsm['spatial']
    color_vector = np.stack([np.array([0.9, 0.9, 0.9, 0.0])]*adata_sub.shape[0])
    inds = pd.notnull(adata_sub.obs['hotspots_CCK_CCKBR'])
    #color_vector[inds] = np.array((0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0))
    cur_z = zscale*z_ind
    ax.scatter(coords[inds, 0], coords[inds, 1], cur_z, color="C3", s=1)
    
    # draw boundary around the slice
    eps = 0.02
    xs = [bbox[0]-eps, bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps]
    ys = [bbox[2]-eps, bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5)


# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_xlim([bbox[0]-0.2, bbox[1]])
ax.set_ylim([bbox[2], bbox[3]])
ax.set_zlim([0*zscale-0.2, len(bregma_vals)*zscale])

plt.savefig(os.path.join(image_save_dir, f"Cck_Cckbr_2d.png"), dpi=300, transparent=True, bbox_inches='tight')

In [None]:
import colorcet as cc
ax = plt.figure().add_subplot(projection='3d')
ax.view_init(elev=6, azim=98)

gene = "Cck"

zscale = 0.05

bbox = [0.58, 0.94, 0.41, 0.73]

color_map = sns.color_palette("Blues", as_cmap=True)

bregma_vals = bregma_values[5:8]
for z_ind, bregma in enumerate(bregma_vals):
    adata_sub = adata[adata.obs["Bregma"] == bregma]
    coords = adata_sub.obsm['spatial']
    x = coords[:, 0]
    y = coords[:, 1]
    in_bbox = (x >= bbox[0]) & (x <= bbox[1]) & (y >= bbox[2]) & (y <= bbox[3])
    adata_sub = adata_sub[in_bbox]
    coords = adata_sub.obsm['spatial']
    color_vector = np.stack([np.array([0.9, 0.9, 0.9, 0.0])]*adata_sub.shape[0])
    expr = adata_sub[:, gene].X.toarray().ravel()
    inds = expr > 0
    expr=expr[inds]
    color_vector = color_map(expr)
    cur_z = zscale*z_ind
    #ax.scatter(coords[inds, 0], coords[inds, 1], cur_z, color=color_vector, s=1, vmax=5, zorder=1)
    
    # draw boundary around the slice
    eps = 0.02
    xs = [bbox[0]-eps, bbox[1]+eps, bbox[1]+eps, bbox[0]-eps, bbox[0]-eps]
    ys = [bbox[3]+eps, bbox[3]+eps, bbox[2]-eps, bbox[2]-eps, bbox[3]+eps]
    ax.plot(xs, ys, cur_z, color="k", linewidth=0.5, zorder=3)
    # front of the frame at a different zorder so it appears in front
    ax.plot([bbox[0]-eps, bbox[1]+eps], [bbox[3]+eps, bbox[3]+eps], cur_z, color="k", linewidth=0.5, zorder=10)
    
# Plot the arrows representing diffusion

# intra-layer arrows
for phi in [-np.pi/5, 0, np.pi/5]:
    theta = np.arange(6)*(2*np.pi)/6
    u = np.sin(theta)*np.cos(phi)
    v = np.cos(theta)*np.cos(phi)
    w = np.sin(phi) * np.ones(6)
    x = 0.78 * np.ones(6)
    y = 0.60* np.ones(6)
    z = zscale * np.ones(6)
    if phi != 0:
        length = 0.05/np.abs(np.sin(phi))
    else:
        length = 0.1
    ax.quiver(x,y,z,u,v,w, length=length, color="C0", zorder=4)


# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

ax.set_xlim([bbox[0]-0.2, bbox[1]])
ax.set_ylim([bbox[2], bbox[3]])
ax.set_zlim([0, (len(bregma_vals)-1)*zscale])

ax.set_box_aspect((bbox[1]-bbox[0], bbox[3]-bbox[2], 2*zscale))

plt.savefig(os.path.join(image_save_dir, f"{gene}_3d.png"), dpi=300, transparent=True, bbox_inches='tight')

In [None]:
ax.quiver

In [None]:
adata[adata.obs["Bregma"] == bregma].obs['hotspots_CCK_CCKBR']

In [None]:
adata.uns['interactions']

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

# Create the sphere
r = 1 # radius of the sphere
phi, theta = np.mgrid[0:np.pi:100j, 0:2*np.pi:100j]
x = r*np.sin(phi)*np.cos(theta)
y = r*np.sin(phi)*np.sin(theta)
z = r*np.cos(phi)

# Create the three planes
# plane 1
z1 = -0.5*np.ones(x.shape)
# plane 2
z2 = 0*np.ones(x.shape)
# plane 3
z3 = 0.5*np.ones(x.shape)

# Plot the sphere and the planes
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.view_init(elev=6, azim=98)
ax.plot_surface(x, y, z, alpha=0.5)
ax.plot_surface(x, y, z1, alpha=0.5, color='r')
ax.plot_surface(x, y, z2, alpha=0.5, color='g')
ax.plot_surface(x, y, z3, alpha=0.5, color='b')
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

# colors
cm = sns.color_palette(cc.glasbey)
cm2 = sns.color_palette("Reds", as_cmap=True)
cm3 = sns.color_palette("Blues", as_cmap=True)
c1 = cm[0]
c2 = cm[3]

# Create the sphere
r = 1 # radius of the sphere
phi, theta = np.mgrid[0:np.pi:100j, 0:2*np.pi:100j]
x = r*np.sin(phi)*np.cos(theta)
y = r*np.sin(phi)*np.sin(theta)
z = r*np.cos(phi)
c = cm2(np.abs(z)**2)
c2 = cm3(np.abs(z)**2)


# create different x, y values for the offset intersections
r = np.sqrt(3)/2
phi, theta = np.mgrid[0:np.pi:100j, 0:2*np.pi:100j]
xs = r*np.sin(phi)*np.cos(theta)
ys = r*np.sin(phi)*np.sin(theta)

# and for the squares
xp, yp = np.mgrid[-1:1:100j, -1:1:100j]


# Create the three planes and the squares
# plane 1 and square 1
zs1 = -0.5*np.ones(x.shape)


# plane 2 and square 2
zs2 = 0*np.ones(x.shape)

# plane 3 and square 3
zs3 = 0.5*np.ones(x.shape)

zs4 = -1*np.ones(x.shape)
zs5 = 1*np.ones(x.shape)


# Plot the sphere, the planes, and the squares
fig = plt.figure(figsize=(2, 2))
ax = plt.axes(projection='3d')
ax.view_init(elev=10, azim=120)

#ax.scatter(0, 0, 0, c='r', marker='o', s=500, zorder=100)

ax.plot_surface(xs, ys, zs1, alpha=0.75, facecolors=c2)
ax.plot_surface(x, y, zs2, alpha=0.75, facecolors=c)
ax.plot_surface(xs, ys, zs3, alpha=0.75, facecolors=c2)

ax.plot_surface(x, y, z, alpha=0.1, color=(0.5, 0.5, 0.5))
ax.plot_surface(xp, yp, zs1, alpha=0.1, color='k')
ax.plot_surface(xp, yp, zs2, alpha=0.1, color='k')
ax.plot_surface(xp, yp, zs3, alpha=0.1, color='k')
#ax.plot_surface(xp, yp, zs4, alpha=0.1, color='k')
#ax.plot_surface(xp, yp, zs5, alpha=0.1, color='k')



# Make panes transparent
ax.xaxis.pane.fill = False # Left pane
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Remove grid lines
ax.grid(False)

# Remove tick labels
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])

# Transparent spines
ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))

# Transparent panes
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# No ticks
ax.set_xticks([]) 
ax.set_yticks([]) 
ax.set_zticks([])

fig.savefig(f'images/merfish/3d_cci_illustration.png', dpi=300, transparent=True, bbox_inches='tight')

In [None]:
cm = sns.color_palette("Reds", as_cmap=True)
cm

In [None]:
cm(0)

In [None]:
c.shape

In [None]:
adata.shape

In [None]:
nest.plot.spatial(adata[adata.obs.Bregma == bregma_values[6]], color="Cell_class", groups=["Ambiguous", "Excitatory", "Inhibitory"])