Author: Erno Hänninen

Created: 25.03.2023

Title: Hypo_d16_MiSTR_d14_d21_bonefight.ipynb

Description: 
- Spatially align day 16 hypothalamic differentiation protocol data (supplemented with data from MiSTR atlas) to neural tube

Procedure
- Read the spatial data to FISHscale object
- Convert the FISHscale object to anndata and remove medullary hindbrain from the spatial data
- Read and concatenate single cell data (MiSTR and day 16)
- From the subsetted data extract d14, d16 and d21 timepoints (d14 and d16 batches from mistr, and d16 from hypothalamus)
- Extract genes shared among both datasets and identify HVG genes from this subset (these are used as training genes)
- Subsample the data so that all clusters contains as many cells as the smallest cluster
- Create cluster expression matrix for these genes (average expression value in each cluster)
- Train BoneFight model and use the model to align tissues to spatial data (returns cluster-by-spot probability matrix)
- Smooth the matrix
- From the smoothed probability matrix for each spatial spot assing the the cluster with highest probability
- Visualize the result

List of non-standard modules:
- bone_fight, tangram, anndata, FISHscale, scanpy, numpy, scipy, sklearn, fastcluster, matplotlib, pandas

Additional information:
- _smooth -function is copied from the script used in Comprehensive cell atlas of the first-trimester developing human brain -paper (Linnarsson lab)
- construct_obs_plot, plot_cell_annotation_custom, convert_adata_array, tangram_plot_genes_custom -functions are from Tangram source code
    - To better serve our purposes slightly modified versions of these functions are implemented in this notebook 

Conda environment used:
- bonefight_env

Usage:
- The script was executed using Jupyter Notebook web interface. All the dependencies required by Jupyter are installed to bonefight_env Conda environment. See README file for further details

In [None]:
# Import python packages
import bone_fight as bf
import tangram as tg
from anndata import AnnData
import sys, os
sys.path.insert(0, "FISHscale")
from FISHscale.utils import dataset
import scanpy as sc
sc.settings.verbosity = 0
sc.set_figure_params(dpi = 450, dpi_save = 450)
import numpy as np
from scipy.spatial.distance import pdist
from scipy.spatial import KDTree
from sklearn.preprocessing import scale
import fastcluster
import scipy.cluster.hierarchy as hc
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, hex2color
import pandas as pd
import matplotlib.lines as mlines


# Set plotting settings
import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = "none"
mpl.rcParams['legend.labelcolor'] = "black"
mpl.rcParams['text.color'] = "black"

if not os.path.exists('figures'):
   os.makedirs('figures')

# Read and process spatial data

In [None]:
# Read spatial data
d = dataset.Dataset("Data/LBEXP20211113_EEL_HE_5w_970um_RNA_transformed_assigned.parquet",
                       gene_label = 'decoded_genes', 
                       x_label = 'r_transformed', 
                       y_label = 'c_transformed', 
                       pixel_size='0.27 micrometer', 
                       other_columns=['Brain'],
                       z = 970,
                       reparse=False)
d.set_working_selection("Brain")

In [None]:
# Creating AnnData object from fishscale spatial data

# Make hexagonal bins, the output can be used to create adata object
# Returns pandas count dataframe to df_hex and spatial coordinates (x, y) to a list variable
df_hex, coordinates = d.hexbin_make(spacing=50, min_count=10)

# Initialize variables
coordinates_to_keep = []
count_df = df_hex.T
spots_to_remove = []
x, y = [], []

# The spatial data contains part (hindbrain) which doesn't occur in our mistr data 
# Using the spatial coordinates (x, y) we are able to select the spots we are interested in 
for i, spot in enumerate(count_df.index):
    if coordinates[i][0] < 1800 and coordinates[i][1] < 4300:
        spots_to_remove.append(spot)
    else:
        coordinates_to_keep.append([coordinates[i][0]*-1,coordinates[i][1]])
        x.append(coordinates[i][0]*-1)
        y.append(coordinates[i][1])

# Remove the unnecessary rows from df based on the spots_to_remove index list
count_df = count_df.drop(spots_to_remove)

# Create volume vector for bonefight
volume_sp = np.array([1] * len(count_df))

# Initialize anndata with count data, and add the spatial cordinates 
adata_spatial = AnnData(count_df.to_numpy(), obsm={"spatial": np.array(coordinates_to_keep)})
adata_spatial.var.index = np.array(count_df.columns)
adata_spatial.obs.index = np.array(count_df.index)
adata_spatial.obs["x"] = x
adata_spatial.obs["y"] = y

