In [None]:
import numpy as np
import gudhi as gd
from scipy.ndimage import label, gaussian_filter

In [None]:
import numpy as np
import pandas as pd

import gudhi as gd
from sklearn.preprocessing import MinMaxScaler
from scipy.ndimage import gaussian_filter, generate_binary_structure

from pathlib import Path
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

from src.inputreader import read_segmented_images, persistence_writer_ensure_filesize
from src.inputreader import read_persistence_files
from src.auxfunctions import compute_vectorizations_all
from src.auxfunctions import get_all_classifications

In [None]:
def create_ball_structure(radius):
    """
    Create a 3D ball structure of a given radius and trim it to remove all-False outer layers.

    Parameters:
    radius (float): The radius of the ball.

    Returns:
    numpy.ndarray: A 3D boolean array representing the trimmed ball structure.
    """
    # Determine the size of the array based on the radius
    size = int(2 * np.ceil(radius) + 1)

    # Create a grid of coordinates
    x, y, z = np.indices((size, size, size))

    # Calculate the distance from the center
    center = np.array([(size - 1) / 2, (size - 1) / 2, (size - 1) / 2])
    distance = np.sqrt((x - center[0])**2 + (y - center[1])**2 + (z - center[2])**2)

    # Create the ball structure
    ball_structure = distance <= radius

    # Trim the structure to remove all-False outer layers
    while np.all(~ball_structure[0, :, :]):
        ball_structure = ball_structure[1:, :, :]
    while np.all(~ball_structure[:, 0, :]):
        ball_structure = ball_structure[:, 1:, :]
    while np.all(~ball_structure[:, :, 0]):
        ball_structure = ball_structure[:, :, 1:]
    while np.all(~ball_structure[-1, :, :]):
        ball_structure = ball_structure[:-1, :, :]
    while np.all(~ball_structure[:, -1, :]):
        ball_structure = ball_structure[:, :-1, :]
    while np.all(~ball_structure[:, :, -1]):
        ball_structure = ball_structure[:, :, :-1]

    return ball_structure

def label_cubical_complex_3d(array_3d, structuresize=3):
    """
    Label a 3D array with the same connectivity as Gudhi's cubical complex (8-neighborhood in 3D).

    Parameters:
    array_3d (numpy.ndarray): A 3D binary array where features to be labeled are True or 1.

    Returns:
    numpy.ndarray: A labeled array where each feature has a unique label.
    """
    # Define the connectivity structure for 8-neighborhood in 3D
    structure = np.ones((structuresize, structuresize, structuresize), dtype=bool)

    # Label the features in the array
    labeled_array, num_features = label(array_3d, structure=structure)

    return labeled_array, num_features

In [None]:
data_input = Path('data_segmented')
input_airyscan = data_input / 'Airyscan'
input_sted = data_input / 'STED'

data_pers = Path('data_processed')

pers_sted = data_pers / 'persistence_sted'
pers_airyscan = data_pers / 'persistence_airyscan'
pers_sted_other = data_pers / 'persistence_sted' / 'other_preprocessing'
pers_airyscan_other = data_pers / 'persistence_airyscan' / 'other_preprocessing'

vec_sted = data_pers / 'vectorizations_sted'
vec_airyscan = data_pers / 'vectorizations_airyscan'
vec_sted_other = data_pers / 'vectorizations_sted' / 'other_preprocessing'
vec_airyscan_other = data_pers / 'vectorizations_airyscan' / 'other_preprocessing'

classification_path_preproc = data_pers / 'classification' / 'preprocessing'

In [None]:
# # get the metadata for the original files as well as the segmented files
# df_sted_metadata = pd.read_csv(data_input/ 'sted_df_metadata.csv', comment='#')

# masks, bounding_boxes, labels_str, labels, npzfiles = \
#     read_segmented_images(input_sted, microscope='sted', replace_nan_with=0)

# df_labels = pd.DataFrame(labels_str, columns=['labels_str'])
# df_labels.loc[:, 'labels'] = labels
# df_labels.loc[:, 'id'] = np.arange(len(labels))
# df_labels.loc[:, 'microscope'] = 'sted'
# df_labels.loc[:, 'filename'] = [f.name for f in npzfiles]
# df_labels.to_csv(Path(data_pers, 'labels_persistence_sted.csv'),
#                  index=False)

# # get the physical pixel sizes for each image
# for filename in npzfiles:
#     assert len(df_sted_metadata.loc[df_sted_metadata['segmented_filename'] == filename.name, :]) == 1
# pixelsizes = [df_sted_metadata.loc[df_sted_metadata['segmented_filename'] == filename.name,
#                                    ['pixel_size_z', 'pixel_size_x', 'pixel_size_y']]\
#                                     .values[0] for filename in npzfiles]
# pixelsizes = np.array(pixelsizes)

