In [None]:
import pandas as pd
import os
import numpy as np
import warnings
import matplotlib.pyplot as plt
from scipy.stats import zscore
import seaborn as sns


from scmultiplex.plotting.functions_io import (zarr_wellpaths, 
                          append_table_to_zarr_url, 
                          load_features_for_well, 
                          make_anndata, 
                          invert_conditions_dict,
                          make_object_dict,
                          randomize_object_dict,
                          load_imgs_from_object_dict,
                          make_filtered_dict)

from scmultiplex.plotting.functions_plotting import (plot_heatmap, build_heatmap_df, plot_heatmap_means, plot_image_grid, 
                                plot_rgb_grid, plot_single_image, plot_single_rgb,  plot_pos_and_neg_sets, 
                                count_positive_fraction, plot_positive_fraction, plot_feature_violin)


from scmultiplex.plotting.functions_classify import classify_me

pd.set_option("display.max_columns", 100)

husl = sns.color_palette("husl", 9).as_hex()[0:8] #remove last color

print(plt.style.available)
plt.style.use('dark_background')
plt.style.use('dark_background')

# USER INPUT

In [None]:
from configpath import exp_path, classifier_path

print('%s\n%s' % (exp_path, classifier_path))

# name of feature extraction tables folder, as specified in Fractal run
table_name = "org_feat_table"

# name of ROI tables folder, as specified in Fractal run
roi_name = "org_ROI_table"


In [None]:
# Also take note to modify 'conditions' settings under "Visualize images prior to filtering" section


# END USER INPUT

## Aggregate data from all plates and wells in experiment

In [None]:
df = pd.DataFrame()

zarr_url_dict, plate_ids, well_ids, row_ids, col_ids = zarr_wellpaths(exp_path, select_mip = True, make_zarr_url = True)
zarr_url_tables_dict = append_table_to_zarr_url(zarr_url_dict, table_name)


for key in zarr_url_tables_dict:
    path = zarr_url_tables_dict[key]
    
    if os.path.exists(path):
        df_well = load_features_for_well(path)
        if df_well is not None:
            df_well["plate_id"] = plate_ids[key]
            df_well["well_id"] = well_ids[key]
            df_well["row_id"] = row_ids[key]
            df_well["col_id"] = col_ids[key]
            df = pd.concat([df_well, df])
    else:
        warnings.warn('no feature extraction detected in plate %s well %s' %(plate_id, well_id))



In [None]:
print('detected ', df.shape[0], ' organoids and ', df.shape[1], ' feature columns')
df.head(5)



## Run classifier

In [None]:
#df = df.reset_index(drop=True)
df['roi_id'] = df["plate_id"] + "_" + df["well_id"] + "_" + df["label"].astype(str)

In [None]:
df_predicted, new_prediction, class_names = classify_me(df, classifier_path, 'roi_id')
df_predicted

## Convert aggregated organoid df into AnnData object and save as H5AD

In [None]:
df['oUID'] = df["plate_id"] + "_" + df["well_id"] + "_" + df["label"].astype(str)

df['oUID_tuple'] = list(zip(df.plate_id, df.well_id, df.label.astype(str)))

df = df.set_index('oUID')


In [None]:
org_numerics_list = ['x_pos_pix', 'y_pos_pix', 'imgdim_x', 'imgdim_y', 
                     'mean_intensity', 'max_intensity', 'min_intensity', 
                     'percentile25', 'percentile50', 'percentile75', 'percentile90', 'percentile95', 'percentile99',
                     'stdev', 'skew', 'kurtosis',
                     'x_pos_weighted_pix', 'y_pos_weighted_pix', 'x_massDisp_pix', 'y_massDisp_pix',
                     'area_bbox', 'area_convhull', 'equivDiam', 'extent', 'solidity',
                     'majorAxisLength', 'minorAxisLength', 'minmajAxisRatio', 
                     'aspectRatio_equivalentDiameter', 'area_pix', 'perimeter', 'concavity', 
                     'asymmetry', 'eccentricity', 'circularity', 'concavity_count'
                     ]