# Plot results after removing the parts not occuring in our single cell data
with plt.rc_context({"figure.figsize": [3.5, 2.2],   "figure.dpi": 400}):
    xs = adata_spatial.obs.x.values
    ys = adata_spatial.obs.y.values
    plt.axis("off")
    plt.scatter(xs, ys, s=0.8, color="#D3D3D3");
    plt.savefig("figures/raw_spatial_data.png", dpi=450, bbox_inches='tight')
    plt.savefig("figures/raw_spatial_data.svg", dpi=450, bbox_inches='tight')

In [5]:
adata_spatial

AnnData object with n_obs × n_vars = 1164 × 448
    obs: 'x', 'y'
    obsm: 'spatial'

# Read single cell data and prepare the data for bonefight

In [6]:
# Reading single cell data (mistr atlas and day 16 data)
mistr_data = sc.read_h5ad('Data/all_mistr_mnn.h5ad')
d16 = sc.read_h5ad('../Scanvi_notebooks/Data/d16.h5ad')

In [9]:
# Concaenate d16 and mistr data 
d16.var = d16.var.set_index('_index')
mistr_data = mistr_data.concatenate(d16, join='outer')
# Fill NA values
mistr_data.obs["tissue"] = mistr_data.obs["tissue"].cat.add_categories('Hypothalamus d 16')
mistr_data.obs["tissue"] = mistr_data.obs["tissue"].fillna(value="Hypothalamus d 16")
mistr_data.obs["day"] = mistr_data.obs["day"].cat.add_categories('day 16')
mistr_data.obs["day"] = mistr_data.obs["day"].fillna(value="day 16")

In [None]:
mistr_data.obs["tissue"]

In [None]:
# Store spatial data before subsetting it
adata_spatial_full = adata_spatial.copy()

# Extract d14, d16 and d21 data from the single cell data
# d14 and d21 is from mistr data and d16 from hypothalamus
mistr_data_d14_d21 = mistr_data[mistr_data.obs["day"].isin(["day 14","day 16", "day 21",])]

# Subset our mistr data using genes that are shared with the spatial data
common_genes = [gene for gene in mistr_data_d14_d21.var.index if gene in adata_spatial.var.index]
mistr_data_d14_d21_gene_subset = mistr_data_d14_d21[:, common_genes]

# We run bonefight using only DE genes 
# From each cluster (6 tissues), we select 35 DE genes
sc.tl.rank_genes_groups(mistr_data_d14_d21_gene_subset, groupby="tissue", use_raw=False,method="wilcoxon" )
genes_to_keep_df = pd.DataFrame(mistr_data_d14_d21_gene_subset.uns["rank_genes_groups"]["names"]).iloc[0:35, :]
genes_to_keep = list(np.unique(genes_to_keep_df.melt().value.values))

print("Number of training genes for BoneFight: ", len(genes_to_keep))

# Subset the data using identified DE genes
mistr_data_d14_d21_gene_subset_2 = mistr_data_d14_d21_gene_subset[:, genes_to_keep]
adata_spatial = adata_spatial[:, genes_to_keep]
count_df = count_df[count_df.columns[count_df.columns.isin(genes_to_keep)]]

# Plot d14 and d21 single cell data
with plt.rc_context({"figure.figsize": [4, 3],  "figure.dpi": 400}):
    sc.pl.umap(mistr_data_d14_d21, color="tissue", size = 3, frameon=False, show=True,legend_fontsize="large", save="_mistr_hypo_d14_d21.png")   

In [None]:
# Plotting DE genes used in bonefight
with plt.rc_context({"figure.figsize": [4, 4],  "figure.dpi": 200}):
    sc.pl.rank_genes_groups(mistr_data_d14_d21_gene_subset, n_genes=25, sharey=False,ncols=3)    

In [None]:
mistr_data_d14_d21_gene_subset_2.obs["tissue"].value_counts() # Print cluster value counts

In [None]:
# Bonefight maps clusters instead of single cells 
# Subsample the data so that all the clusters have equal size
tissue_types = np.unique(mistr_data_d14_d21_gene_subset_2.obs['tissue'])
subsampled = []
for tissue_type in tissue_types:
    subset = mistr_data_d14_d21_gene_subset_2[mistr_data_d14_d21_gene_subset_2.obs['tissue'] == tissue_type, :]
    if len(subset) >= 6223:
        sc.pp.subsample(subset, n_obs=6223)    
    subsampled.append(subset)
    