In [None]:
df_airy_metadata = pd.read_csv(data_input / 'airyscan_df_metadata.csv')
masks, bounding_boxes, labels_str, labels, npzfiles = read_segmented_images(input_airyscan, microscope='airyscan')

df_labels = pd.DataFrame(labels_str, columns=['labels_str'])
df_labels.loc[:, 'labels'] = labels
df_labels.loc[:, 'id'] = np.arange(len(labels))
df_labels.loc[:, 'microscope'] = 'airyscan'
df_labels.loc[:, 'filename'] = [f.name for f in npzfiles]

df_labels.to_csv(Path(data_pers, 'labels_persistence_airyscan.csv'),
                 index=False)

pixelsizes = [df_airy_metadata.loc[df_airy_metadata['segmented_filename'] == filename.name,
              ['pixel_size_z', 'pixel_size_x', 'pixel_size_y']]\
              .values[0] for filename in npzfiles]
pixelsizes = np.array(pixelsizes)

assert np.all(pixelsizes[:, 1] == pixelsizes[:, 2])

In [None]:
preproc = 'clip_minmax_gaussian4a_minmax'

In [None]:
sigma = 1

filtered_images = []
for i, mask_loop in tqdm(enumerate(masks), total=len(masks)):
    mask = mask_loop.astype(np.float64)
    if np.any(np.isnan(mask)):
        assert np.nanmin(mask) == 0
        mask[np.isnan(mask)] = 0
    mask_org = mask.copy()

    # if preproc != 'raw':
    #     quant05 = np.nanquantile(mask[mask > np.min(mask)], 0.05)
    #     quant95 = np.nanquantile(mask[mask > np.min(mask)], 0.95)
    #     mask = np.clip(mask, quant05, quant95)
    
    # if 'clip_minmax' in preproc:
    #     mask = MinMaxScaler().fit_transform(mask.reshape(-1, 1)).reshape(mask.shape)
    
    if 'gaussian' in preproc and ('a_minmax' in preproc or 'a_mask0' in preproc or preproc.endswith('a')):
        sigma_pixels = 1
    elif 'gaussian' in preproc and ('b_minmax' in preproc or 'b_mask0' in preproc or preproc.endswith('b')):
        # set the sigmas such that pixel_x and pixel_y are 1
        sigma_pixels = pixelsizes[i].copy()
        # x and y resolution should be the same
        assert sigma_pixels[1] == sigma_pixels[2]
        sigma_pixels /= sigma_pixels[1]
    elif 'gaussian' in preproc and ('c_minmax' in preproc or 'c_mask0' in preproc or preproc.endswith('c')):
        # set the sigmas such that pixel_z are 1
        sigma_pixels = pixelsizes[i].copy()
        # x and y resolution should be the same
        assert sigma_pixels[1] == sigma_pixels[2]
        sigma_pixels /= sigma_pixels[0]
    
    if 'gaussian' in preproc:
        gaussian_truncate = \
            int(preproc[preproc.find('gaussian') + len('gaussian'):].split('_')[0][:-1])
        mask = gaussian_filter(mask, sigma=sigma_pixels,
                truncate=gaussian_truncate, mode='constant', cval=0.0)
    
    filtered_images.append(mask)

# for a single image

In [None]:
## THIS TAKES SUPER LONG

from scipy.ndimage import distance_transform_edt

# labelled, numfeats = label_cubical_complex_3d(image_binary)
# distances = np.zeros([numfeats, numfeats])

# for feat in tqdm(range(1, numfeats)):
#     image_oneregion = labelled != feat
#     disttemp = distance_transform_edt(image_oneregion, sampling=(1, 1, 1), return_distances=True)

#     # sigma_pixels = pixelsizes[i].copy()
#     # sigma_pixels /= sigma_pixels[1]
#     # disttemp = distance_transform_edt(image_oneregion, sampling=sigma_pixels, return_distances=True)

#     for feat2 in range(feat+1, numfeats):
#         distances[feat, feat2] = np.min(disttemp[labelled == feat2])
#         distances[feat2, feat] = distances[feat, feat2]

In [None]:
# values = []
# for thresh in tqdm(np.linspace(np.min(mask), np.max(mask), 100, endpoint=False)):
#     labeled, numfeats = label(mask >= thresh, structure=generate_binary_structure(3, 2))
#     values.append([thresh, numfeats])

# values = np.array(values)

# plt.plot(values[:, 0], values[:, 1])

