In [1]:
import numpy as np
import pandas as pd
import tifffile as tiff
import zarr
import napari
import dask.array as da
import zarr
import os
import numpy as np
from pathlib import Path
from magicgui import magicgui
import distinctipy 
from matplotlib.colors import to_hex
from tribus import classify
import tribus
from tribus import run_tribus
from qtpy.QtWidgets import QHBoxLayout, QPushButton, QWidget

In [24]:
#Open Napari Viewer
viewer = napari.Viewer()

### Open whole slide image

In [None]:
@magicgui(call_button='Open image')
def open_image(impath = Path(), channel_list=Path()):
    image = tiff.TiffFile(impath, is_ome = False)
    z = zarr.open(image.aszarr(), mode='r')
    n_levels = len(image.series[0].levels) # pyramid

    # If and if not pyramids are available
    if n_levels > 1:
        pyramid = [da.from_zarr(z[i]) for i in range(n_levels)]
        multiscale = True
    else:
        pyramid = da.from_zarr(z)
        multiscale = False 
        
    if channel_list == '.':
        viewer.add_image(pyramid, multiscale=multiscale, channel_axis=0, visible=False, contrast_limits=(0,65535))
    else:
        list_df = pd.read_excel(channel_list)
        clist = list(list_df.loc[:, 'Channel_name'])
        viewer.add_image(pyramid, multiscale=multiscale, channel_axis=0, visible=False, contrast_limits=(0,65535), name=clist)

viewer.window.add_dock_widget(open_image)

In [22]:
#Load helper functions 
RED = (1, 0, 0)
GREEN = (0, 1, 0)
BLUE = (0, 0, 1)
WHITE = (1, 1, 1)
BLACK = (0, 0, 0)
MAGENTA = (1, 0, 1)
YELLOW = (1, 1, 0)
CYAN = (0, 1, 1)

excluded_colors = [RED, GREEN, BLUE, BLACK, WHITE, CYAN, MAGENTA]

def filter_and_transform(arr, selected_ids):
    filtered_arr = np.where(np.isin(arr, selected_ids), arr, 0)
    return np.where(filtered_arr != 0, 1, 0)

### Napari widget to run Tribus on one sample and visualize results
Celltype_mask: Visualizes assigned celltypes across the sample

Marker_intensity_mask: Visualizes selected marker intensity across the sample 

Probability_mask: Visualizes Tribus's probability score per cell type

mask path: Segmentation mask, sample_data: Quantification table, logic: Tribus logic table, output folder: User set folder to save results

In [23]:
@magicgui(call_button='Run Tribus', output_folder={"mode": "w", "mode": "d"})
def run_tribus_classify(celltype_mask: bool,  probability_mask: bool, marker_intensity_mask: bool, marker: str = 'type a marker', depth: int = 1, tuning: int=5, sigma=0.5, 
                        learning_rate=0.5, clustering_threshold: int=15000, undefined_threshold=0.01, other_threshold=0.4, mask_path = Path(), sample_data = Path(),logic = Path(), output_folder=Path()):

    print(celltype_mask, marker_intensity_mask, probability_mask)
    label_df = pd.ExcelFile(logic)
    label_data = pd.read_excel(label_df, label_df.sheet_names, index_col='Marker')
    level_list = list(range(depth))
    
    dat = pd.read_csv(sample_data)
    
    res, prob = run_tribus(input_df=dat, logic = label_data, depth = depth, normalization=None, tuning=tuning, sigma=sigma, learning_rate=learning_rate, 
                                                clustering_threshold=clustering_threshold, undefined_threshold=undefined_threshold,
                                                other_threshold=other_threshold, random_state=None)
    
    res['ID'] = dat['ID']
    prob['ID'] = dat['ID']

    res.to_csv(str(output_folder) + '/labels.csv')
    prob.to_csv(str(output_folder) + '/probability_scores.csv')
    
    ch=0
    image = tiff.TiffFile(mask_path)
    print(image)
    img = image.pages[ch].asarray()
    print(img)
    pyramid = [img[::4**i, ::4**i] for i in range(4) ]
    pyramid = [da.from_array(z) for z in pyramid]
    
    print(pyramid)
    
    print(f"Type of celltype_mask: {type(celltype_mask)}")

    if celltype_mask == True:
        c_table = res
        cts = c_table['final_label'].unique()
        ids = c_table['ID'].unique()
        filtered_cellIDs = c_table.loc[c_table['final_label'] == cts[0], 'ID'].values    

        colors = distinctipy.get_colors(len(cts), pastel_factor=0.6, exclude_colors = excluded_colors)
        hex_colors = [to_hex(color) for color in colors]

        for celltype in range(len(cts)):
            color = list(colors[celltype])
            color.append(1)
            new_cmap = napari.utils.colormaps.Colormap(colors = np.array([[0.0, 0.0, 0.0, 0.0],color]), name = hex_colors[celltype])
            print(cts[celltype])
            filtered_cellIDs = c_table.loc[c_table['final_label'] == cts[celltype], 'ID'].values 
            filtered_pyramid = [da.map_blocks(filter_and_transform, layer, filtered_cellIDs, dtype=np.uint8) for layer in pyramid]
            viewer.add_image(filtered_pyramid, colormap= new_cmap, name= cts[celltype] + '_mask', blending = 'translucent', contrast_limits=(0,1))  

    if marker_intensity_mask == True:
        marker_to_display = marker
        print(marker_to_display)
        np.set_printoptions(suppress=True)
        keys = dat.iloc[:,0]
        values = dat.loc[:, marker_to_display]
        d = dict(zip(keys, values))
        markermask = np.vectorize(d.get)(pyramid[0], 0.0)
        viewer.add_image(markermask, name=(marker_to_display + '_intensity') ,colormap='magma')

    if probability_mask == True:
        prob_table = prob
        sheet_names = label_df.sheet_names  # Assuming these are the columns representing the cell types
    
        # Loop over each sheet name (corresponding to cell types)
        for sheet_name in sheet_names:
            print(f"Processing {sheet_name}...")
            
            # Get the column corresponding to the current sheet name (cell type)
            prob_column = prob_table[sheet_name]
            
            # Create a dictionary mapping cell IDs (where probability is not NA) to their probability values
            valid_ids = prob_table.loc[~prob_column.isna(), 'ID'].values  # Extract valid cell IDs
            probabilities = prob_table.loc[~prob_column.isna(), sheet_name].values  # Extract corresponding probabilities
            
            # Map the valid cell IDs to their probabilities
            prob_dict = dict(zip(valid_ids, probabilities))
            
            # Create a mask: replace cell IDs in the original mask with their corresponding probabilities
            # Use np.vectorize to apply the dictionary lookup element-wise
            prob_mask = np.vectorize(prob_dict.get)(pyramid[0], 0.0)  # pyramid[0] is the highest resolution layer
            
            # Visualize the mask with the 'magma' colormap (for intensity display)
            viewer.add_image(prob_mask, name=(sheet_name + '_intensity'), colormap='magma', blending='translucent', contrast_limits=(0, 1))
    
