In [None]:
import re
import os

import tifffile
from tifffile import TiffFile
import scanpy as sc
import pandas as pd
import numpy as np
import anndata
import matplotlib.pyplot as plt
import mgitools.os_helpers as os_helpers
import cv2

from skimage.io import imread
from skimage import exposure
from deepcell.applications import MultiplexSegmentation
from ome_types import from_tiff, from_xml


#### convert mcd

In [None]:
from imctools.io.mcd.mcdparser import McdParser

In [None]:
fn_mcd = '../data/hyperion/b3_09012020/mcd/HT122P1 S1H3 L1 L4 090320.mcd'

parser = McdParser(fn_mcd)

In [None]:
session = parser.session

# Get all acquisition IDs
ids = parser.session.acquisition_ids

# The common class to represent a single IMC acquisition is AcquisitionData class.
# Get acquisition data for acquisition with id 2
ids

In [None]:
ac_data = parser.get_acquisition_data(1)

In [None]:
ac_data.channel_labels

In [None]:
ac_data = parser.get_acquisition_data(4)

In [None]:

ac_data.save_ome_tiff('../data/john_will_grant_imageing/test.ome.tiff',)

In [None]:
t = tifffile.TiffFile('../data/john_will_grant_imageing/test.ome.tiff')
t.ome_metadata

In [None]:
ac_data

In [None]:
ac_data.channel_labels

In [None]:
ac_data.acquisition.metadata

In [None]:
ac_data.acquisition.roi_start_x_pos_um - ac_data.acquisition.roi_end_x_pos_um

In [None]:
ac_data.image_data

In [None]:
ac_data.acquisition.

#### inspect tiffs

In [None]:
import mgitools.os_helpers as os_helpers
fps = sorted(os_helpers.listfiles('../data/hyperion/b1_01042020/ome-tiff_htan/', regex=r'.tiff$'))
fps = [fp for fp in fps if 'HT056' in fp]
fps

In [None]:
from ome_types import from_tiff, from_xml
for fp in fps:
    print(fp)
    header = from_tiff(fp)
    print(header.images[0])
    

#### extract nuclear and membrane markers

###### multichannel extraction (hyperion)

In [1]:
def get_channels(fp):   
    tif = TiffFile(fp)
    ome = from_xml(tif.ome_metadata)
    im = ome.images[0]
    return [channel.name for channel in im.pixels.channels]

def extract_htan_ome_tiff(fp):   
    tif = TiffFile(fp)
    ome = from_xml(tif.ome_metadata)
    im = ome.images[0]
    d = {}
    for c, p in zip(im.pixels.channels, tif.pages):
        img = p.asarray()
        d[c.name] = img
    return d

In [None]:
# def extract_hyperion(fp, ome_source='halo'):
#     tif = TiffFile(fp)
#     d = {}
#     if ome_source == 'halo':
#         for page in tif.pages:
#             img = page.asarray()
#             channel = re.sub(r'^.*Name..(.*)SamplesPerPixel.*$', r'\1', page.description.replace('\r\n', ''))
#             channel = re.sub(r'^[0-9]+[A-Z][a-z]?-?(.*)\(.*$', r'\1', channel)
#             channel = re.sub(r'^[A-Z][a-z]?[0-9]+-?(.*)\(.*$', r'\1', channel)
#             d[channel] = img
#         d.pop('')
        
        
#     else:
#         description = tif.pages[0].description
#         metadatas = re.findall(r'<Channel [^>]*>', p.description)
#         for i, page in enumerate(tif.pages):
#             img = page.asarray()
#             m = metadatas[i]
#             channel = re.sub(r'^.* Name..(.*). SamplesPerPixel.*$', r'\1', m)
#             channel = re.sub(r'^.*-(.*)$', r'\1', channel)
#             if channel and channel[0].isdigit(): channel = ''
# #             channel = re.sub(r'^[0-9]+[A-Z][a-z]?-?(.*)\(.*$', r'\1', channel)
# #             channel = re.sub(r'^[A-Z][a-z]?[0-9]+-?(.*)\(.*$', r'\1', channel)
#             d[channel] = img
#         if '' in d: d.pop('')
        
#     return d
    
# def extract_codex(fp):
#     tifs = os_helpers.listfiles(fp, regex='.tif')
#     channel_to_tif = {fp.split('/')[-1].split('.')[0].split('_')[-1]:tifffile.imread(fp) for fp in tifs}
#     return channel_to_tif
    
    

In [None]:
# CHANNEL_MAP = {
    
# }

# def rename_channels(channel_to_img):
#     channel_to_img = {c.lower():v for c, v in channel_to_img.items()}
#     channel_to_img = {CHANNEL_MAP.get(c, c):v for c, v in channel_to_img.items()}
#     return channel_to_img

In [None]:
def merge_channels(channel_to_img, channels):
    img = None
    for c in channels:
        X = np.expand_dims(channel_to_img[c], 0)
        if img is None:
            img = X
        else:
            X = np.concatenate((img, X))
    return np.mean(img, axis=0)

###### one channel tif extraction (codex)

In [None]:
tif = TiffFile(
    '/home/estorrs/imaging/data/codex/codex_processed_hu_pancreas_TrisEDTA_2020-10-08/reg001_cyc001_ch001_DAPI1.tif')
tif

In [None]:
img = tifffile.imread(
    '/home/estorrs/imaging/data/codex/codex_processed_hu_pancreas_TrisEDTA_2020-10-08/reg001_cyc001_ch001_DAPI1.tif')
img

In [None]:
p.description

In [None]:
re.findall(r'<Channel [^>]*>', p.description)

In [None]:
len(tif.pages)

In [None]:
def extract_codex(fp):
    fps = sorted(os_helpers.listfiles(fp, regex='.tif'))
    
    

##### hyperion

In [None]:
# hyperior_folder = '/home/estorrs/imaging/data/hyperion/b1_01042020/ome-tiff_htan/'
hyperior_folder = '/home/estorrs/imaging/data/hyperion/b4_10062020/ome-tiff_htan/'
# hyperior_folder = '../data/john_will_grant_imageing/'

In [None]:
fps = sorted(os_helpers.listfiles(hyperior_folder, regex='.tiff$'))
sample_fp_tups = []
for fp in fps:
    sample = fp.split('/')[-1].replace('.ome.tiff', '')
    sample_fp_tups.append((sample + '_b1', fp))