In [None]:
from scipy.ndimage import distance_transform_edt, binary_dilation, label
from scipy.spatial.distance import squareform

def compute_region_distances(labeled_array, intersecting_labels):
    distances_tmp = np.zeros((len(intersecting_labels), len(intersecting_labels)))
    for i, labelA in enumerate(intersecting_labels):
        # Compute the distance transform for the current region
        dist_mask = distance_transform_edt(~(labeled_array == labelA), return_distances=True)

        for j, labelB in enumerate(intersecting_labels[i+1:]):
            # Compute the minimum distance between the two regions
            min_distance = np.min(dist_mask[labeled_array == labelB])
            distances_tmp[i, j+i+1] = min_distance
            distances_tmp[j+i+1, i] = min_distance
    
    return distances_tmp

def find_regions_within_distance_optimized(labeled_array, N):
    """
    Find all regions in a labeled array that are within a distance of N pixels from each other using EDT and dilation.

    Parameters:
    labeled_array (numpy.ndarray): A labeled array where each region has a unique label.
    N (int): The maximum distance in pixels.

    Returns:
    dict: A dictionary where keys are region labels and values are lists of region labels within distance N.
    """

    # Initialize a dictionary to store the regions within distance N
    regions_within_distance = {}

    # Create a labeled dilation mask
    dilated_region = binary_dilation(labeled_array > 0,
                        structure=np.ones((3, 3, 3)),
                        iterations=N)
    labeled_dilation_mask, numfeats = label(dilated_region, structure=generate_binary_structure(3, 2))

    # Check for intersections between dilated masks
    for label1 in range(1, numfeats + 1):
        # Find other regions that intersect with the dilated mask
        intersecting_labels = np.unique(labeled_array[labeled_dilation_mask == label1])
        intersecting_labels = intersecting_labels[intersecting_labels != 0]

        distances = compute_region_distances(labeled_array, intersecting_labels)

        for label_in in intersecting_labels:
            if label_in in regions_within_distance:
                print(f"Label {label_in} is already in?")

            if len(intersecting_labels) > 1:
                regions_within_distance[label_in] = (list(intersecting_labels), squareform(distances))
            assert len(intersecting_labels) > 0, f'{np.unique(labeled_array[labeled_dilation_mask == label1])}'
        

    return regions_within_distance


In [None]:
import pickle
import numpy as np
from skimage.measure import regionprops
from scipy.ndimage import generate_binary_structure, label
from tqdm import tqdm

minval = np.min([np.min(mask) for mask in filtered_images])
maxval = np.max([np.max(mask) for mask in filtered_images])
median = np.mean([np.quantile(mask, 0.5) for mask in filtered_images])

numberfeatures = []
smalldistances = []
totalsizes = []
sizes = []
thresholds = list(np.linspace(minval, median, 25, endpoint=False)[1:])\
    + list(np.linspace(median, maxval, 75, endpoint=False))

for maski, mask in tqdm(enumerate(filtered_images)):
    numfeats = []
    smalldist = []
    size = []
    totalsize = []

    for thresh in tqdm(thresholds):
        labeled, numfeat = label(mask >= thresh, structure=generate_binary_structure(3, 2))
        numfeats.append(numfeat)

        smalldist.append(find_regions_within_distance_optimized(labeled, 10))
        size.append([int(region.area) for region in regionprops(labeled)])
        totalsize.append(np.count_nonzero(masks[maski] > np.min(masks[maski])))

    numberfeatures.append(numfeats)
    smalldistances.append(smalldist)
    totalsizes.append(totalsize)
    sizes.append(size)

# Dump everything into a single pickle file
data_to_pickle = {
    "numberfeatures": numberfeatures,
    "smalldistances": smalldistances,
    "totalsizes": totalsizes,
    "sizes": sizes,
    "thresholds": thresholds
}

with open("region_data.pkl", "wb") as file:
    pickle.dump(data_to_pickle, file)

print("Data successfully saved to region_data.pkl")

In [None]:
# Dump everything into a single pickle file
data_to_pickle = {
    "numberfeatures": numberfeatures,
    "smalldistances": smalldistances,
    "totalsizes": totalsizes,
    "sizes": sizes,
    "thresholds": thresholds
}

with open("region_data_airyscan.pkl", "wb") as file:
    pickle.dump(data_to_pickle, file)

print("Data successfully saved to region_data.pkl")

In [None]:
# get the metadata for the original files as well as the segmented files
df_sted_metadata = pd.read_csv(data_input/ 'sted_df_metadata.csv', comment='#')

masks, bounding_boxes, labels_str, labels, npzfiles = \
    read_segmented_images(input_sted, microscope='sted', replace_nan_with=0)