org_obs_list = ["label", "ROI_table_name", "ROI_name", "index", 
                "is_touching_border_xy", "disconnected_components", 
                "plate_id", "well_id", "col_id", "row_id"]


adata = make_anndata(df, org_numerics_list, org_obs_list)



In [None]:
adata.write(filename = os.path.join(exp_path, 'org.h5ad'))




## Visualize images prior to filtering

# USER INPUT

## Make conditions table

In [None]:
# key is already in zarr_url_dict in format tuple(plate_id, well_id)
print('example of key: ')
list(zarr_url_dict.keys())[0]

In [None]:
# for each well (key), set a condition name (value). condition names can repeat.
# format must be
# unique well id : condition id

# choose from...
# plate_ids, well_ids, row_ids, col_ids

# key is already in zarr_url_dict in format ()

# for plate layout where condition replicates are along columns
conditions = {key: plate_ids[key][-5:] + "." + col_ids.get(key, '') for key in zarr_url_dict.keys()}

conditions


In [None]:
# also set condition as column in DF


df['condition'] = df['plate_id'].str[-5:] + "." + df['col_id'].astype(str)

df.head(2)

# END USER INPUT

In [None]:
inv_cond = invert_conditions_dict(conditions)

objects_to_randomize = make_object_dict(inv_cond, zarr_url_dict, roi_name)

objects_randomized = randomize_object_dict(objects_to_randomize, n_obj = 6, seed = 4)


In [None]:
objects_randomized



In [None]:
# load random image set
c01_all_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 0,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)

c02_all_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 1,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)

c03_all_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 2,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)
c04_all_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 3,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)



In [None]:
plot_image_grid(c01_all_dict)

In [None]:
plot_single_image(c01_all_dict, cond = 'd3-P1.02', plate_id = '20230712-d3-P1', well_id = 'D02', org_id = '32')



## Visualize RGB

In [None]:
plot_rgb_grid(c03_all_dict, c02_all_dict, c01_all_dict, ncols=None, min_quantile = 0, max_quantile = 0.9,
             global_norm = False, auto_range = True, ranges = ())

## Visualize heatmap

# USER INPUT

In [None]:
# add timepoint column to dataframe; might need to modify parsing!
df["timepoint"] = df['plate_id'].str.split('-', 2, expand=True)[1]
df.head(2)

# END USER INPUT

## Heatmap plate visualization: number of organoids per plate

In [None]:
for plate in np.unique(df['plate_id']):
    df_hm = build_heatmap_df(plate_size = 96)
    df_plt = df[(df.plate_id == plate)].copy(deep = True)
    df_plt['count'] = 1
    grouped = df_plt.groupby(["well_id"])['count'].count().to_frame()
    
    for well in grouped.index:
        df_hm.loc[well[0], well[1:]] = grouped.loc[well]['count']
    
    vmin = min(df_hm.min().dropna())
    vmax = max(df_hm.max().dropna())
    hm = plot_heatmap(df_hm, 'viridis', annot = True, vmin = vmin, vmax = vmax)
    
    hm.set_title(plate + "\n", loc = 'left')
    plt.subplots_adjust(top = 0.6)
                


## Jitterplot visualization: organoid features per timepoint

In [None]:
plt.figure(figsize= (16,10))


plt.subplot(2, 3, 1)
cx = sns.stripplot(x="timepoint", y="area_pix", data=df, size=3, palette = husl)
plt.title("area_pix", fontsize=12)

plt.subplot(2, 3, 2)
cx = sns.stripplot(x="timepoint", y="circularity", data=df, size=3, palette = husl)
plt.title("circularity", fontsize=12)

plt.subplot(2, 3, 3)
cx = sns.stripplot(x="timepoint", y="disconnected_components", data=df, size=3, palette = husl)
plt.title("disconnected_components", fontsize=12)

plt.subplot(2, 3, 4)
cx = sns.stripplot(x="timepoint", y="C01.mean_intensity", data=df, size=3, palette = husl)
plt.title("C01.mean_intensity", fontsize=12)