len(sample_fp_tups), sample_fp_tups[:5]

In [None]:
# just keep a few for now
keep = ['HT_077B1_S1H1_A1_A4_100620_ROI_A4_b1']
sample_fp_tups = [(s, fp) for s, fp in sample_fp_tups if s in keep]
sample_fp_tups

In [None]:
sample_fp_tups[0] = ('PDAC_1', sample_fp_tups[0][1])
sample_fp_tups[1] = ('PDAC_2', sample_fp_tups[1][1])
sample_fp_tups[2] = ('PDAC_3', sample_fp_tups[2][1])

In [None]:
get_channels(sample_fp_tups[0][1])

In [None]:
# NUCLEAR = ['HistoneH3']
NUCLEAR = ['Histone H3']
MEMBRANE = ['cellseg2']
# MEMBRANE = ['ICSK1', 'ICSK2', 'ICSK3']

sample_to_imgs = {}
for sample, fp in sample_fp_tups:
#     channel_to_img = extract_hyperion(fp, ome_source='a')
    channel_to_img = extract_htan_ome_tiff(fp)
#     print(channel_to_img.keys())
    nuclear_img = merge_channels(channel_to_img, NUCLEAR)
    membrane_img = merge_channels(channel_to_img, MEMBRANE)
    channel_to_img['nuclear'] = nuclear_img
    channel_to_img['membrane'] = membrane_img
    sample_to_imgs[sample] = channel_to_img    

##### codex

In [None]:
codex_folder = '/home/estorrs/imaging/data/codex/adult_kidney/ome-tiff_htan/'

In [None]:
sample_fp_tups = [('adult_kidney', os.path.join(codex_folder, 'adult_kidney.ome.tiff'))]

In [None]:
NUCLEAR = ['dapi1']
MEMBRANE = ['CD3']

sample_to_imgs = {}
for sample, fp in sample_fp_tups:
    channel_to_img = extract_codex(fp)
    nuclear_img = merge_channels(channel_to_img, NUCLEAR)
    membrane_img = merge_channels(channel_to_img, MEMBRANE)
    channel_to_img['nuclear'] = nuclear_img
    channel_to_img['membrane'] = membrane_img
    sample_to_imgs[sample] = channel_to_img    

In [None]:
## from the new tiffs
sample_to_imgs = {}
for sample, fp in sample_fp_tups:
    tif = TiffFile(fp)
    m = re.sub(r'<Acquis.*AcquisitionDate>', r'', tif.ome_metadata)
    ome = from_xml(m)
    im = ome.images[0]
    channel_to_img = {}
    for channel, page in zip(im.pixels.channels, tif.pages):
        channel_to_img[channel.name] = page.asarray()
    sample_to_imgs[sample] = channel_to_img

#### display images

In [None]:
next(iter(sample_to_imgs.values())).keys()

In [None]:
# red, green, light blue, blue, purple, yellow, white
DEFAULTS = ['#ff0000', '#00ff04', '#00fff7', '#0008ff', '#ff00e1', '#fbff00', '#ffffff']
def hex_to_rgb(h):
    h = h.replace('#', '').lower()
    return tuple(int(h[i:i+2], 16) / 255. for i in (0, 2, 4))

from skimage import color


In [None]:
def add_scale(img, mpm=1):
    x = img.shape[1] - 210
    y = img.shape[0] - 20
    
    x1, x2 = int(x), int(x + 200)
    y1, y2 = int(y), int(y + 5)
    
    for r in range(img.shape[0]):
        for c in range(img.shape[1]):
            if r > y1 and r <= y2 and c > x1 and c <= x2:
                img[r, c, 0] = 1.
                img[r, c, 1] = 1.
                img[r, c, 2] = 1.
    plt.text(x, y - 10, '200 um', fontsize=12, color='white')
    return img
    

In [None]:
im = next(iter(sample_to_imgs.values()))['Pan Keratin'].copy()
im = ((im / max(im.flatten())) * 255).astype(np.uint8)
plt.imshow(im)

In [None]:
plt.imshow((exposure.equalize_adapthist(im, clip_limit=.03) * 255).astype(np.uint8))

In [None]:
import seaborn as sns
sns.distplot(im[im>5], bins=100)
plt.xscale('log')
plt.xlim((5, 255))

In [None]:
plt.imshow(cv2.fastNlMeansDenoising(im, ))

In [None]:
from skimage.filters import gaussian
g = gaussian(im)
print(np.min(g), np.max(g))
plt.imshow((gaussian(im)*255).astype(np.uint8))

In [None]:
idxs = np.argsort(im.flatten())
mark = im.flatten()[idxs[int(.999 * len(idxs))]]
print(mark)
img = im.copy()
img[img>=mark] = mark
plt.imshow(((img / max(img.flatten())) * 255).astype(np.uint8))

In [None]:
def convert_channels_raw(channel_to_img):
    new = {}
    for channel, img in channel_to_img.items():
        im = img.copy()
        if not isinstance(img, np.uint8):
            new[channel] = ((im / max(im.flatten())) * 255).astype(np.uint8)
    return new

def convert_channels_CLAHE(channel_to_img, clip_limit=.05):
    new = {}
    for channel, img in channel_to_img.items():   
        im = img.copy()
        if not isinstance(img, np.uint8):
            im = ((im / max(im.flatten())) * 255).astype(np.uint8)
        try:
            new[channel] = (exposure.equalize_adapthist(im, clip_limit=clip_limit) * 255).astype(np.uint8)
        except ZeroDivisionError:
            new[channel] = im
            print(f'channel {channel} failed')
    return new

def convert_channels_gaussian(channel_to_img):
    new = {}
    for channel, img in channel_to_img.items():
        im = img.copy()
        if not isinstance(img, np.uint8):
            im = ((im / max(im.flatten())) * 255).astype(np.uint8)
        new[channel] = (gaussian(im) * 255).astype(np.uint8)
    return new

def convert_channels_threshold(channel_to_img, thresh=.99):
    new = {}
    for channel, img in channel_to_img.items():
        # get 99th percentile
        idxs = np.argsort(img.flatten())
        mark = img[idxs[int(thresh * len(idxs))]]
        
        im = img.copy()
        im[im>=mark] = mark
        new[channel] = ((im / max(im.flatten())) * 255).astype(np.uint8)
    return new