df_labels = pd.DataFrame(labels_str, columns=['labels_str'])
df_labels.loc[:, 'labels'] = labels
df_labels.loc[:, 'id'] = np.arange(len(labels))
df_labels.loc[:, 'microscope'] = 'sted'
df_labels.loc[:, 'filename'] = [f.name for f in npzfiles]
df_labels.to_csv(Path(data_pers, 'labels_persistence_sted.csv'),
                 index=False)

# get the physical pixel sizes for each image
for filename in npzfiles:
    assert len(df_sted_metadata.loc[df_sted_metadata['segmented_filename'] == filename.name, :]) == 1
pixelsizes = [df_sted_metadata.loc[df_sted_metadata['segmented_filename'] == filename.name,
                                   ['pixel_size_z', 'pixel_size_x', 'pixel_size_y']]\
                                    .values[0] for filename in npzfiles]
pixelsizes = np.array(pixelsizes)

In [None]:
sigma = 1

filtered_images = []
for i, mask_loop in tqdm(enumerate(masks), total=len(masks)):
    mask = mask_loop.astype(np.float64)
    if np.any(np.isnan(mask)):
        assert np.nanmin(mask) == 0
        mask[np.isnan(mask)] = 0
    mask_org = mask.copy()

    # if preproc != 'raw':
    #     quant05 = np.nanquantile(mask[mask > np.min(mask)], 0.05)
    #     quant95 = np.nanquantile(mask[mask > np.min(mask)], 0.95)
    #     mask = np.clip(mask, quant05, quant95)
    
    # if 'clip_minmax' in preproc:
    #     mask = MinMaxScaler().fit_transform(mask.reshape(-1, 1)).reshape(mask.shape)
    
    if 'gaussian' in preproc and ('a_minmax' in preproc or 'a_mask0' in preproc or preproc.endswith('a')):
        sigma_pixels = 1
    elif 'gaussian' in preproc and ('b_minmax' in preproc or 'b_mask0' in preproc or preproc.endswith('b')):
        # set the sigmas such that pixel_x and pixel_y are 1
        sigma_pixels = pixelsizes[i].copy()
        # x and y resolution should be the same
        assert sigma_pixels[1] == sigma_pixels[2]
        sigma_pixels /= sigma_pixels[1]
    elif 'gaussian' in preproc and ('c_minmax' in preproc or 'c_mask0' in preproc or preproc.endswith('c')):
        # set the sigmas such that pixel_z are 1
        sigma_pixels = pixelsizes[i].copy()
        # x and y resolution should be the same
        assert sigma_pixels[1] == sigma_pixels[2]
        sigma_pixels /= sigma_pixels[0]
    
    if 'gaussian' in preproc:
        gaussian_truncate = \
            int(preproc[preproc.find('gaussian') + len('gaussian'):].split('_')[0][:-1])
        mask = gaussian_filter(mask, sigma=sigma_pixels,
                truncate=gaussian_truncate, mode='constant', cval=0.0)
    
    filtered_images.append(mask)

In [None]:
import pickle
import numpy as np
from skimage.measure import regionprops
from scipy.ndimage import generate_binary_structure, label
from tqdm import tqdm

minval = np.min([np.min(mask) for mask in filtered_images])
maxval = np.max([np.max(mask) for mask in filtered_images])
median = np.mean([np.quantile(mask, 0.5) for mask in filtered_images])

numberfeatures = []
smalldistances = []
totalsizes = []
sizes = []
thresholds = list(np.linspace(minval, median, 25, endpoint=False)[1:])\
    + list(np.linspace(median, maxval, 75, endpoint=False))

for maski, mask in tqdm(enumerate(filtered_images)):
    numfeats = []
    smalldist = []
    size = []
    totalsize = []

    for thresh in tqdm(thresholds):
        if thresh <= np.min(mask):
            continue
        labeled, numfeat = label(mask >= thresh, structure=generate_binary_structure(3, 2))
        numfeats.append(numfeat)

        # smalldist.append(find_regions_within_distance_optimized(labeled, 10))
        size.append([int(region.area) for region in regionprops(labeled)])
        totalsize.append(np.count_nonzero(masks[maski] > np.min(masks[maski])))

    # Dump everything into a single pickle file
    data_to_pickle = {
        "numberfeatures": numfeats,
        "smalldistances": smalldist,
        "totalsizes": totalsize,
        "sizes": size,
        "thresholds": thresholds
    }

    # with open(f"region_data_sted_{maski}.pkl", "wb") as file:
    #     pickle.dump(data_to_pickle, file)

    numberfeatures.append(numfeats)
    smalldistances.append(smalldist)
    totalsizes.append(totalsize)
    sizes.append(size)