# Combine subsampled datasets into a single AnnData object
subsampled_adata = subsampled[0].concatenate(subsampled[1:], join='outer')

subsampled_adata

In [26]:
# Create cluster gene expression matrix
def create_expression_matrix(adata):
    adata.obs['tissue'] = adata.obs['tissue'].astype('category')
    cluster_expression = pd.DataFrame(columns=adata.var_names, index=adata.obs['tissue'].cat.categories)                                                                                                 

    for clust in subsampled_adata.obs["tissue"].cat.categories: 
        cluster_expression.loc[clust] = adata[adata.obs['tissue'].isin([clust]),:].X.mean(0)
    return cluster_expression.transpose().sort_index()

bonefigt_expression_matrix = create_expression_matrix(subsampled_adata)

# Get value counts to dictionary and sort according the columns in cluster_expression dataframe
cluster_counts_dict = subsampled_adata.obs["tissue"].value_counts().to_dict()
cluster_counts_dict = sorted(cluster_counts_dict.items(), key=lambda pair: list(bonefigt_expression_matrix.columns).index(pair[0]))

# Volume_sc contains the cluster counts
volume_sc = np.array(list(dict(cluster_counts_dict).values()))


# Bonefight

In [27]:
# Initialize bonefight and train the model
a = bf.View(bonefigt_expression_matrix.T.to_numpy(), volume_sc)
b = bf.View(count_df.to_numpy(), volume_sp)
model = bf.BoneFight(a, b).fit(80, 0.1)


100%|██████████| 80/80 [00:42<00:00,  1.87it/s, loss=-1.1462272]


In [None]:
# Plot the training process to see that the model has reached convergence
with plt.rc_context({"figure.figsize": [3.5, 3.5], "figure.dpi": 250}):
    plt.figure()
    plt.plot(model.losses)
    plt.title('Convergence')
    plt.xlabel('Epoch')
    plt.ylabel('Losses')

In [29]:
# Create an identity matrix, shape (n_clusters, n_clusters)
labels = np.eye(len(bonefigt_expression_matrix.columns))
# Transform it, and the result will be (x, y, n_clusters)
y = model.transform(labels)

#Put results in dataframe
bonefight_df = pd.DataFrame(y, index=count_df.index, columns=bonefigt_expression_matrix.columns)

In [30]:
# Smooth results
# This smoothing function is from the first-trimester developing human brain -publication 
# (https://www.biorxiv.org/content/10.1101/2022.10.24.513487v1.full)
def _smooth(xy, data, k=28, weight=True):
    
    k += 1 #It will find itself but this is later removed
    tree = KDTree(xy)
    dist, kneigh = tree.query(xy, k=k)
    data_np = data.to_numpy()
    mean = data_np[kneigh[:, :]]
    
    if weight:
        #Weigh by distance
        dist_max = dist.max()
        mean = mean * (dist_max - dist[:,:, np.newaxis])
        
    mean = mean.mean(axis=1)
    
    return mean

smooth0_5w = _smooth(coordinates_to_keep, bonefight_df)

bonefight_df_smooth = pd.DataFrame(data = smooth0_5w, index = bonefight_df.index, columns=bonefight_df.columns)
bonefight_df_smooth.shape

(1164, 7)

In [None]:
# Bonefight returns a dataframe in where each cluster has its probability value to be in spatial spot j.
# The values are stored in relative scale
bonefight_df_smooth

# Prepare bonefight results for tangram functions

In [32]:
# Tangram plotting function inverts the y axis, and there is no parameter to control this
# Therefore this function can be used to invert the y axis before passing the data to tangram plotting functions
# and when tangram inverts the axis again, the plot is displayed correctly

def invert_y_axis_for_tangram(adata):
    counter, inverted_coordinates = 0, []
    # Looping over the coordinatas, and multiplying the y-axis coordinates by 1 and append the updated coordinates to list
    for i in range(len(adata.obsm["spatial"])):
        inverted_coordinates.append((adata.obsm["spatial"][i][0],adata.obsm["spatial"][i][1]*-1))

    # Store the updated coordinates in adata_spatial object
    adata.obsm["spatial"] = np.array(inverted_coordinates)
    return adata

In [None]:
# Calling the function that invers y axis, both with the full and subsetted spatial data
adata_spatial = invert_y_axis_for_tangram(adata_spatial)
adata_spatial_full = invert_y_axis_for_tangram(adata_spatial_full)
adata_spatial.obsm["spatial"]