In [None]:
# convert channels
sti_raw, sti_CLAHE, sti_gaussian, sti_CLAHE_gaussian, sti_gaussian_CLAHE = {}, {}, {}, {}, {}
for sample, channel_to_imgs in sample_to_imgs.items():
    sti_raw[sample] = convert_channels_raw(channel_to_imgs)
    sti_CLAHE[sample] = convert_channels_CLAHE(channel_to_imgs)
    sti_gaussian[sample] = convert_channels_gaussian(channel_to_imgs)
    sti_gaussian_CLAHE[sample] = convert_channels_gaussian(sti_CLAHE[sample])
    sti_CLAHE_gaussian[sample] = convert_channels_CLAHE(sti_gaussian[sample])
        

In [None]:
def colorize_bw(img, rgb):
    grayscale_image = img / max(img.flatten())
    image = color.gray2rgb(grayscale_image)
    
    colored = rgb * image
    return colored

def display_image(channel_to_image, channels):
    channel_to_color = {ch:hex_to_rgb(c) for ch, c in channels.items()}
    final = None
    for c, color in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        if final is None:
            final = colorized
        else:
            final += colorized
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
    return final

def display_image_CLAHE(channel_to_image, channels, clip_limit=.03):
    channel_to_color = {ch:hex_to_rgb(c) for ch, c in channels.items()}
    final = None
    for c, color in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        if final is None:
            final = colorized
        else:
            final += colorized
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
    final = (exposure.equalize_adapthist(final, clip_limit=clip_limit) * 255).astype(np.uint8)
    
    return final

def display_image_gamma(channel_to_image, channels):
    if isinstance(channels, dict):
        channel_to_color = {ch:(hex_to_rgb(c), g) for ch, (c, g) in channels.items()}
    else:
        channel_to_color = {c:(hex_to_rgb(DEFAULTS[i]), .5) for i, c in enumerate(channels)}
    
    final = None
    for c, (color, g) in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        img = exposure.adjust_gamma(colorized, gamma=g)
#         p1, p2 = np.percentile(img, (0, 100))
#         img = exposure.rescale_intensity(img, in_range=(p1, p2))
        
        if final is None:
            final = img
        else:
            final += img
#     p1, p2 = np.percentile(final, (0, 100))
#     final = exposure.rescale_intensity(final, in_range=(p1, p2))
    return final

In [None]:
## hyperion
channel_map = {
    'Histone H3': '#0000ff', ## blue
    'Pan Keratin': '#ff0000', ## red
    'CD3': '#00ff00', ## green
    'CD20': '#ff00ff', ## purple
    'CD8a': '#ffff00', ## yellow
    'CD68': '#00ffff', ## light blue
}
from pathlib import Path
fig_dir = hyperior_folder.replace('/ome-tiff_htan/', '/thumbnails/generic_hyperion_qc_presentation')
Path(fig_dir).mkdir(parents=True, exist_ok=True)
for sample, channel_to_img in sti_raw.items():
    print(sample)
    img = display_image(channel_to_img, channel_map)
    plt.imshow(img)
#     plt.imsave(os.path.join(fig_dir, f'{sample}.png'), img)
    plt.show()
# display_legend({k:m for k, m in channel_map.items()})
plt.subplots_adjust(wspace=2.)
# plt.savefig(os.path.join(fig_dir, 'legend.png'))

In [None]:
for sample, channel_to_img in sti_raw.items():
    print(sample)
    img = display_image_CLAHE(channel_to_img, channel_map)
    plt.imshow(img)
    plt.show()

In [None]:
for sample, channel_to_img in sti_gaussian.items():
    print(sample)
    img = display_image(channel_to_img, channel_map)
    plt.imshow(img)
    plt.show()

In [None]:
for sample, channel_to_img in sti_CLAHE.items():
    print(sample)
    img = display_image(channel_to_img, channel_map)
    plt.imshow(img)
    plt.show()

In [None]:
for sample, channel_to_img in sti_CLAHE_gaussian.items():
    print(sample)
    img = display_image(channel_to_img, channel_map)
    plt.imshow(img)
    plt.show()

In [None]:
for sample, channel_to_img in sti_gaussian_CLAHE.items():
    print(sample)
    img = display_image(channel_to_img, channel_map)
    plt.imshow(img)
    plt.show()

In [None]:
gamma_map = {
    'Histone H3': ('#0000ff', .2), ## blue
    'Pan Keratin': ('#ff0000', .2), ## red
    'CD8a': ('#00ff00', .1), ## green
#     'CD20': ('#ff00ff', .1), ## purple
    'Ki67': ('#ffff00', .5), ## yellow
#     'CD20': ('#00ffff', .2), ## light blue
}
for sample, channel_to_img in sample_to_imgs.items():
    print(sample)
    img = display_image_gamma(channel_to_img, gamma_map)
#     img = exposure.equalize_adapthist(img, clip_limit=.03)
    plt.imshow(img)
    plt.show()

In [None]:
gamma_map = {
    'Histone H3': ('#0000ff', .2), ## blue
    'Pan Keratin': ('#ff0000', .2), ## red
    'CD8a': ('#00ff00', .5), ## green
#     'CD20': ('#ff00ff', .1), ## purple
    'Ki67': ('#ffff00', .3), ## yellow
#     'CD20': ('#00ffff', .2), ## light blue
}
for sample, channel_to_img in sti_gaussian.items():
    print(sample)
    img = display_image_gamma(channel_to_img, gamma_map)
#     img = exposure.equalize_adapthist(img, clip_limit=.03)
    plt.imshow(img)
    plt.show()

In [None]:
gamma_map = {
    'Histone H3': ('#0000ff', .8), ## blue
    'Pan Keratin': ('#ff0000', .8), ## red
    'CD8a': ('#00ff00', .8), ## green
#     'CD20': ('#ff00ff', .1), ## purple
    'Ki67': ('#ffff00', .8), ## yellow
#     'CD20': ('#00ffff', .2), ## light blue
}
for sample, channel_to_img in sti_CLAHE.items():
    print(sample)
    img = display_image_gamma(channel_to_img, gamma_map)
#     img = exposure.equalize_adapthist(img, clip_limit=.03)
    plt.imshow(img)
    plt.show()

###### sandbox

In [None]:
def colorize_bw(img, rgb):
    grayscale_image = img / max(img.flatten())
    image = color.gray2rgb(grayscale_image)
    
    colored = rgb * image
    return colored