# Dump everything into a single pickle file
data_to_pickle = {
    "numberfeatures": numberfeatures,
    "smalldistances": smalldistances,
    "totalsizes": totalsizes,
    "sizes": sizes,
    "thresholds": thresholds
}

with open("region_data_sted.pkl", "wb") as file:
    pickle.dump(data_to_pickle, file)

print("Data successfully saved to region_data.pkl")

# plot and interpret it

In [None]:
def compute_boxplot_quantiles(values):
    """
    Compute the quantiles used for a boxplot.

    Parameters:
    values (list or numpy.ndarray): A list or array of numerical values.

    Returns:
    dict: A dictionary containing the minimum, Q1, median, Q3, and maximum.
    """
    if len(values) == 0:
        return []

    values = np.array(values)
    quantiles = {
        'min': np.min(values),
        'Q1': np.percentile(values, 25),
        'mean': np.mean(values),
        'median': np.percentile(values, 50),
        'Q3': np.percentile(values, 75),
        'max': np.max(values)
    }
    return list(quantiles.values())

## Airyscan

In [None]:
import pickle
with open("region_data_airyscan.pkl", "rb") as file:
    data_loaded = pickle.load(file)

print(data_loaded)

thresholds = data_loaded['thresholds'].copy()
numberfeatures = data_loaded['numberfeatures'].copy()
smalldistances = data_loaded['smalldistances'].copy()
totalsizes = data_loaded['totalsizes'].copy()
sizes = data_loaded['sizes'].copy()

del data_loaded

In [None]:
# sizes_reordered = [[[] for t in thresholds],
#                    [[] for t in thresholds]]
# labels_assignment = {}
# for i, size in tqdm(enumerate(sizes)):
#     for ti, _ in enumerate(thresholds):
#         stats = compute_boxplot_quantiles(size[ti])
#         if len(stats) > 0:
#             sizes_reordered[labels[i]][ti].append(stats)

# for i in range(len(sizes_reordered)):
#     for ti, _ in enumerate(thresholds):
#         if len(sizes_reordered[i][ti]) > 0:
#             sizes_reordered[i][ti] = np.array(sizes_reordered[i][ti])

In [None]:
# data = []
# for i in range(len(sizes_reordered)):
#     for ti, t in enumerate(thresholds):
#         if len(sizes_reordered[i][ti]) > 0:
#             data.append({
#                 'label': i,
#                 'thresh': t,
#                 'min': sizes_reordered[i][ti][0],
#                 'Q1': sizes_reordered[i][ti][1],
#                 'mean': sizes_reordered[i][ti][2],
#                 'median': sizes_reordered[i][ti][3],
#                 'Q3': sizes_reordered[i][ti][4],
#                 'max': sizes_reordered[i][ti][5]
#             }.copy())
# data = pd.DataFrame(data)
# df_data = data.groupby(['label', 'thresh']).mean().reset_index()
# df_count = data.groupby(['label', 'thresh']).count().reset_index()

In [None]:
data = []
for i, size in tqdm(enumerate(sizes)):
    for ti, t in enumerate(thresholds):
        # fix the totalsize part, since we want to exclude the minimal image value
        if t <= np.min(filtered_images[i]):
            continue
        stats = compute_boxplot_quantiles(size[ti])
        if len(stats) > 0:
            data.append({
                'pos': i, 
                'label': labels[i],
                'thresh': t,
                'min': stats[0],
                'Q1': stats[1],
                'mean': stats[2],
                'median': stats[3],
                'Q3': stats[4],
                'max': stats[5],
                'count': len(size[ti])
            }.copy())

data = pd.DataFrame(data)
data.loc[:, 'totalsize'] = [np.count_nonzero(filtered_images[i] > np.min(filtered_images[i]))
                            for i in data['pos'].values]
df_data = data.groupby(['label', 'thresh']).mean().reset_index()
df_count = data.groupby(['label', 'thresh']).count().reset_index()

In [None]:
df_datarel = data.copy()
for col in ['min', 'Q1', 'mean', 'median', 'Q3', 'max']:
    df_datarel[col] = df_datarel[col].astype(np.float64) 
    df_datarel[col] = df_datarel[col] / df_datarel['totalsize'].astype(np.float64)
    
# df_datarel = df_datarel.groupby(['label', 'thresh']).mean().reset_index()

In [None]:
threshidx = set(df_count.loc[(df_count['mean'] > 1) & (df_count['label'] == 0), 'thresh'].values).\
    intersection(set(df_count.loc[(df_count['mean'] > 1) & (df_count['label'] == 0), 'thresh'].values))

