# Nuclear Game - Analysis
Gabriel Emilio Herrera Oropeza <br>
13/06/2022

In [None]:
import cv2
import anndata
import time
import pandas as pd
import os
from os import listdir
from os.path import isfile, join, isdir
import statistics
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import warnings
import scanpy as sc
from math import log10
from skimage.filters import threshold_otsu, threshold_triangle
from statannotations.Annotator import Annotator
from skimage.filters import threshold_multiotsu
from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler
from PIL import Image, ImageEnhance
import plotly.graph_objects as go
from matplotlib_scalebar.scalebar import ScaleBar

warnings.filterwarnings('ignore')
sc.settings.verbosity = 3

In [None]:
def _log(txt, verbose = True):
    if verbose:
        print(txt)
        
        
def intensityNormalisation(df, method = "mode", nbins = 10, verbose = True, hue = "experiment"):
    dct_norm = {}
    df_ = df.copy()
    method = method.lower()
    for n, exp in tqdm(enumerate(set(df_[hue])), total = len(set(df_[hue]))):
        subset = df_[df_[hue] == exp]
        dct_norm[exp] = {}
        for col in subset.columns:
            if "avg_intensity" in col and not any(l in col.lower() for l in ["dapi", "gfap", "olig2"]):
                if method == "mode":
                    relevant_subset = subset[[col]].copy()
                    relevant_subset["bins"] = pd.cut(relevant_subset[col], nbins, duplicates = "drop", 
                                                      labels = False)
                    bins_mode = statistics.mode(relevant_subset["bins"])
                    mode_ = relevant_subset[col][relevant_subset["bins"] == bins_mode].median()
                    subset[col] = subset[col] / mode_
                    dct_norm[exp][col] = mode_
                    _log(f"Reference value for {col} in {exp}: {mode_}", verbose)
                elif method == "mean":
                    mean_ = subset[col].mean()
                    subset[col] = subset[col] / mean_
                    dct_norm[exp][col] = mean_
                elif method == "median":
                    median_ = subset[col].median()
                    subset[col] = subset[col] / median_
                    dct_norm[exp][col] = median_
        if n == 0:
            out_df = subset.copy()
        else:
            out_df = out_df.append(subset)
    return out_df, dct_norm


def find_SingleCells(df, byExperiment = True, nbins = 10, spread = 0.2, channel = "dapi", hue = "experiment"):
    df_ = df.copy()
    dct_norm = {}
    col = f"total_intensity_{channel}"
    
    if not col in list(df_.columns):
        raise ValueError("Ops! Channel not found.")
    
    if byExperiment:
        for n, exp in tqdm(enumerate(set(df_[hue])), total = len(set(df_[hue]))):
            subset = df_[df_[hue] == exp]
            temp = subset.copy()
            temp["bins"] = pd.cut(temp[col], nbins, duplicates = "drop", labels = False)
            bins_mode = statistics.mode(temp["bins"])
            mode_ = temp[col][temp["bins"] == bins_mode].median()
            subset[col] = subset[col] / mode_
            dct_norm[exp] = mode_
            if n == 0:
                out_df = subset.copy()
            else:
                out_df = out_df.append(subset)
    elif not byExperiment:
        temp = df_.copy()
        temp["bins"] = pd.cut(temp[col], nbins, duplicates = "drop", labels = False)
        bins_mode = statistics.mode(temp["bins"])
        mode_ = temp[col][temp["bins"] == bins_mode].median()
        df_[col] = df_[col] / mode_
        dct_norm["all"] = mode_
        out_df = df_.copy()
                
    out_df["isSingleCell"] = [True if row[col] >= 1-spread and row[col] <= 1+spread else False 
                              for index, row in out_df.iterrows()]
    
    return out_df


def generatePairs(data):
    if not isinstance(data, list):
        try:
            data = list(data)
        except:
            raise ValueError("Input should be list or vector.")
    res = [(a, b) for idx, a in enumerate(data) for b in data[idx + 1:]]
    return res