def display_image_raw(channel_to_img, channels):
    channel_to_color = {ch:hex_to_rgb(c) for ch, c in channels.items()}

    final = None
    for c, color in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        if final is None:
            final = colorized
        else:
            final += colorized
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
    return final

def display_image_CLAHE(channel_to_img, channels, clip=.03):
    channel_to_color = {ch:hex_to_rgb(c) for ch, c in channels.items()}

    final = None
    for c, color in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        if final is None:
            final = colorized
        else:
            final += colorized
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
    try:
        final = exposure.equalize_adapthist(final, clip_limit=clip)
    except ZeroDivisionError:
        print('zero division error encountered')
    return final

def display_image_generic_v2(channel_to_img, channels, clip=.03):
    channel_to_color = {ch:hex_to_rgb(c) for ch, c in channels.items()}

    final = None
    for c, color in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
        if final is None:
            final = colorized
        else:
            final += colorized
            p1, p2 = np.percentile(final, (0, 100))
            final = exposure.rescale_intensity(final, in_range=(p1, p2))
            final = exposure.equalize_adapthist(final, clip_limit=clip)
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
#     final = exposure.equalize_adapthist(final, clip_limit=clip)
    return final  

def display_image_gamma(channel_to_img, channels, clip=.03):
    if isinstance(channels, dict):
        channel_to_color = {ch:(hex_to_rgb(c), g) for ch, (c, g) in channels.items()}
    else:
        channel_to_color = {c:(hex_to_rgb(DEFAULTS[i]), .5) for i, c in enumerate(channels)}
    
    final = None
    for c, (color, g) in channel_to_color.items():
        img = channel_to_img[c].copy()
        colorized = colorize_bw(img, list(color))
#         p1, p2 = np.percentile(colorized, (5, 100))
#         colorized = exposure.rescale_intensity(colorized, in_range=(p1, p2))
        img = exposure.adjust_gamma(colorized, gamma=g)
#         img = colorized
#         p1, p2 = np.percentile(img, (0, 100))
#         img = exposure.rescale_intensity(img, in_range=(p1, p2))
        
        if final is None:
            final = img
        else:
            final += img
    p1, p2 = np.percentile(final, (0, 100))
    final = exposure.rescale_intensity(final, in_range=(p1, p2))
#     final = exposure.adjust_gamma(final, gamma=.5)
    try:
        final = exposure.equalize_adapthist(final, clip_limit=clip)
    except ZeroDivisionError:
        print('zero division error encountered')
    return final

In [None]:
def display_legend(channel_map):
    fig, axs = plt.subplots(ncols=len(channel_map), figsize=(5, 1))
    for i, (c, color) in enumerate(channel_map.items()):
        ax = axs[i]
        ax.axhline(0, color=color, linewidth=110)
        ax.set_ylabel(c)
        ax.set_xticks([])
        ax.set_yticks([])
    return axs

In [None]:
def generate_comparison_plot(channel_to_img, channels, func_map, figsize=(20, 4), **kwargs):
    fig, axs = plt.subplots(nrows=1, ncols=len(func_map), figsize=figsize)
    for i, (name, func) in enumerate(func_map):
        ax = axs[i]
        
        if name == 'raw':
            img = display_image_raw(channel_to_img, channels)
            ax.imshow(img)
        if name == 'CLAHE':
            

In [None]:
channel_to_img.keys()

In [None]:
## hyperion
channel_map = {
    'Histone H3': '#0000ff', ## blue
    'Pan Keratin': '#ff0000', ## red
    'CD3': '#00ff00', ## green
    'CD20': '#ff00ff', ## purple
    'Ki67': '#ffff00', ## yellow
    'CD68': '#00ffff', ## light blue
}

# fig_dir = '/home/estorrs/imaging/data/hyperion/b1_01042020/thumbnails/generic'
fig_dir = hyperior_folder.replace('/ome-tiff_htan/', '/thumbnails/generic_hyperion_qc_presentation')
Path(fig_dir).mkdir(parents=True, exist_ok=True)
for sample, channel_to_img in sample_to_imgs.items():
    print(sample)
    img = display_image_generic(channel_to_img, channel_map, clip=.03)
    plt.imshow(img)
    plt.imsave(os.path.join(fig_dir, f'{sample}.png'), img)
    plt.show()
display_legend({k:m for k, m in channel_map.items()})
plt.subplots_adjust(wspace=2.)
plt.savefig(os.path.join(fig_dir, 'legend.png'))



# channel_map = {
#     'Histone H3': ('#0000ff', .6), ## blue
#     'Pan Keratin': ('#ff0000', .35), ## red
#     'CD3': ('#00ff00', .95), ## green
#     'CD20': ('#ff00ff', .6), ## purple
#     'Ki67': ('#ffff00', .6), ## yellow
#     'CD68': ('#00ffff', .6), ## light blue
# }
# # fig_dir = '/home/estorrs/imaging/data/hyperion/b1_01042020/thumbnails/generic'
# fig_dir = hyperior_folder.replace('/ome-tiff_htan/', '/thumbnails/generic_high_gamma')
# Path(fig_dir).mkdir(parents=True, exist_ok=True)
# for sample, channel_to_img in sample_to_imgs.items():
#     print(sample)
#     img = display_image_gamma(channel_to_img, channel_map, clip=.03)
#     plt.imshow(img)
#     plt.imsave(os.path.join(fig_dir, f'{sample}.png'), img)
#     plt.show()
# display_legend({k:m for k, (m, g) in channel_map.items()})
# plt.subplots_adjust(wspace=2.)
# plt.savefig(os.path.join(fig_dir, 'legend.png'))

In [None]:
import os
channel_map = {
    'DAPI1': ('#0000ff', .9),
    'ecadherin': ('#ff0000', .9),
    'CD8': ('#00ff00', .9),
    'CD20': ('#ff00ff', .9),
#     'CD31': ('#00ffff', .9),
    'Ki67': ('#ffff00', .9),
}