df_count.loc[df_count['thresh'].isin(threshidx), :]

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Plot the mean line
fig, ax  = plt.subplots(1, 2, figsize=(20, 10))

threshlimit = [0, 0]
for i in range(2):
    largest_diff = np.argmax(np.diff(df_data.loc[(df_data['label']==i) & (df_data['thresh'].isin(threshidx)), 'mean'].values))
    threshlimit[i] = df_data.loc[(df_data['label']==i) & (df_data['thresh'].isin(threshidx)), 'thresh'].values[largest_diff:largest_diff+1][0]
threshlimit = np.mean(threshlimit)

df = df_data.loc[(df_data['thresh'] < threshlimit) & (df_data['thresh'].isin(threshidx))]
sns.lineplot(data=df, x='thresh', y='mean', hue='label', marker='o', ax=ax[0])

# Add error bars for IQR
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[0].errorbar(subset['thresh'], subset['mean'],
                 yerr=[subset['Q1'], subset['Q3']],
                 fmt='o', capsize=5, color=sns.color_palette()[df['label'].unique().tolist().index(label)],
                 alpha=0.5)

# Mark min and max as outliers
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[0].scatter(subset['thresh'], subset['min'], color='red', marker='_', s=100, label='Min' if label == df['label'].unique()[0] else "")
    ax[0].scatter(subset['thresh'], subset['max'], color='red', marker='_', s=100, label='Max' if label == df['label'].unique()[0] else "")

ax[0].set_yscale('log')
ax[0].set_title('Lineplot with IQR Error Bars and Min/Max Outliers')
ax[0].legend()

df = df_data.loc[df_data['thresh'] >= threshlimit]
sns.lineplot(data=df, x='thresh', y='mean', hue='label', marker='o', ax=ax[1])

# Add error bars for IQR
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[1].errorbar(subset['thresh'], subset['mean'],
                 yerr=[subset['Q1'], subset['Q3']],
                 fmt='o', capsize=5, color=sns.color_palette()[df['label'].unique().tolist().index(label)],
                 alpha=0.5)

# Mark min and max as outliers
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[1].scatter(subset['thresh'], subset['min'], color='red', marker='_', s=100, label='Min' if label == df['label'].unique()[0] else "")
    ax[1].scatter(subset['thresh'], subset['max'], color='red', marker='_', s=100, label='Max' if label == df['label'].unique()[0] else "")

ax[1].set_yscale('log')
ax[1].set_title('Lineplot with IQR Error Bars and Min/Max Outliers')
ax[1].legend()

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Plot the mean line
fig, ax  = plt.subplots(1, 2, figsize=(20, 10))

threshlimit = [0, 0]
for i in range(2):
    largest_diff = np.argmax(np.diff(df_datarel.loc[(df_datarel['label']==i) & (df_datarel['thresh'].isin(threshidx)), 'mean'].values))
    threshlimit[i] = df_datarel.loc[(df_datarel['label']==i) & (df_datarel['thresh'].isin(threshidx)), 'thresh'].values[largest_diff:largest_diff+1][0]
threshlimit = np.mean(threshlimit)

df = df_datarel.loc[(df_datarel['thresh'] < threshlimit) & (df_datarel['thresh'].isin(threshidx))]
sns.lineplot(data=df, x='thresh', y='mean', hue='label', marker='o', ax=ax[0])

# Add error bars for IQR
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[0].errorbar(subset['thresh'], subset['mean'],
                 yerr=[subset['Q1'], subset['Q3']],
                 fmt='o', capsize=5, color=sns.color_palette()[df['label'].unique().tolist().index(label)],
                 alpha=0.5)

# Mark min and max as outliers
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[0].scatter(subset['thresh'], subset['min'], color='red', marker='_', s=100, label='Min' if label == df['label'].unique()[0] else "")
    ax[0].scatter(subset['thresh'], subset['max'], color='red', marker='_', s=100, label='Max' if label == df['label'].unique()[0] else "")

# ax[0].set_yscale('log')
ax[0].set_title('Lineplot with IQR Error Bars and Min/Max Outliers')
ax[0].legend()

df = df_datarel.loc[df_datarel['thresh'] >= threshlimit]
sns.lineplot(data=df, x='thresh', y='mean', hue='label', marker='o', ax=ax[1])

# Add error bars for IQR
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[1].errorbar(subset['thresh'], subset['mean'],
                 yerr=[subset['Q1'], subset['Q3']],
                 fmt='o', capsize=5, color=sns.color_palette()[df['label'].unique().tolist().index(label)],
                 alpha=0.5)

