In [None]:
from pathlib import Path
import numpy as np
import importlib
from matplotlib import pyplot as plt
import humanize
import sys
import os, psutil
import time
import shutil
import pandas as pd
import seaborn as sns
import warnings
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import logging
import string
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib.ticker import AutoLocator
import tifffile

import astrocast.reduction as red
import astrocast.clustering as clust
import astrocast.analysis as ana
import astrocast.autoencoders as AE
import astrocast.helper as helper
import astrocast.experiments as exp

for pack in [red, clust, helper, ana, AE, helper, exp]:
    importlib.reload(pack)


In [None]:
dir_ = Path("/media/janrei1/data/astrocast_paper/22A8x5-2_small.roi")
tiffs = {}
for tiff in dir_.glob("*.tiff"):
    tiffs[tiff.name] = tifffile.imread(tiff)

display(tiffs.keys())

In [None]:
# data = ana.Video(dir_.parent.joinpath("22A8x5-2_small.h5"), loc="dn/ast", lazy=False).get_data()
data = tiffs["debug_smoothed_input.tiff"]
event_map = tiffs["event_map.tiff"]
# event_map = tifffile.imread(dir_.joinpath("event_map.tiff"))

In [None]:
n_frame = 165

sns.set_style("whitegrid")

fig, axx = plt.subplots(1, len(tiffs), figsize=(20, 20))

for i, (k, img) in enumerate(tiffs.items()):
    ax = axx[i]
    ax.imshow(img[n_frame, :, :])
    ax.set_title(k.replace(".tiff", "").replace("_", " "))
    ax.grid(None)

In [None]:
# pixels = [(144, 114), (195, 212), (108, 226)]
pixels = [(147, 107), (198, 206), (109, 222)]
colors = sns.color_palette('husl', n_colors=len(pixels))

sns.set_style("white")
fig, axx = plt.subplot_mosaic("ABEE\nABEE\nABFF\nCDFF\nCDGG\nCDGG",
                              figsize=(11, 6),  # 15, 6
                              # subplot_kw=dict(width_ratios=[1, 1, 2])
                              )

ax = axx['C']
ax.imshow(tiffs["debug_smoothed_input.tiff"][n_frame, :, :], cmap="magma")
ax.set_title("smoothed frame")

ax = axx['A']
ax.imshow(tiffs["debug_active_pixels_spatial.tiff"][n_frame, :, :], cmap="binary")
ax.set_title("spatial mask")

ax = axx['B']
ax.imshow(tiffs["debug_active_pixels_temporal.tiff"][n_frame, :, :], cmap="binary")
ax.set_title("temporal mask")

ax = axx['D']
ax.imshow(tiffs["event_map.tiff"][n_frame, :, :])
ax.set_title("event map")

for k in "ABCD":
    ax = axx[k]
    ax.grid(False)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    
    for i, (px, py) in enumerate(pixels):
        ax.scatter(py + 7, px, marker='<', edgecolor='white', color=colors[i], s=50, alpha=0.74)

for k in "CD":
    ax = axx[k]
    
    for i, (px, py) in enumerate(pixels):
        ax.annotate(text="EFG"[i], xy=(py + 17, px + 8), xycoords='data', color="white")

for i in range(len(pixels)):
    
    ax = axx[chr(69 + i)]
    px, py = pixels[i]
    
    trace = data[:, px, py].copy()
    trace -= trace.min()
    mask = event_map[:, px, py]
    
    # Get indices where mask is 0 for background, and 1 for signal
    indices_background = np.where(mask == 0)[0]
    indices_signal = np.where(mask > 0)[0]
    
    # Create the x-axis data range
    X = np.arange(len(trace))
    
    # Get y-values for signal and background
    signal_y = trace[indices_signal]
    background_y = trace[indices_background]
    
    # Create contiguous sections for signal
    contiguous_signal_sections = np.split(signal_y, np.where(np.diff(indices_signal) != 1)[0] + 1)
    contiguous_signal_indices = np.split(indices_signal, np.where(np.diff(indices_signal) != 1)[0] + 1)
    
    # Plot each contiguous section for signal
    for sec_indices, sec_y in zip(contiguous_signal_indices, contiguous_signal_sections):
        sec_x = X[sec_indices]
        ax.plot(sec_x, sec_y, linestyle="-", color=colors[i])
    
    # Create contiguous sections for background
    contiguous_background_sections = np.split(background_y, np.where(np.diff(indices_background) != 1)[0] + 1)
    contiguous_background_indices = np.split(indices_background, np.where(np.diff(indices_background) != 1)[0] + 1)
    
    # Plot each contiguous section for background
    for bg_indices, bg_y in zip(contiguous_background_indices, contiguous_background_sections):
        bg_x = X[bg_indices]
        ax.plot(bg_x, bg_y, linestyle="-", color="gray")

for k, ax in axx.items():
    ax.text(-0.075, 1.05, k, transform=ax.transAxes, size=14, weight='bold')

for k in "EFG":
    axx[k].set_ylabel("pixel intensity")
    axx[k].axvline(n_frame, linestyle=":", color="black", alpha=0.75)

for k in "EF":
    axx[k].set_xticklabels([])

axx['G'].set_xlabel("Frames")

plt.tight_layout()

fig_name = "5"
save_path = Path.cwd().parent.joinpath(f"{fig_name}.png")
fig.savefig(save_path, dpi=(260))

legend = """
Depiction of astrocytic event detection employed by astroCaST using spatial and temporal thresholding. A) Binary mask of frame after application of spatial threshold (min\_ratio 1). B) Binary mask of framer after application of temporal threshold (prominence 2, width 3, rel\_height 0.9). C) Frame used for thresholding after motion correction, denoising and smoothing. D) Events detected as identified by both spatial and temporal thresholding. E-G) Pixel intensity analysis for selected pixels (as indicated in Panels A-D), with active frames color-coded in the plots. The frame shown in A-D is indicated as a vertical dotted line.
"""
legend_path = Path.cwd().parent.joinpath(f"{fig_name}.txt")
with open(legend_path, 'w') as f:
    f.write(legend)