channel_map = {
    'CD49f': ('#0000ff', .9),
    'CD90': ('#ff0000', .9),
    'CD21': ('#00ff00', .9),
    'CD169': ('#ff00ff', .9),
    'CD45R': ('#00ffff', .9),
    'CD31': ('#ffff00', .9),
}
# ['CD140', 'CD90.2', 'CD31', 'MHCII', 'CD38', 'PDCA-1'])
channel_map = {
    'CD140': '#0000ff', ## blue
    'CD90.2': '#ff0000', ## red
    'PDCA-1': '#00ff00', ## green
    'MHCII': '#ff00ff', ## purple
    'Ki67': '#ffff00', ## yellow
    'CD31': '#00ffff', ## light blue
}
# channel_map = {
#     'CD140': ('#0000ff', .5), ## blue
#     'CD90.2': ('#ff0000', .5), ## red
#     'CD90.2': ('#00ff00', .5), ## green
#     'MHCII': ('#ff00ff', .5), ## purple
#     'PDCA-1': ('#ffff00', .5), ## yellow
#     'CD31': ('#00ffff', .5), ## light blue
# }


channel_to_img = next(iter(sample_to_imgs.values()))
print('start')
# img = display_image_gamma(channel_to_img, channel_map)
# img = display_image_raw(channel_to_img, channel_map)
img = display_image_generic(channel_to_img, channel_map, clip=.03)
print('done creating')
from pathlib import Path
fig_dir = '/home/estorrs/imaging/results/display/adult_kidney'
Path(fig_dir).mkdir(parents=True, exist_ok=True)
print('saving')
plt.imsave(os.path.join(fig_dir, f'test_kidney.png'), img)

In [None]:
plt.subplots(figsize=(10, 10))
plt.imshow(img[:5000, :5000])

In [None]:
img.shape

In [None]:
def display_legend(channel_map):
    fig, axs = plt.subplots(ncols=len(channel_map), figsize=(5, 1))
    for i, (c, color) in enumerate(channel_map.items()):
        ax = axs[i]
        ax.axhline(0, color=color, linewidth=110)
        ax.set_ylabel(c)
        ax.set_xticks([])
        ax.set_yticks([])
    return axs

In [None]:
# display_legend(channel_map)
import os
display_legend({k:m for k, (m, g) in channel_map.items()})
# display_legend({k:m for k, m in channel_map.items()})
# plt.tight_layout(pad=50.0)
plt.subplots_adjust(wspace=2.)
plt.savefig(os.path.join(fig_dir, 'legend.png'))

#### do nuclear and membrane segmentation

In [None]:
# from deepcell.applications import ScaleDetection

In [None]:
# m = ScaleDetection()

In [None]:
app = MultiplexSegmentation()
def get_coordinates(img):
    ids = sorted(set(img.flatten()))
    d = {}
    for c_id in ids:
        idxs = np.argwhere(img==c_id)
        r1, r2 = min(idxs[:, 0].flatten()), max(idxs[:, 0].flatten())
        c1, c2 = min(idxs[:, 1].flatten()), max(idxs[:, 1].flatten())
        d[c_id] = (r1 + ((r2 - r1) * .5), c1 + ((c2 - c1) * .5))
    return d

def run_whole_cell_segmentation(nuclear_img, membrane_img):
    img = np.stack((nuclear_img, membrane_img), axis=-1)
    img = np.expand_dims(img, 0)
    return app.predict(img, image_mpp=1)[0, :, :, 0]

# def run_whole_cell_segmentation(nuclear_img, membrane_img, w=140, p=30):
#     """Return 64 bit int segmentation mask for given nuclear and membrane images"""
#     im = np.stack((nuclear_img, membrane_img), axis=-1)
#     w = 140
#     p = 30

#     nrows = int(im.shape[0] / w) + 1
#     ncols = int(im.shape[1] / w) + 1

#     img = np.zeros(((w * nrows) + (p*2), (w * ncols) + (p*2), im.shape[2]))
#     img[p:im.shape[0] + p, p:im.shape[1] + p, :] = im
#     imgs = None
#     for r in range(nrows):
#         for c in range(ncols):
#             r1, r2 = (r * w), ((r+1) * w) + (p*2)
#             c1, c2 = (c * w), ((c+1) * w) + (p*2)

#             small = img[r1:r2, c1:c2]
#             small = np.expand_dims(small, 0)

#             if imgs is None:
#                 imgs = small
#             else:
#                 imgs = np.concatenate((imgs, small), axis=0)

#     labeled_img = app.predict(imgs, image_mpp=1)
    
    
#     final = np.zeros(((w * nrows) + (p*2), (w * ncols) + (p*2))).astype(np.int64)
#     count = 0
#     for r in range(nrows):
#         for c in range(ncols):
#             r1, r2 = (r * w), ((r+1) * w) + (p*2)
#             c1, c2 = (c * w), ((c+1) * w) + (p*2)
#             small = labeled_img[count, :, :, 0].copy()

#             wr1, wr2 = r1 + p, r2 - p
#             wc1, wc2 = c1 + p, c2 - p
#             cell_to_center = get_coordinates(small)
#             # remove cells whose center point fall outside of the window
#             for c_id in sorted(set(small.flatten())):
#                 x, y = cell_to_center[c_id]
#     #             print(x, y)
#                 if y < p: small[small==c_id] = 0
#                 if y > small.shape[0] - p: small[small==c_id] = 0
#                 if x < p: small[small==c_id] = 0
#                 if x > small.shape[1] - p: small[small==c_id] = 0

#             current = len(set(final.flatten()))
#             labels = sorted(set(small.flatten()))
#             for label in labels:
#                 if label != 0: small[small==label] = label + current
#             i2 = np.zeros(final.shape)
#             i2[r1:r2, c1:c2] = small
#             final = final + i2
#             count+=1

#     final = final[p:im.shape[0] + p, p:im.shape[1] + p]

#     # reassign labels so they actually are in order
#     to_new_label = {c:i for i, c in enumerate(sorted(set(final.flatten())))}
#     for old, new in to_new_label.items():
#         final[final==old] = int(new)
#     final = final.astype(np.int64)

#     return final
            


In [None]:
next(iter(sample_to_imgs.values())).keys()

In [None]:
for obj in [sti_raw, sti_CLAHE, sti_gaussian]:
    for sample, d in obj.items():
        print(sample)
        mask = run_whole_cell_segmentation(d['nuclear'], d['membrane'])
#         mask = run_whole_cell_segmentation(d['nuclear'], np.zeros(d['membrane'].shape))
        d['segmentation_mask'] = mask