# Mark min and max as outliers
for label in df['label'].unique():
    subset = df[df['label'] == label]
    ax[1].scatter(subset['thresh'], subset['min'], color='red', marker='_', s=100, label='Min' if label == df['label'].unique()[0] else "")
    ax[1].scatter(subset['thresh'], subset['max'], color='red', marker='_', s=100, label='Max' if label == df['label'].unique()[0] else "")

# ax[1].set_yscale('log')
ax[1].set_title('Lineplot with IQR Error Bars and Min/Max Outliers')
ax[1].legend()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=df_datarel, x='thresh', y='mean', hue='label', ax=ax)

# Add a twin axis
ax2 = ax.twinx()
sns.lineplot(data=df_data, x='thresh', y='count', hue='label', ax=ax2, linestyle='--')

# Customize the labels for clarity
ax.set_ylabel('Mean')
ax2.set_ylabel('Count')
ax.set_xlabel('Threshold')


In [None]:
df_plotnew = df_data[['label', 'thresh', 'median', 'count']].groupby(['label', 'thresh']).\
    aggregate('median').reset_index()
df_plotnew1 = df_data[['label', 'thresh', 'median', 'count']].groupby(['label', 'thresh']).\
    aggregate(lambda x: np.quantile(x, 0.75)).reset_index()
df_plotnew2 = df_data[['label', 'thresh', 'median', 'count']].groupby(['label', 'thresh']).\
    aggregate(lambda x: np.quantile(x, 0.25)).reset_index()

In [None]:
xvals = df_plotnew.loc[df_plotnew['label'] == 0, 'thresh'].values
yvals = df_plotnew.loc[df_plotnew['label'] == 0, 'median'].values

y1 = df_plotnew1.loc[df_plotnew1['label'] == 0, 'median'].values
y2 = df_plotnew2.loc[df_plotnew2['label'] == 0, 'median'].values

plt.plot(xvals, yvals, label='Label 0', lw=1.5, color='red')
plt.fill_between(xvals, y1, y2, alpha=0.3, label='IQR for Label 0', color='red')

xvals = df_plotnew.loc[df_plotnew['label'] == 1, 'thresh'].values
yvals = df_plotnew.loc[df_plotnew['label'] == 1, 'median'].values

y1 = df_plotnew1.loc[df_plotnew1['label'] == 1, 'median'].values
y2 = df_plotnew2.loc[df_plotnew2['label'] == 1, 'median'].values

plt.plot(xvals, yvals, label='Label 1', lw=1.5, color='blue')
plt.fill_between(xvals, y1, y2, alpha=0.3, label='IQR for Label 1', color='blue')


In [None]:
data

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=data, x='thresh', y='mean', hue='label', ax=ax)
ax.set_yscale('log')

# Add a twin axis
ax2 = ax.twinx()
sns.lineplot(data=data, x='thresh', y='count', hue='label', ax=ax2, linestyle='--')

# Customize the labels for clarity
ax.set_ylabel('Mean (log-scale)')
ax2.set_ylabel('Count')
ax.set_xlabel('Threshold')

ax.set_xlim(left = 9000)

In [None]:
plt.plot([np.min(mask) for mask in filtered_images], 'x')
plt.plot([np.max(mask) for mask in filtered_images], 'o')
plt.plot([np.mean(mask) for mask in filtered_images], '--')
plt.plot([np.quantile(mask, 0.5) for mask in filtered_images], '-')

In [None]:
df_data_max = df_data.groupby(['thresh']).max()['mean'].reset_index()
df_data_joined = df_data.merge(df_data_max, on='thresh', suffixes=('', '_max'))

df_data_joined.loc[:, 'mean_rel'] = df_data_joined['mean'] / df_data_joined['mean_max']
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=df_data_joined.loc[df_data['thresh'].isin(threshidx)], x='thresh', y='mean_rel', hue='label')
# ax.set_yscale('log')

In [None]:
df_data_max = df_data.groupby(['thresh']).max()['mean'].reset_index()
df_data_joined = df_data.merge(df_data_max, on='thresh', suffixes=('', '_max'))

df_data_joined.loc[:, 'mean_rel'] = df_data_joined['mean'] / df_data_joined['mean_max']
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=df_data_joined.loc[df_data['thresh'].isin(threshidx)], x='thresh', y='mean_rel', hue='label')
# ax.set_yscale('log')

## STED

In [None]:
import pickle

with open("region_data_sted.pkl", "rb") as file:
    data_loaded = pickle.load(file)

thresholds = data_loaded['thresholds'].copy()
numberfeatures = data_loaded['numberfeatures'].copy()
smalldistances = data_loaded['smalldistances'].copy()
totalsizes = data_loaded['totalsizes'].copy()
sizes = data_loaded['sizes'].copy()