viewer.window.add_dock_widget(run_tribus_classify)
    
    

<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x22f018a3920>

False False True
Global, subsetting done
Start hyperparameter tuning. 
x is 24
Current quantization error is 112530.50118092149                                                                       
100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.75trial/s, best loss: 112530.50118092149]
best: {'learning_rate': np.float64(3.6624602028328934), 'sig': np.float64(1.0539935952425987)}
Current grid size x is 24, grid size y is 24, sigma is 1.0539935952425987, learning rate is 3.6624602028328934.
less than min sample_size
Tumor, subsetting done
Start hyperparameter tuning. 
x is 24
Current quantization error is 112530.48529765345                                                                       
100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.81trial/s, best loss: 112530.48529765345]
best: {'learning_rate': np.float64(2.4869357363196047), 'sig': np.float64(4.6630679470753975)}
Current grid size x is 24, grid size y is 24, sigma 

### Visualize existing data

In [None]:
#Visualize cell types
@magicgui(call_button = 'Show celltype masks')
def show_celltypes(mask_path = Path(), labels = Path()):

    c_table = pd.read_csv(labels)
    cts = c_table['final_label'].unique()
    ids = c_table['ID'].unique()
    filtered_cellIDs = c_table.loc[c_table['final_label'] == cts[0], 'ID'].values 

    channel=0
    image = tiff.TiffFile(mask_path)
    img = image.pages[channel].asarray()
    pyramid = [img[::4**i, ::4**i] for i in range(4) ]
    pyramid = [da.from_array(z) for z in pyramid]

    colors = distinctipy.get_colors(len(cts), pastel_factor=0.6, exclude_colors = excluded_colors)
    hex_colors = [to_hex(color) for color in colors]
    
    for celltype in range(len(cts)):
        color = list(colors[celltype])
        color.append(1)
        new_cmap = napari.utils.colormaps.Colormap(colors = np.array([[0.0, 0.0, 0.0, 0.0],color]), name = hex_colors[celltype])
        print(cts[celltype])
        filtered_cellIDs = c_table.loc[c_table['final_label'] == cts[celltype], 'ID'].values 
        filtered_pyramid = [da.map_blocks(filter_and_transform, layer, filtered_cellIDs, dtype=np.uint8) for layer in pyramid]
        viewer.add_image(filtered_pyramid, colormap= new_cmap, name= cts[celltype] + '_mask', blending = 'translucent', contrast_limits=(0,1))  

viewer.window.add_dock_widget(show_celltypes)

In [26]:
#Visualize marker intensities
@magicgui(call_button = 'Show celltype masks')
def show_marker_intensities(marker: str = 'type a marker', mask_path = Path(), quantification = Path()):

        channel=0
        image = tiff.TiffFile(mask_path)
        img = image.pages[channel].asarray()
        pyramid = [img[::4**i, ::4**i] for i in range(4) ]
        pyramid = [da.from_array(z) for z in pyramid]

        dat = pd.read_csv(quantification)
    
        marker_to_display = marker
        print(marker_to_display)
        np.set_printoptions(suppress=True)
        keys = dat.iloc[:,0]
        values = dat.loc[:, marker_to_display]
        d = dict(zip(keys, values))
        markermask = np.vectorize(d.get)(pyramid[0], 0.0)
        viewer.add_image(markermask, name=(marker_to_display + '_intensity') ,colormap='magma')

viewer.window.add_dock_widget(show_marker_intensities)

<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x22f10cc1910>

CD8a
E-cadherin