In [34]:
# This code is edited version of tangrams plot_cell_annotation_sc function
# The original function contained some shortcommings, therefore this custom version is used
# Only small edits was made: allows saving the figure and some additional parameters
import pylab as pl
sc.set_figure_params(scanpy=True, fontsize=10)
def construct_obs_plot(df_plot, adata, perc=0, suffix=None):
    # clip
    df_plot = df_plot.clip(df_plot.quantile(perc), df_plot.quantile(1 - perc), axis=1)

    # normalize
    df_plot = (df_plot - df_plot.min()) / (df_plot.max() - df_plot.min())

    if suffix:
        df_plot = df_plot.add_suffix(" ({})".format(suffix))
    adata.obs = pd.concat([adata.obs, df_plot], axis=1)

def plot_cell_annotation_custom(adata_sp, annotation_list, x="x", y="y", spot_size=None, scale_factor=None, perc=0,alpha_img=1.0,bw=False,ax=None, show=True, save=False, ncols=4):
        
    # remove previous df_plot in obs
    adata_sp.obs.drop(annotation_list, inplace=True, errors="ignore", axis=1)

    # construct df_plot
    df = adata_sp.obsm["tangram_ct_pred"][annotation_list]
    construct_obs_plot(df, adata_sp, perc=perc)
    
    #non visium data 
    if 'spatial' not in adata_sp.obsm.keys():
        #add spatial coordinates to obsm of spatial data 
        coords = [[x,y] for x,y in zip(adata_sp.obs[x].values,adata_sp.obs[y].values)]
        adata_sp.obsm['spatial'] = np.array(coords)
    
    if 'spatial' not in adata_sp.uns.keys() and spot_size == None and scale_factor == None:
        raise ValueError("Spot Size and Scale Factor cannot be None when ad_sp.uns['spatial'] does not exist")
    
    #REVIEW
    if 'spatial' in adata_sp.uns.keys() and spot_size != None and scale_factor != None:
        raise ValueError("Spot Size and Scale Factor should be None when ad_sp.uns['spatial'] exists")
    
    sc.pl.spatial(
        adata_sp, color=annotation_list, show=show, frameon=False, spot_size=spot_size,
        scale_factor=scale_factor, alpha_img=alpha_img, bw=bw, ax=ax, ncols=ncols, save=save, colorbar_loc=None)


    adata_sp.obs.drop(annotation_list, inplace=True, errors="ignore", axis=1)

In [None]:
# In tangram tutorial they use project_cell_annotations function to move cell annotations onto space
# However, as the output of project_cell_annotations function corresponds the bonefight's output dataframe,
# we are not running the function in here
# Storing smoothed bonefight output dataframe to adata_spatial and plot the tissue probabilities in space

adata_spatial.obsm["tangram_ct_pred"] = bonefight_df_smooth
annotation_list = list(pd.unique(subsampled_adata.obs['tissue']))
adata_spatial.obsm["tangram_adjusted"] = adata_spatial.obsm["tangram_ct_pred"]
annotation_list
for tissue in annotation_list:
    plot_cell_annotation_custom(adata_spatial, [tissue],x="x",y="y", spot_size=50, scale_factor=0.1, perc=0.001, ncols=3, save="_tissue_probability_svg/"+tissue[3:]+"_probabilities.svg", show=True)   
    plot_cell_annotation_custom(adata_spatial, [tissue],x="x",y="y", spot_size=50, scale_factor=0.1, perc=0.001, ncols=3, save="_tissue_probability_png/"+tissue[3:]+"_probabilities.png", show=False)   
    

In [None]:
# The tissue with highest probability value is assigned for each spatial spot
# The tissue with highest probability gets value 1, rest gets value 0

adata_spatial.obsm["tangram_adjusted"] = adata_spatial.obsm["tangram_adjusted"].eq(adata_spatial.obsm["tangram_adjusted"].where(adata_spatial.obsm["tangram_adjusted"] != 0).max(1), axis=0).astype(int)
adata_spatial.obsm["tangram_ct_pred"] = adata_spatial.obsm["tangram_adjusted"]
# Plot the assigned tissues on space
plot_cell_annotation_custom(adata_spatial, annotation_list,x='x', y='y',spot_size= 50, scale_factor=0.1, perc=0.001, ncols=3, save=False, show=True)


In [None]:
# Next we merge d14 and d21 clusters so that we have one cluster for each tissue