In [None]:
# for obj in [sti_raw, sti_CLAHE, sti_gaussian, sti_CLAHE_gaussian, sti_gaussian_CLAHE]:
#     for sample, d in obj.items():
#         print(sample)
#         mask = run_nuclear_segmentation(d['nuclear'])
#         d['nuclear_segmentation_mask'] = mask

In [None]:
import random
def display_mask(mask):
    colors = sns.color_palette('tab20')
#     colors = [(1., 0., 0.), (0., 1., 0.), (0., 0., 1.)]
    new = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.float32)
    for i in sorted(set(mask.flatten())):
        if i:
            red, green, blue = random.choice(colors)
            rs, cs = np.where(mask==i)
            for r, c in zip(rs, cs):
                new[r, c, 0] = red * 255
                new[r, c, 1] = green * 255
                new[r, c, 2] = blue * 255
    new = new.astype(np.uint8)
    return new
        
    

In [None]:
sti_raw.keys()

In [None]:
k = 'HT_077B1_S1H1_A1_A4_100620_ROI_A4_b1'
w = [(100, 600), (100, 800)]

In [None]:
# labeled_img = app.predict(np.expand_dims(np.stack((sti_raw[k]['nuclear'], sti_raw[k]['membrane']), axis=-1), 0),
#                          image_mpp=1)

In [None]:
# plt.imshow(display_mask(labeled_img[0, :, :, 0]))

In [None]:
n = display_mask(sti_raw[k]['segmentation_mask'])
plt.imshow(n, )

In [None]:
# n = display_mask(sti_raw[k]['segmentation_mask'][w[0][0]:w[0][1], w[1][0]:w[1][1]])
# plt.imshow(n, )

In [None]:
plt.imshow(exposure.adjust_gamma(sti_raw[k]['nuclear'], .4))

In [None]:
plt.imshow(exposure.adjust_gamma(sti_raw[k]['membrane'], .4))

In [None]:
plt.imshow(exposure.adjust_gamma(sti_raw[k]['cellseg2'], .4))

In [None]:
def get_segmentation_comparison(objs, labels, window=None, figsize=(10, 10)):
    fig, axs = plt.subplots(nrows=3, ncols=len(labels), sharex=True, sharey=True)
    for i, obj in enumerate(objs):
        if window is None: window = [(0, obj[k]['segmentation_mask'].shape[0]), (0, obj[k]['segmentation_mask'].shape[1])]
        ax1, ax2, ax3 = axs[0, i], axs[1, i], axs[2, i]
        mask = obj[k]['segmentation_mask'][window[0][0]:window[0][1], window[1][0]:window[1][1]]
        img1 = display_mask(mask).copy()
        
        img = obj[k]['nuclear'][window[0][0]:window[0][1], window[1][0]:window[1][1]]
        if labels[i] !='CLAHE': img = exposure.adjust_gamma(img, .4)
        img2 = img.copy()
        
        img = obj[k]['membrane'][window[0][0]:window[0][1], window[1][0]:window[1][1]]
        if labels[i] !='CLAHE': img = exposure.adjust_gamma(img, .4)
        img3 = img.copy()
        
        ax1.imshow(img1)
        ax1.set_title(labels[i])
        
        ax2.imshow(img2)
        ax3.imshow(img3)
        
    plt.tight_layout()

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
get_segmentation_comparison([sti_raw, sti_CLAHE, sti_gaussian],
                           ['raw', 'CLAHE', 'gaussian'])

In [None]:
w1 = [(100, 200), (500, 600)]
get_segmentation_comparison([sti_raw, sti_CLAHE, sti_gaussian],
                           ['raw', 'CLAHE', 'gaussian'],
                           window=w1)

In [None]:
w1 = [(400, 500), (200, 300)]
get_segmentation_comparison([sti_raw, sti_CLAHE, sti_gaussian],
                           ['raw', 'CLAHE', 'gaussian'],
                           window=w1)

In [None]:
w1 = [(0, 200), (-200, -1)]
get_segmentation_comparison([sti_raw, sti_CLAHE, sti_gaussian],
                           ['raw', 'CLAHE', 'gaussian'],
                           window=w1)

In [None]:
w1 = [(200, 300), (300, 400)]
get_segmentation_comparison([sti_raw, sti_CLAHE, sti_gaussian],
                           ['raw', 'CLAHE', 'gaussian'],
                           window=w1)

In [None]:
## save segmentation_mask
seg_dir = hyperior_folder.replace('/ome-tiff_htan/', '/segmentation')
Path(seg_dir).mkdir(exist_ok=True, parents=True)
for sample, d in sample_to_imgs.items():
    tifffile.imsave(os.path.join(seg_dir, f'{sample}.ome.tiff'), d['segmentation_mask'].astype(np.int32), ome=True)

#### get intensities and normalize

In [None]:
def initialize_anndata(sample_to_imgs):
    adata = None
    for sample, d in sample_to_imgs.items():
        print(sample)
        segmentation_mask = d['segmentation_mask']
        # we exclude 0 bc thats not a cell
        cells = sorted(set(segmentation_mask.flatten()))[1:]
        data, channels = [], [c for c in d.keys() if c not in ['segmentation_mask', 'nuclear', 'membrane']]
        areas, xs, ys = [], [], []
        for cell in cells:
            mask = segmentation_mask==cell
            area = np.count_nonzero(mask)
            rs, cs = np.where(mask)
            xs.append(int((np.max(cs) - np.min(cs)) / 2) + np.min(cs))
            ys.append(int((np.max(rs) - np.min(rs)) / 2) + np.min(rs))
            areas.append(area)
            ls = []
            for channel in channels:
                ls.append(np.sum(d[channel][mask]) / area)
            data.append(ls)

        df = pd.DataFrame(data=data, columns=channels, index=[f'{sample}_{i}' for i in range(1, len(cells)+1)])
        small = anndata.AnnData(X=df.values)
        small.obs.index = df.index.to_list()
        small.var.index = df.columns
        small.obs['area'] = areas
        small.obs['x'] = xs
        small.obs['y'] = ys
        small.obs['sample'] = sample
        small.obs['batch'] = sample.split('_')[-1]

        if adata is None:
            adata = small
        else:
            adata = anndata.concat((adata, small))
    return adata

In [None]:
adata_map = {name:initialize_anndata(obj)
            for obj, name in zip([sti_raw, sti_CLAHE, sti_gaussian],
                          ['raw', 'CLAHE', 'gaussian', 'CLAHE_gaussian', 'gaussian_CLAHE'])}