del data_loaded

In [None]:
data = []
for i, size in tqdm(enumerate(sizes)):
    for ti, t in enumerate(thresholds):
        # fix the totalsize part, since we want to exclude the minimal image value
        if t <= np.min(filtered_images[i]):
            continue
        stats = compute_boxplot_quantiles(size[ti])
        if len(stats) > 0:
            data.append({
                'pos': i, 
                'label': labels[i],
                'thresh': t,
                'min': stats[0],
                'Q1': stats[1],
                'mean': stats[2],
                'median': stats[3],
                'Q3': stats[4],
                'max': stats[5],
                'count': len(size[ti])
            }.copy())

data = pd.DataFrame(data)
data.loc[:, 'totalsize'] = [np.count_nonzero(filtered_images[i] > np.min(filtered_images[i]))
                            for i in data['pos'].values]
df_data = data.groupby(['label', 'thresh']).mean().reset_index()
df_count = data.groupby(['label', 'thresh']).count().reset_index()

In [None]:
df_datarel = data.copy()
for col in ['min', 'Q1', 'mean', 'median', 'Q3', 'max']:
    df_datarel[col] = df_datarel[col].astype(np.float64) 
    df_datarel[col] = df_datarel[col] / df_datarel['totalsize'].astype(np.float64)

In [None]:
threshidx = set(df_count.loc[(df_count['mean'] > 1) & (df_count['label'] == 0), 'thresh'].values).\
    intersection(set(df_count.loc[(df_count['mean'] > 1) & (df_count['label'] == 0), 'thresh'].values))

df_count.loc[df_count['thresh'].isin(threshidx), :]

In [None]:
import seaborn as sns

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=data.loc[data['thresh'] >= 3, :],
             x='thresh', y='mean', hue='label', ax=ax)
ax.set_yscale('log')
# ax.set_xlim(left = 0)
# Add a twin axis

ax2 = ax.twinx()
# ax2.set_xlim(left = 0)
sns.lineplot(data=data.loc[data['thresh'] >= 3, :],
             x='thresh', y='count', hue='label', ax=ax2, linestyle='--')

# Customize the labels for clarity
ax.set_ylabel('Mean (log-scale, solid)')
ax2.set_ylabel('Count (dashed)')
ax.set_xlabel('Threshold')

ax.set_xlim(right = 60)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=df_datarel.loc[df_datarel['thresh'] >= 3, :],
             x='thresh', y='mean', hue='label', ax=ax)
# ax.set_yscale('log')

# Add a twin axis
ax2 = ax.twinx()
sns.lineplot(data=df_datarel.loc[df_datarel['thresh'] >= 3, :],
             x='thresh', y='count', hue='label', ax=ax2, linestyle='--')

# Customize the labels for clarity
ax.set_ylabel('Mean (solid)')
ax2.set_ylabel('Count (dashed)')
ax.set_xlabel('Threshold')

ax.set_xlim(right = 60)

In [None]:
from matplotlib.ticker import MaxNLocator

# Create the plot
fig, ax = plt.subplots(figsize=(10, 5))
sns.lineplot(data=df_datarel.loc[df_datarel['thresh'] >= 3, :],
             x='thresh', y='mean', hue='label', ax=ax)

# Add a twin axis
ax2 = ax.twinx()
sns.lineplot(data=df_datarel.loc[df_datarel['thresh'] >= 3, :],
             x='thresh', y='count', hue='label', ax=ax2, linestyle='--')

# Customize the labels for clarity
ax.set_ylabel('Mean (solid)')
ax2.set_ylabel('Count (dashed)')
ax.set_xlabel('Threshold')

# Set the x-axis limit
ax.set_xlim(right=60)

# # For the secondary y-axis (count), set ticks 5 units apart
start, end = ax2.get_ylim()
# Set the y-axis limits based on the data (more explicit control)
ax2.set_ylim(0, np.ceil(end / 10) * 10)
# Use MaxNLocator for automatic, sensible tick placement
ax2.yaxis.set_major_locator(MaxNLocator(integer=True, steps=[10]))

# # For the secondary y-axis (count), set ticks 5 units apart
start, end = ax.get_ylim()
# Set the y-axis limits based on the data (more explicit control)
ax.set_ylim(0, np.ceil(end / 10) * 10)
# Use MaxNLocator for automatic, sensible tick placement
ax.yaxis.set_major_locator(MaxNLocator(integer=True, steps=[10]))
ax.grid(True, which='both', linestyle='--', linewidth=0.5)



# Show the plot
plt.show()