# Merging the clusters 
adata_spatial.obs["predicted_tissue"] = list(adata_spatial.obsm["tangram_adjusted"].idxmax(axis=1).values)
adata_spatial.obs['merged_pred_tissue'] = adata_spatial.obs['predicted_tissue'].replace({'R/C dorsal d 21': 'R/C dorsal', 'R/C dorsal d 14': 'R/C dorsal', 
                                'R/C ventral d 14': 'R/C ventral', 'R/C ventral d 21': 'R/C ventral',
                                'D/V forebrain d 14': 'D/V forebrain', 'D/V forebrain d 21':'D/V forebrain',
                                "Hypothalamus":"Hypothalamus"})

# Color code the clusters
adata_spatial.obs["tissue_color"] = "unknown"
adata_spatial.obs.loc[adata_spatial.obs['merged_pred_tissue'] == 'R/C dorsal', 'tissue_color'] = "#6cb16d"
adata_spatial.obs.loc[adata_spatial.obs['merged_pred_tissue'] == 'R/C ventral', 'tissue_color'] = "#f7c701"
adata_spatial.obs.loc[adata_spatial.obs['merged_pred_tissue'] == 'D/V forebrain', 'tissue_color'] = "#cc6677"
adata_spatial.obs.loc[adata_spatial.obs['merged_pred_tissue'] == 'Hypothalamus d 16', 'tissue_color'] = "#1F77B4"

# Plotting results
with plt.rc_context({"figure.figsize": [3.5, 2.5],  "figure.dpi": 450}):
    xs = adata_spatial.obs.x.values
    #xs = [x * -1 for x in xs]
    ys = adata_spatial.obs.y.values
    plt.scatter(xs, ys, s=0.7, c=adata_spatial.obs['tissue_color'].values);
    plt.axis('off')
    dorsal_legend = mlines.Line2D([],[],color='#6cb16d', label='R/C dorsal d14/d21', marker="o", markersize=6)
    ventral_legend = mlines.Line2D([],[],color='#f7c701', label='R/C ventral d14/d21', marker="o", markersize=6)
    forebrain_legend = mlines.Line2D([],[],color='#cc6677', label='D/V forebrain d14/d21', marker="o", markersize=6)
    hypothalamus_legend = mlines.Line2D([],[],color='#1F77B4', label='Hypothalamus d 16', marker="o", markersize=6)
    
    plt.legend(handles=[dorsal_legend,ventral_legend,forebrain_legend, hypothalamus_legend], loc="center left",  bbox_to_anchor=(0.98, 0.45), frameon=False, prop={'size': 10})
    plt.grid()
    plt.savefig('figures/projected_tissues_hypothalamus.png', dpi=360, bbox_inches='tight')
    #plt.savefig("figures/projected_tissues.png", dpi=450, bbox_inches='tight')  
    plt.show()

# Visualizing measured gene expression

In [23]:
# This code is edited version of tangram plot_genes_sc function
# The original function contained some shortcommings,
# and therefore a customized version of it was used to plot mesured / projected genes on space
# The original function allows plotting measured and projected expression side by side,
# whereas this function plots either measured or projected genes
# Other changes: possibility to save the figure and some additional parameters that serves better our purposes

from scipy.sparse.csr import csr_matrix
from scipy.sparse.csc import csc_matrix
from matplotlib.gridspec import GridSpec
sc.set_figure_params(dpi_save = 500)


def convert_adata_array(adata):
    if isinstance(adata.X, csc_matrix) or isinstance(adata.X, csr_matrix):
        adata.X = adata.X.toarray()
        