plt.subplot(2, 3, 5)
cx = sns.stripplot(x="timepoint", y="is_touching_border_xy", data=df, size=3, palette = husl)
plt.title("is_touching_border_xy", fontsize=12)

plt.subplot(2, 3, 6)
cx = sns.stripplot(x="timepoint", y="C02.mean_intensity", data=df, size=3, palette = husl)
plt.title("C02.mean_intensity", fontsize=12)




plt.show()

## Filter organoids by features, per timepoint

# USER INPUT

In [None]:
features_filt = ['area_pix','circularity', 'C01.mean_intensity']

#initialize dictionary for storing desired quantiles
q = {} 
for tp in df["timepoint"].unique():
    q[tp] = {}
    for feat in features_filt:
        if feat == 'area_pix':
            q[tp][feat] = [0.01, 0.999] 
        elif feat == 'circularity':
            if tp == 'd5':
                q[tp][feat] = [0., 0.85] # remove objects with high circularity at later tps
            else:
                q[tp][feat] = [0., 1.]
        elif feat == 'C01.mean_intensity':
            q[tp][feat] = [0.02, 1.0] 
        else: 
            q[tp][feat] = [0.05, 1.0]




# END USER INPUT

In [None]:
##Calculate z-score by plate or timepoint
for feat in features_filt:
    df[feat+"_z"] = df.groupby(["timepoint"])[feat].transform(lambda x : zscore(x,ddof=0))

##Filter based on z_score 
df_filtered = pd.DataFrame()
org_to_omit_q = []

inv_tp = {}

for tp in df["timepoint"].unique():
    df_tp = df.loc[(df["timepoint"] == tp)]
    for feat in features_filt:
        tp_org_to_omit_q=[]
        #quantile based on dictionary value specified above, unique quantile for each tp and filter
        qval1=np.quantile(df_tp[feat+"_z"],q[tp][feat][0])
        qval2=np.quantile(df_tp[feat+"_z"],q[tp][feat][1])
        #save qval in quantile dictionary for plotting
        q[tp][feat].append(qval1)
        q[tp][feat].append(qval2)
        temp_removed = df_tp.loc[(df_tp[feat+"_z"]<qval1) | (df_tp[feat+"_z"]>qval2)]
        
        org_to_omit_q.append(temp_removed["oUID_tuple"].unique())
        tp_org_to_omit_q.append(temp_removed["oUID_tuple"].unique())

        tp_org_to_omit_q= np.unique(np.concatenate(tp_org_to_omit_q))
        # add to dictionary
        inv_tp[tp + "_" + feat] = tp_org_to_omit_q
    
        
        print("Omitted", len(temp_removed["oUID_tuple"].unique()), "organoids based on", feat, "in timepoint", tp)


#List of organoids to remove
org_to_omit_q = np.unique(np.concatenate(org_to_omit_q))

df_r = df[df["oUID_tuple"].isin(org_to_omit_q)] #dataframe of removed organoids
#display(df1_r)

#print("These ", len(org_to_omit_q), "organoid_IDs have been removed during quantile filter:", org_to_omit_q)
     

In [None]:
#Plot histograms of features and cutoffs 
for feat in features_filt:
    #graph histogram
    plt.figure(figsize=(9,6))
    sns.swarmplot(x="timepoint", y=feat+"_z", data=df, size =3, palette = husl)
    plt.title(feat+"_z", fontsize=12)
    
    for n,tp in enumerate(df["timepoint"].unique()):
        m = len(df["timepoint"].unique())
        plt.axhline(q[tp][feat][2], xmin=(n/m+(0.1/m)), xmax=(n/m+(0.9/m)), color = 'r')
        plt.axhline(q[tp][feat][3], xmin=(n/m+(0.1/m)), xmax=(n/m+(0.9/m)), color = 'r')

## Plot removed organoids

In [None]:
objects_randomized = randomize_object_dict(inv_tp, n_obj = 6, seed = 3)
#objects_randomized