In [None]:
# adata = None
# for sample, d in sample_to_imgs.items():
#     print(sample)
#     segmentation_mask = d['segmentation_mask']
#     # we exclude 0 bc thats not a cell
#     cells = sorted(set(segmentation_mask.flatten()))[1:]
#     data, channels = [], [c for c in d.keys() if c not in ['segmentation_mask', 'nuclear', 'membrane']]
#     areas, xs, ys = [], [], []
#     for cell in cells:
#         mask = segmentation_mask==cell
#         area = np.count_nonzero(mask)
#         rs, cs = np.where(mask)
#         xs.append(int((np.max(cs) - np.min(cs)) / 2) + np.min(cs))
#         ys.append(int((np.max(rs) - np.min(rs)) / 2) + np.min(rs))
#         areas.append(area)
#         ls = []
#         for channel in channels:
#             ls.append(np.sum(d[channel][mask]) / area)
#         data.append(ls)
    
#     df = pd.DataFrame(data=data, columns=channels, index=[f'{sample}_{i}' for i in range(len(cells))])
#     small = anndata.AnnData(X=df.values)
#     small.obs.index = df.index.to_list()
#     small.var.index = df.columns
#     small.obs['area'] = areas
#     small.obs['x'] = xs
#     small.obs['y'] = ys
#     small.obs['sample'] = sample
    
#     if adata is None:
#         adata = small
#     else:
#         adata = anndata.concat((adata, small))
# adata

In [None]:
adata_map['raw']

In [None]:
adata_map['raw'].obs

In [None]:
import seaborn as sns
sns.distplot([x for x in adata_map['raw'].obs['area'] if x <= 1000])

In [None]:
for name, adata in adata_map.items():
#     filtered = adata[adata.obs['area']<=500]
    filtered = adata[adata.obs['area']>=4]
    adata_map[name] = filtered

In [None]:
# # save adatas 
# adata.write_h5ad('../results/sandbox/presentation_adata.h5ad')
# filtered.write_h5ad('../results/sandbox/presentation_filtered.h5ad')

###### trying svca

#### UMAP/leiden based cell typing

In [None]:
def process_intensity_adata(adata):
    adata.raw = adata
    sc.pp.scale(adata)
    sc.pp.pca(adata)
    sc.pp.neighbors(adata, n_neighbors=5)
    sc.tl.umap(adata)
    
    return adata

In [None]:
regular_map = {}
for name, adata in adata_map.items():
    regular_map[name] = process_intensity_adata(adata.copy())

In [None]:
mpl.rcParams['figure.figsize']

In [None]:
mpl.rcParams['figure.figsize'] = (5, 5)
sc.pl.umap(regular_map['raw'], color='sample')

In [None]:
sc.pl.umap(regular_map['raw'], color='Pan Keratin')

In [None]:
adata = adata_map['raw'].copy()
set(adata.obs['sample'])

In [None]:
ref = adata[adata.obs['sample']=='HT061P1_PA_A1_A4_ROI_03_b1']
ref

In [None]:
rest = adata[adata.obs['sample']!='HT061P1_PA_A1_A4_ROI_03_b1']
rest

In [None]:
rest_processed = process_intensity_adata(rest.copy())

In [None]:
ref_processed = process_intensity_adata(ref.copy())
sc.pl.umap(ref_processed, color=['Pan Keratin'])

In [None]:
sc.tl.leiden(ref_processed, resolution=1.)
sc.pl.umap(ref_processed, color=['Pan Keratin', 'leiden'])

In [None]:
sc.pl.umap(ref_processed, color=['Pan Keratin', 'CD8a', 'Ki67'])

In [None]:
# from sklearn.preprocessing import StandardScaler
# import umap
# X = umap.UMAP().fit_transform(StandardScaler().fit_transform(ref.X))
# ref_processed.obsm['X_umap_2'] = X

In [None]:
ref_processed.var

In [None]:
# sc.pl.embedding(ref_processed, basis='X_umap_2', color=['Pan Keratin', 'CD8a', 'Ki67'])

In [None]:
# sc.pl.embedding(ref_processed, basis='X_umap_2', color=['leiden'])

In [None]:
sc.tl.ingest(rest_processed, ref_processed, obs='leiden')

In [None]:
sc.pl.embedding(rest_processed, basis='X_umap', color=['leiden'])

In [None]:
sc.pl.embedding(rest_processed, basis='X_umap', color=['Pan Keratin', 'CD8a', 'Ki67'])

In [None]:
integrated = ref_processed.concatenate(rest_processed)
integrated

In [None]:
sc.pl.umap(integrated, color=['leiden'])

In [None]:
sc.pl.umap(integrated, color=['Pan Keratin', 'CD8a', 'Ki67'], vmax=30., use_raw=True)

In [None]:
sc.pl.umap(integrated, color=['Pan Keratin', 'CD8a', 'CD4', 'Ki67'], vmax=10., use_raw=False)

In [None]:
sc.pl.umap(integrated, color=integrated.var.index, use_raw=True)

In [None]:
sc.pl.umap(integrated, color=['Pan Keratin', 'CD8a', 'Ki67'], vmax=10., use_raw=False)

In [None]:
sc.pl.umap(integrated, color=['batch', 'sample'], vmax=10., use_raw=False)

In [None]:
# pull up high ki67 cells across each sample

In [None]:
tumor = integrated[integrated.obs['leiden']=='7']
tumor

In [None]:
tumor[:, 'Ki67'].raw.X.flatten()

In [None]:
target = tumor[tumor[:, 'Ki67'].X.flatten() >= 2.]
target

In [None]:
target.obs.index

In [None]:
sample_adata = integrated[integrated.obs['sample']=='HT064B1_H1_A1_A4_ROI_03_b1']
sample_adata.obs

In [None]:
sample_adata.obs['highlight'] = [True if x in target.obs.index else False for x in sample_adata.obs.index]

In [None]:
sc.pl.scatter(sample_adata, x='y', y='x', color='highlight')

In [None]:
# !pip install leidenalg

In [None]:
clustering = filtered.copy()
exclude = {'PD1', 'PD-L1', 'Lag3', 'CD45RO', 'HLA-DR', 'HistoneH3', 'DNA', 'cellseg1', 'cellseg2', 'cellseg3',
          'CKG', 'GranzymeB', 'CD11c', 'CD133'}