def tangram_plot_genes_custom(adata, genes=[], x="x",y = "y",spot_size=None, save=False,scale_factor=None, ncols=4, cmap="inferno",
                              perc=0,alpha_img=1.0,bw=False,return_figure=False, plot_measured_spatial=False, show=True):
    # construct df_plot
    data = []
    
    # remove df_plot in obs
    if plot_measured_spatial == False:
        adata.obs.drop(["{} (projected gene expression)".format(gene) for gene in genes], inplace=True, errors="ignore", axis=1)
        adata.var.index = [g.upper() for g in adata.var.index]
        adata.obs.drop(["{} (projected gene expression)".format(gene) for gene in genes], inplace=True, errors="ignore", axis=1)
        
        df = pd.DataFrame(data=np.array(adata[:, genes].X), columns=genes, index=adata.obs.index)
        construct_obs_plot(df, adata, perc=perc, suffix="projected gene expression")
            
    else:
        adata.obs.drop(["{} (measured gene expression)".format(gene) for gene in genes], inplace=True, errors="ignore", axis=1)
        
        # prepare adatas
        convert_adata_array(adata)
        adata.var.index = [g.upper() for g in adata.var.index]
        adata.obs.drop(["{} (measured gene expression)".format(gene) for gene in genes],inplace=True,errors="ignore",axis=1,)
        for ix, gene in enumerate(genes):
            if gene not in adata.var.index:
                data.append(np.zeros_like(np.array(adata[:, 0].X).flatten()))
            else:
                data.append(np.array(adata[:, gene].X).flatten())

        df = pd.DataFrame(
            data=np.array(data).T, columns=genes, index=adata.obs.index,
        )
        construct_obs_plot(df, adata, suffix="measured gene expression")

    fig = plt.figure(figsize=(7, len(genes) * 3.5))
    gs = GridSpec(len(genes), 2, figure=fig)
    
    #non visium data
    if 'spatial' not in adata.obsm.keys():
        #add spatial coordinates to obsm of spatial data 
        if plot_measured_spatial == True:
            coords = [[x,y] for x,y in zip(adata.obs[x].values,adata.obs[y].values)]
            adata.obsm['spatial'] = np.array(coords)
        else:
            coords = [[x,y] for x,y in zip(adata.obs[x].values,adata.obs[y].values)]
            adata.obsm['spatial'] = np.array(coords)

    if ("spatial" not in adata.uns.keys()) and (spot_size==None and scale_factor==None):
        raise ValueError("Spot Size and Scale Factor cannot be None when ad_sp.uns['spatial'] does not exist")
        
    gene_list=[]
    for gene in genes:
        if plot_measured_spatial==True:       
            gene_list.append("{} (measured gene expression)".format(gene))          
        else:
            gene_list.append("{} (projected gene expression)".format(gene))
        
    if plot_measured_spatial == True:
        sc.pl.spatial(adata, spot_size=spot_size, scale_factor=scale_factor, color=gene_list,
            frameon=False, show=show, cmap=cmap, alpha_img=alpha_img, bw=bw, ncols=ncols, colorbar_loc=None, save=save, vmax=0.52)
    else: 
        #key = dict_keys[i+j]
        #sc.pl.spatial(adata, spot_size=spot_size, scale_factor=scale_factor, color=["{} (predicted)".format(gene)],
        sc.pl.spatial(adata, spot_size=spot_size, scale_factor=scale_factor, color=gene_list, 
            frameon=False, show=show, cmap=cmap, alpha_img=alpha_img, bw=bw, ncols=ncols, colorbar_loc=None,save=save )
    
    """if show == True:
        a = np.array([[0,1]])
        pl.figure(figsize=(0.14, 4.5))
        img = pl.imshow(a, cmap="YlGnBu")
        pl.gca().set_visible(False)
        cax = pl.axes([0.1, 0.2, 0.8, 0.6])
        pl.colorbar(orientation="vertical", cax=cax)
        pl.savefig("figures/YlGnBu_colorbar_1.png", dpi=450, bbox_inches='tight')

        a = np.array([[0,1]])
        pl.figure(figsize=(0.14, 1.8))
        img = pl.imshow(a, cmap="YlGnBu")
        pl.gca().set_visible(False)
        cax = pl.axes([0.1, 0.2, 0.8, 0.6])
        pl.colorbar(orientation="vertical", cax=cax)
        pl.savefig("figures/YlGnBu_colorbar_2.png", dpi=450, bbox_inches='tight')"""
    
    # remove df_plot in obs
    if plot_measured_spatial == True:
        adata.obs.drop(["{} (measured gene expression)".format(gene) for gene in genes], inplace=True, errors="ignore", axis=1)
    else:
        adata.obs.drop(["{} (projected gene expression)".format(gene) for gene in genes], inplace=True, errors="ignore", axis=1)
          
    if return_figure==True:
        return fig

In [None]:
with plt.rc_context({"figure.figsize": [3.5, 2.5],  "figure.dpi": 450}):

    for gene in ["NKX2-1"]:
        tangram_plot_genes_custom(adata_spatial_full, genes=[gene],spot_size=50,
                              scale_factor=0.1, perc = 0.001, return_figure=False, cmap="viridis", plot_measured_spatial=True, save="_measured_png/"+gene+"_measured_expression.png", show=True)