def _normalise_data(X, method = "standardscaler", copy = False):

    X = X.copy() if copy else X
    
    if method.lower() == "standardscaler":
        X = StandardScaler().fit_transform(X)
    elif method.lower() == "minmaxscaler":
        X = MinMaxScaler().fit_transform(X)
    elif method.lower() == "maxabsscaler":
        X = MaxAbsScaler().fit_transform(X)
    else:
        pass
       # logg.info(f"Method '{method}' not supported. Data was not normalised.")
    
    return X


def show_cell(data, order_by = "areaNucleus", fig_height = 15, fig_width = 40, show_nucleus = True,
              contrast_red = 3, contrast_green = 3, contrast_blue = 4, uniqID = False, channels = ["var", "rfp", "beta3"]):
    df = data.copy()
    
    # Ask for the number of cells to show
    while True:
        no_cells = input('\nEnter number of nuclei to show (any integer OR "all"): ')
        try:
            no_cells = int(no_cells)
            break
        except:
            if isinstance(no_cells, str):
                if no_cells.lower() == 'all':
                    no_cells = len(df)
                    break
            else:
                print('Ops! Invalid number format! Enter an integer or "all"')


    if len(df) == no_cells:
        print(f"\nShowing all cells ({len(df)}) in the selected area")

    if len(df) > no_cells:
        print('\nShowing {0} cells of a total of {1} in the selected data'.format(no_cells, str(len(df))))

    if len(df) < no_cells:
        no_cells = len(df)
        print('\nONLY ' + str(len(df)) + ' cells were found in the selected data')

    new_df = df.sample(n = no_cells)

    # Get the names of the channels
    dct_channels = {}

    for ch in channels:
        dct_channels[ch] = input(f'Show {ch} (y/n): ')

    # Color of the channels
    dct_colors = {}

    for ch in dct_channels:
        if dct_channels[ch].lower() == 'y':
            while True:
                try:
                    dct_colors[ch] = input('Desired colour for {0} (red/green/blue): '.format(ch))
                    if dct_colors[ch] == 'red' or dct_colors[ch] == 'green' or dct_colors[ch] == 'blue':
                        break
                    else:
                        raise ValueError
                except:
                    print('Input color {0} is not valid!'.format(dct_colors[ch]))
                    pass

    # Generate the figure
    if no_cells <= 5:
        fig, axes = plt.subplots(nrows = no_cells, ncols = 1, sharex = True, sharey = True,
                                 figsize = (int(fig_width)/2.54, (int(fig_height)/2.54) * no_cells))
    elif no_cells > 5 and no_cells <= 10:
        fig, axes = plt.subplots(nrows = 2, ncols = 5, sharex = True, sharey = True,
                                figsize = (int(fig_width)/2.54, (int(fig_height)/2.54)))
    elif no_cells > 10:
        if int(str(no_cells)[-1]) > 5:
            fig, axes = plt.subplots(nrows = ((no_cells // 5) + 1), ncols = 5, sharex = True, sharey = True,
                                    figsize = (int(fig_width)/2.54, (int(fig_height)/2.54) * ((no_cells // 10) + 1)))
        elif int(str(no_cells)[-1]) <= 5 and int(str(no_cells)[-1]) > 0:
            fig, axes = plt.subplots(nrows = (no_cells // 5 + 1), ncols = 5, sharex = True, sharey = True,
                                    figsize = (int(fig_width)/2.54, (int(fig_height)/2.54) * ((no_cells // 10) + 0.5)))
        elif int(str(no_cells)[-1]) == 0:
            fig, axes = plt.subplots(nrows = (no_cells // 5), ncols = 5, sharex = True, sharey = True,
                                    figsize = (int(fig_width)/2.54, (int(fig_height)/2.54) * (no_cells // 10)))

    ax = axes.ravel()

    # Create Unique ID dictionary
    if uniqID:
        uniqID_dct = {}

    n = 0
    for index, row in tqdm(new_df.iterrows(), total = new_df.shape[0]):
        masks = np.load(join(row["path2ong"].replace("output.csv", ""), row["imageID"], f"{row['imageID']}_masks.npy"))
        wk_array = np.load(join(row["path2ong"].replace("output.csv", ""), row["imageID"], f"{row['imageID']}_wkarray.npy"))
        nucleus = wk_array[0].copy()
        nucleus[masks != row['cellID']] = 0
        cX_low, cX_high, cY_low, cY_high = zoomIN(nucleus, row["x_pos"], row["y_pos"], zoom_box_side = 300)
        nucleus = nucleus[cY_low:cY_high, cX_low:cX_high]
        y, x = nucleus.shape
        color_red = Image.fromarray(np.zeros((y, x, 3), dtype = 'uint8')).convert('L')
        color_green = Image.fromarray(np.zeros((y, x, 3), dtype = 'uint8')).convert('L')
        color_blue = Image.fromarray(np.zeros((y, x, 3), dtype = 'uint8')).convert('L')
        for ch in dct_channels:
            if dct_channels[ch].lower() == 'y':
                channel = wk_array[channels.index(ch) + 1].copy()
                channel = channel[cY_low:cY_high, cX_low:cX_high]
                channel = cv2.normalize(channel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
                RGB = np.array((*"RGB",))
                if dct_colors[ch].lower() == 'red':
                    color = np.multiply.outer(channel, RGB == 'R')
                    color_red = Image.fromarray(color).convert('L')
                    enhancer = ImageEnhance.Contrast(color_red)
                    color_red = enhancer.enhance(contrast_red)
                elif dct_colors[ch].lower() == 'green':
                    color = np.multiply.outer(channel, RGB == 'G')
                    color_green = Image.fromarray(color).convert('L')
                    enhancer = ImageEnhance.Contrast(color_green)
                    color_green = enhancer.enhance(contrast_green)
                elif dct_colors[ch].lower() == 'blue':
                    color = np.multiply.outer(channel, RGB == 'B')
                    color_blue = Image.fromarray(color).convert('L')
                    enhancer = ImageEnhance.Contrast(color_blue)
                    color_blue = enhancer.enhance(contrast_blue)
        mrg = Image.merge("RGB", (color_red, color_green, color_blue))
        mrg = np.array(mrg, dtype = 'uint8')
        scalebar = ScaleBar(0.227, 'um', box_alpha = 0, location = "upper left", color = "w") # 1 pixel = 1um
        if show_nucleus == False:
            kernel = np.ones((3, 3), np.uint8)
            masks[masks != int(row['cellID'])] = 0
            masks[masks == int(row['cellID'])] = 1
            masks = masks[cY_low:cY_high, cX_low:cX_high]
            masks = np.uint16(masks)
            eroded = cv2.erode(masks, kernel, iterations = 2)
            nucleus = masks - eroded
            nucleus = np.array(nucleus, dtype = 'uint8')
            mrg[nucleus == 1] = [255, 255, 255]
            ax[n].set_xlim(0,299)
            ax[n].set_ylim(299,0)
            ax[n].imshow(mrg)
            ax[n].add_artist(scalebar)
        if show_nucleus == True:
            nucleus = np.ma.masked_where(nucleus == 0, nucleus)
            ax[n].set_xlim(0,299)
            ax[n].set_ylim(299,0)
            ax[n].imshow(mrg)
            ax[n].imshow(nucleus)
            ax[n].add_artist(scalebar)
        if uniqID:
            ax[n].set_title(f"ID: {n+1}", fontdict = {'fontsize' : 8})
            uniqID_dct[str(n+1)] = {"cellID": row["cellID"], "imageID": row["imageID"]}
        else:
            ax[n].set_title(f"{row['imageID']} | {row['cellID']}", fontdict = {'fontsize' : 8})
        ax[n].axis("off")
        n += 1
    plt.tight_layout()

    if uniqID:
        return fig, uniqID_dct
    else:
        return fig
    
    
def zoomIN(nucleus, x_pos, y_pos, zoom_box_side = 300):
    zoom_box_side = zoom_box_side / 2
    cY = int(y_pos)
    cX = int(x_pos)
    cY_low = cY - zoom_box_side
    cY_high = cY + zoom_box_side
    cX_low = cX - zoom_box_side
    cX_high = cX + zoom_box_side
    if (cY-zoom_box_side) < 0:
        cY_low = 0
    if (cY+zoom_box_side) > len(nucleus):
        cY_high = len(nucleus)
    if (cX-zoom_box_side) < 0:
        cX_low = 0
    if (cX+zoom_box_side) > len(nucleus[0]):
        cX_high = len(nucleus[0])
    return int(cX_low), int(cX_high), int(cY_low), int(cY_high)


def embeddingPlotter(adata, basis = "umap", size = 20):
    
    df = adata.obs.copy()
    df = df.reset_index(drop = True)
    
    if basis == "diffmap":
        
        df["x"] = adata.obsm["X_diffmap"][..., 1]
        df["y"] = adata.obsm["X_diffmap"][..., 2]
        
    f = go.FigureWidget([go.Scatter(y = df["y"],
                                x = df["x"],
                                mode = 'markers',
                                marker=dict(size = size
                                            )
                                )
                     ]
                    )

    scatter = f.data[0]

    t = go.FigureWidget([go.Table(
    header=dict(values = df.columns,
                fill = dict(color='#C2D4FF'),
                align = ['left'] * 5),
    cells=dict(values=[df[col].to_list() for col in df.columns],
               fill = dict(color='#F5F8FF'),
               align = ['left'] * 5))])

    def selection_fn(trace,points,selector):
        t.data[0].cells.values = [df.reindex(index = points.point_inds)[col] for col in df.columns]

    scatter.on_selection(selection_fn)

    return f, t
    
    
def selection2df(table):

    d = table.to_dict()
    df_out = pd.DataFrame(d['data'][0]['cells']['values'], index = d['data'][0]['header']['values']).T
    df_out = df_out.reset_index(drop = True)
    return df_out


def centerDAPI(data, splitBy = "experiment", nbins = 100, showPlot = True):
    
    modes_ = {}
    for exp in data[splitBy].unique():
        subset = data[data[splitBy] == exp]
        subset["bins"] = pd.cut(subset["total_intensity_dapi"], nbins, duplicates = "drop", labels = False)
        bins_mode = statistics.mode(subset["bins"])
        mode_ = subset["total_intensity_dapi"][subset["bins"] == bins_mode].median()
        modes_[exp] = mode_
        
    dapi_reference = data["total_intensity_dapi"].median()
    
    dapi_norm = {}
    for k, v in modes_.items():
        dapi_norm[k] = dapi_reference / v
        
    data["avg_intensity_dapi"] = [row["avg_intensity_dapi"] * dapi_norm[row["experiment"]] for index, row in data.iterrows()]
    
    if showPlot:
        fig, ax = plt.subplots(figsize = (3*len(data[splitBy].unique()),6))
        sns.violinplot(x = splitBy, y = "total_intensity_dapi", data = data, ax = ax)
        for n, exp in enumerate(data[splitBy].unique()):
            X = n
            ax.plot([X-0.4,X+0.4], [modes_[exp],modes_[exp]], color = 'r')
        plt.tight_layout()
        plt.show()

    return data

## Data Processing
### Select Experiment

In [None]:
path_to_experiments = "E:/emilio/phd/NucleusAnalysis/data"

In [None]:
experiments = [exp for exp in listdir(path_to_experiments) if isdir(join(path_to_experiments, exp))]
print("Experiments in current directory:\n")
for exp in experiments:
    print(f"\t{exp}")

while True:
    experiment = input("\nEnter one of the experiments above: ")
    if experiment in experiments:
        break
    else:
        print("Experiment entered NOT in list of experiments available. Try again...")

### Read Data

In [None]:
within_experiment = [file for file in listdir(join(path_to_experiments, experiment)) 
                     if isdir(join(path_to_experiments, experiment, file))]

for n, file in enumerate(within_experiment):
    if n == 0:
        data = pd.read_csv(join(path_to_experiments, experiment, file, "out_ng", "output.csv"))
        data["experiment"] = file
    else:
        temp = pd.read_csv(join(path_to_experiments, experiment, file, "out_ng", "output.csv"))
        temp["experiment"] = file
        data = data.append(temp)

### Center DAPI

In [None]:
data = centerDAPI(data, splitBy = "experiment", nbins = 100, showPlot = True)

### Identify Single Cells
Identify single cells based on DNA marker content.

In [None]:
scData = find_SingleCells(data, byExperiment = True, nbins = 100, spread = 0.4, channel = "dapi", hue = "experiment")

Check selection of single cells:

In [None]:
fig, ax = plt.subplots(figsize = (6.4, 4.8))
ax = sns.scatterplot(data = scData, y = "avg_intensity_dapi", x = "nuclear_area", hue = "isSingleCell", alpha = 0.5,
                    ax = ax)
ax.set(xscale = "log"); ax.set(yscale = "log"); plt.tight_layout()
plt.show()

In [None]:
# Keep only single cells
scData = scData[scData["isSingleCell"] == True].copy()

### Intensity Normalisation
Statistic-based normalisation of intensity data. **Options are: mode, mean, and median.** *nbins* is used only when method is *mode*. DAPI channel is not normalised.

In [None]:
normData, normMetadata = intensityNormalisation(scData, method = "mode", nbins = 100, verbose = False, 
                                                hue = "experiment")

Observe data before normalisation for a channel. The red line represents the statistical method value used for normalisation.

In [None]:
channel = "dcx" # Modify channel as needed
log_scale = False # Modify to True or False as needed

fig, ax = plt.subplots(figsize = (len(normMetadata)*1.5, 6))

x = "experiment"
y = f"avg_intensity_{channel}"
data_to_plot = scData[[x, y]].copy()

if log_scale:
    data_to_plot[y] = [log10(l) for l in data_to_plot[y]]
    
ax = sns.violinplot(x = x, y = y, data = data_to_plot, palette = "Set3", bw = .2, order = list(normMetadata.keys()),
                    ax = ax)
    
plt.xticks(rotation = 45,ha = "right")

for n, exp in enumerate(normMetadata):
    Y = log10(normMetadata[exp][y]) if log_scale else normMetadata[exp][y]
    X = n
    plt.plot([X-0.4,X+0.4], [Y,Y], color='r')

if log_scale:
    plt.ylabel(f"{y} (log10)")
plt.tight_layout()
#fig.savefig(f"J:/emilio/figures/laminAC/{channel}_normalisation.pdf")
plt.show()

## Data Exploration

### Linear relationships

In [None]:
x = "avg_intensity_dapi" # Change as needed
y = "nuclear_area" # Change as needed

fig = sns.lmplot(x = x, 
                 y = y, 
                 data = normData,
                 hue = "experiment",
                 lowess = True,
                 scatter = False
                )
#plt.xlim(10,)
plt.ylim(0,)
plt.show()

### Dimension Reduction

In [None]:
wkData = normData.copy()

In [None]:
# List of observations
obs_ = ['imageID', 
        'experiment', 
        'cellID',
        "x_pos", 
        "y_pos",
        "angle",
        "dcx_class",
        "rfp_class",
        "laminB1_class",
        'avg_intensity_laminB1',
        'path2ong',
       ]

In [None]:
# List of features
data_cols = [
'avg_intensity_dapi',
'nuclear_area',
'nuclear_perimeter',
'major_axis',
'minor_axis',
'axes_ratio',
'circularity',
'eccentricity',
'solidity',
'avg_intensity_core_dapi',
'avg_intensity_internal_ring_dapi',
'avg_intensity_external_ring_dapi',
#'total_intensity_core_dapi',
#'total_intensity_internal_ring_dapi',
#'total_intensity_external_ring_dapi',
#'total_intensity_dapi',
#'total_intensity_gfap',
'avg_intensity_rfp',
#'total_intensity_rfp',
'avg_intensity_dcx',
#'total_intensity_beta3',
#'dna_peaks',
'dna_dots',
'dna_dots_size_median',
'spatial_entropy',
#'total_intensity_olig2',
#'gfap_positive',
#'gfap_frac_covered',
#'rfp_positive',
#'rfp_frac_covered',
#'beta3_positive',
#'beta3_frac_covered',
#'isSingleCell',
#'rfp_dcx_product'
]

In [None]:
# Create adata object
adata = anndata.AnnData(    
    X = wkData[data_cols].values,
    obs = pd.DataFrame(
        wkData.index.to_list(), 
        columns = ["cell_uniqID"], 
        index = [str(n) for n in wkData.index.to_list()]
        ),
    var = pd.DataFrame(
        wkData[data_cols].columns.to_list(), 
        columns = ["feature"], 
        index = [str(n) for n,c in enumerate(wkData[data_cols].columns)])
    )

adata.var_names = adata.var["feature"].to_list()

for o in tqdm(obs_):
    if o in list(wkData.columns):
        if "intensity" in o:
            adata.obs[o] = wkData[o].to_list()
        else:
            adata.obs[o] = [str(l) for l in wkData[o]]

In [None]:
# Data pre-processing
adata.X = _normalise_data(adata.X)
sc.pp.scale(adata, max_value = 10)

#### UMAP

In [None]:
# Find neighbours
sc.pp.neighbors(adata, n_neighbors = 30, use_rep = 'X', method = 'umap')

In [None]:
# Run UMAP
sc.tl.umap(adata)

In [None]:
# Plot UMAP showing features
fig, ax = plt.subplots(figsize = (4, 4))
sc.pl.umap(adata, color = "avg_intensity_dapi", frameon = False, ax = ax,
          size = 30,
          )

fig.tight_layout()
plt.show()

In [None]:
# Find clusters
sc.tl.leiden(adata, resolution = 0.6)

In [None]:
# Plot UMAP showing features
fig, ax = plt.subplots(figsize = (4, 4))
sc.pl.umap(adata, color = "leiden", frameon = False, ax = ax, legend_loc = "on data",
          size = 30,
          )

fig.tight_layout()
plt.show()

#### DIFFMAP

In [None]:
# Find neighbours
sc.pp.neighbors(adata, n_neighbors = 30, use_rep = 'X', method = 'gauss')

In [None]:
# Run DIFFMAP
sc.tl.diffmap(adata)

In [None]:
# Plot DIFFMAP showing features
fig, ax = plt.subplots(figsize = (4, 4))
sc.pl.diffmap(adata, color = "avg_intensity_dapi", frameon = False, ax = ax,
              size = 30, dimensions = [1,2]
             )

fig.tight_layout()
plt.show()

In [None]:
# Find clusters
sc.tl.leiden(adata, resolution = 0.6)

In [None]:
# Plot DIFFMAP showing features
fig, ax = plt.subplots(figsize = (4, 4))
sc.pl.diffmap(adata, color = "leiden", frameon = False, ax = ax, legend_loc = "on data",
              size = 30, dimensions = [1,2]
             )

fig.tight_layout()
plt.show()

#### Pseudotime
Choose a root cell for diffusion pseudotime:

In [None]:
adata.uns['iroot'] = np.flatnonzero(adata.obs['leiden']  == '3')[0]

Run diffusion pseudotime:

In [None]:
sc.tl.dpt(adata)

In [None]:
# Plot DIFFMAP showing features
fig, ax = plt.subplots(figsize = (4, 4))
sc.pl.diffmap(adata, color = "dpt_pseudotime", frameon = False, ax = ax,
              size = 30, dimensions = [1,2]
             )

fig.tight_layout()
plt.show()

#### Stacked violin plot

In [None]:
fig, ax = plt.subplots(figsize = (5, 7))
sc.pl.stacked_violin(adata, data_cols, groupby = 'experiment', swap_axes = True, ax = ax, dendrogram = True)
fig.tight_layout()
plt.show()

#### Pseudotime - heatmap

In [None]:
# Enter order of clusters in pseudotime
pseudotime_path = [3,4,7]

In [None]:
# Heatmap - pseudotime
sc.pl.paga_path(
    adata, 
    pseudotime_path, 
    data_cols,
    show_node_names = True,
    n_avg = 50,
    annotations = ['dpt_pseudotime'],
    show_colorbar = True,
    color_map = 'coolwarm',
    groups_key = 'leiden',
    color_maps_annotations = {'dpt_pseudotime': 'viridis'},
    title = 'Path',
    return_data = False,
    normalize_to_zero_one = True,
    show = True
)

### Save Object

In [None]:
adata.write("/save/path/filename.hdf5")