clustering = clustering[:, [v for v in clustering.var.index if v not in exclude]]
sc.pp.pca(clustering)
sc.pp.neighbors(clustering, n_neighbors=15)
sc.tl.umap(clustering)
sc.tl.leiden(clustering, resolution=1.)

In [None]:
sc.pl.umap(clustering, color=[c for c in clustering.var.index], ncols=3)

In [None]:
sc.pl.umap(clustering, color=['sample', 'leiden'], ncols=1, legend_loc='on data')

In [None]:
cell_to_cluster = {
    'Malignant': ['13', '7', '15', '19', '21'],
    'Endothelial/CAF': ['4', '12', '11', '20', '6'],
    'Monocyte': ['1'],
    'CD4': ['16'],
    'CD8': ['9'],
    'Treg': ['18'],
    'Dendritic': ['14'],
    'B': ['8'],
    'Proliferating': ['2', '3']
}
cluster_to_cell = {v:k for k, vs in cell_to_cluster.items() for v in vs}
clustering.obs['cell_type'] = [cluster_to_cell.get(x, 'Other') for x in clustering.obs['leiden']]


In [None]:
sc.pl.umap(clustering, color=['leiden', 'cell_type'])

In [None]:
small = clustering.copy()
small = small[small.obs['cell_type']=='Endothelial/CAF']
sc.pp.pca(small)
sc.pp.neighbors(small, n_neighbors=15)
sc.tl.umap(small)
sc.tl.leiden(small, resolution=.5)

In [None]:
sc.pl.umap(small, color=['leiden', 'SMA', 'Vimentin', 'FAP', 'CD74', 'Type1Coll', 'CD31'])

In [None]:
d = {
    'myCAF': ['7'],
    'Endothelial': ['8', '0', '6'],
    'CAF1': ['1'],
    'CAF2': ['2']
}
r = {v:k for k, vs in d.items() for v in vs}
small.obs['cell_type'] = [r.get(x, 'Other') for x in small.obs['leiden']]


In [None]:
sc.pl.umap(small, color=['cell_type'])

In [None]:
clustering.obs['cell_type'] = [small.obs.loc[i, 'cell_type'] if c=='Endothelial/CAF' else c
                                             for i, c in zip(clustering.obs.index, clustering.obs['cell_type'])]
sc.pl.umap(clustering, color=['cell_type'])

#### downstream

cell porportions

In [None]:
clustering.obs

In [None]:
clustering.obs['image_id'] = clustering.obs['sample'].to_list()
clustering.obs['sample_id'] = ['_'.join(x.split('_')[:2]) for x in clustering.obs['image_id']]
clustering.obs

In [None]:
from collections import Counter
def plot_proportion(adata, percentage=True):
    samples = sorted(set(adata.obs['sample_id']))
    cell_types = sorted(set(adata.obs['cell_type']))

    data = []
    for s in samples:
        mini = adata.obs[adata.obs['sample_id']==s]
        counts = Counter(mini['cell_type'])
        for c in cell_types:
            data.append([s, c, counts.get(c, 0)/mini.shape[0] if percentage else counts.get(c, 0)])
    df = pd.DataFrame(data=data)
    df.columns = ['sample', 'cell_type', 'fraction' if percentage else 'count']

    return alt.Chart(df).mark_bar().encode(
        x='sample',
        y='fraction' if percentage else 'count',
        color='cell_type'
    )
    
plot_proportion(clustering, percentage=False)

In [None]:
plot_proportion(clustering, percentage=True)

In [None]:
caf_mask = [True if 'CAF' in c else False for c in clustering.obs['cell_type']]
plot_proportion(clustering[caf_mask], percentage=False)

In [None]:
plot_proportion(clustering[caf_mask], percentage=True)

display image

In [None]:
img = sample_to_imgs[]

In [None]:
d = next(iter(sample_to_imgs.values()))
d.keys()

In [None]:
plt.imshow(grayscale_image)

In [None]:
max(d['HLA-DR'].flatten())

In [None]:
from skimage import color
# c = .1
# grayscale_image = d['HLA-DR'] / (max(d['HLA-DR'].flatten()) * c)
grayscale_image = d['HLA-DR'] / max(d['HLA-DR'].flatten())
# p1, p2 = np.percentile(grayscale_image, (2, 98))
# grayscale_image = exposure.rescale_intensity(grayscale_image, in_range=(p1, p2), )
# grayscale_image = exposure.equalize_adapthist(grayscale_image, clip_limit=0.03)
# grayscale_image[grayscale_image>1] = 1.
image = color.gray2rgb(grayscale_image)

red_multiplier = [1, 0, 0]
yellow_multiplier = [1, 1, 0]

fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(16, 16*3),
                               sharex=True, sharey=True)
r = (red_multiplier * image) / max((red_multiplier * image).flatten())
ax1.imshow(r)
ax2.imshow(yellow_multiplier * image)
ax3.imshow((yellow_multiplier * image) + (red_multiplier * image))


In [None]:
from skimage import exposure

In [None]:
print(max(r.flatten()))

In [None]:
# Contrast stretching
img = r.copy()
p2, p98 = np.percentile(img, (2, 98))
img_rescale = exposure.rescale_intensity(img, in_range=(p2, p98), )

# Equalization
img_eq = exposure.equalize_hist(img)

# Adaptive Equalization
print(max(img.flatten()))
img_adapteq = exposure.equalize_adapthist(img, clip_limit=0.03)

# Display results
fig = plt.figure(figsize=(8, 5))
axes = np.zeros((2, 4), dtype=np.object)
axes[0, 0] = fig.add_subplot(2, 4, 1)
for i in range(1, 4):
    axes[0, i] = fig.add_subplot(2, 4, 1+i, sharex=axes[0,0], sharey=axes[0,0])
for i in range(0, 4):
    axes[1, i] = fig.add_subplot(2, 4, 5+i)


In [None]:
plt.subplots(figsize=(10, 10))
plt.imshow(img)

In [None]:
plt.subplots(figsize=(10, 10))
plt.imshow(exposure.equalize_adapthist(r, clip_limit=0.03))

In [None]:
plt.subplots(figsize=(10, 10))
plt.imshow(exposure.equalize_adapthist(img, clip_limit=0.01))