In [None]:
filt_npimg_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 0,
                                            level=0,
                                            roi_name = roi_name, 
                                            reset_origin=False)

In [None]:
plot_image_grid(filt_npimg_dict)

## Remove organoids from source and plot cleaned-up dataset

# USER INPUT

In [None]:

#Remove organoids that are in removal list
df_filtered = df.drop(df[df["oUID_tuple"].isin(org_to_omit_q)].index)



# remove organoids that are positive for these:           
df_filtered.drop(df_filtered[df_filtered['disconnected_components'] == 1.0].index, inplace = True)
df_filtered.drop(df_filtered[df_filtered['is_touching_border_xy'] == 1.0].index, inplace = True)

In [None]:
# drop any unwanted conditions (only from DF, not from image plotting objects!)
#df_filtered.drop(df_filtered[df_filtered['condition'] == "d5-P2.05"].index, inplace = True)

# END USER INPUT

In [None]:
# remove from filtered organoids plotting dictionary
all_objects = make_object_dict(inv_cond, zarr_url_dict, roi_name)

objects_filtered = make_filtered_dict(all_objects, org_to_omit_q, omit_my_list = True)


In [None]:
# randomize and load images

objects_randomized = randomize_object_dict(objects_filtered, n_obj = 6, seed = 9)

c01_filt_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 0,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)

c02_filt_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 1,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)

c03_filt_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 2,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)
c04_filt_dict = load_imgs_from_object_dict(objects_randomized,
                                            zarr_url_dict,
                                            channel_index = 3,
                                            level=1,
                                            roi_name = roi_name, 
                                            reset_origin=False)


In [None]:
plot_rgb_grid(c03_filt_dict, c02_filt_dict, c01_filt_dict, ncols=None, min_quantile = 0, max_quantile = 0.9,
                      global_norm = True, auto_range = True, ranges = ())

In [None]:
c03_filt_dict

In [None]:
plot_single_rgb(c03_filt_dict, c02_filt_dict, c01_filt_dict, cond = 'd3-P1.02', 
                plate_id = '20230712-d3-P1', well_id = 'C02', org_id = '56')

 

## Plot organoid-level feature data across conditions

In [None]:
df_filtered.head(2)

In [None]:
plot_heatmap_means(df_filtered, feature = 'C03.mean_intensity', plate_size = 96, vmax_multiplier=0.5)

In [None]:
plot_feature_violin(df_filtered, colname = 'C03.mean_intensity')

## Filter positive/negative organoids with threshold cutoff and plot 

In [None]:
df_filtered, grouped = count_positive_fraction(df_filtered, colname = 'C03.mean_intensity', thresh = 2000)
plot_positive_fraction(grouped)


In [None]:
plot_pos_and_neg_sets(df_filtered, grouped, inv_cond, 
                          zarr_url_dict, roi_name, n_obj=6, seed=3, level=1, 
                          min_quantile=0, max_quantile=0.88,
                          r_ch_idx =2, g_ch_idx=1, b_ch_idx=0)

## Repeat with another marker

In [None]:
plot_heatmap_means(df_filtered, feature = 'C02.max_intensity', plate_size = 96, vmax_multiplier=0.1)

In [None]:
#df_filtered['C02.sum_intensity'] = df_filtered['area_pix'] * df_filtered['C02.mean_intensity']
plot_feature_violin(df_filtered, colname = 'C02.max_intensity') 

In [None]:
df_filtered, grouped = count_positive_fraction(df_filtered, colname = 'C02.max_intensity', thresh = 5000)

plot_positive_fraction(grouped)

In [None]:
plot_pos_and_neg_sets(df_filtered, grouped, inv_cond, 
                          zarr_url_dict, roi_name, n_obj=6, seed=2, level=1, 
                          min_quantile=0, max_quantile=0.88,
                          r_ch_idx =2, g_ch_idx=1, b_ch_idx=0)

In [None]:
df_filtered['area_diff'] = df_filtered['area_convhull'] - df_filtered['area_pix']

In [None]:
plot_feature_violin(df_filtered, colname = 'area_diff')