In [None]:
import pickle
import json
import os
import re
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage
import tifffile
import yaml
from matplotlib.collections import PolyCollection
from einops import rearrange, repeat

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils
import mushroom.data.multiplex as multiplex
import mushroom.data.visium as visium
import mushroom.data.xenium as xenium
import mushroom.data.cosmx as cosmx
import mushroom.visualization.tiling_utils as tiling_utils

In [None]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [None]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1'

In [None]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
        
    return config

In [None]:
# case = 'HT913Z1'
case = 'HT891Z1'
# case = 'HT704B1'
# case = 'HT206B1'
# case = 'HT397B1'
# case = 'HT413C1-Th1k4A1'

In [None]:
project_dir = Path(f'/data/estorrs/mushroom/data/projects/submission_v1/{case}')

In [None]:
config = yaml.safe_load(open(os.path.join(project_dir, 'registered', 'metadata.yaml')))
config = alter_filesystem(config, source_root, target_root)
config

In [None]:
multiplex_view_settings = [
    {
        'channel': 'E-Cadherin',
        'color': 'red',
        'min_value': 15,
        'max_value': 100,
        'gamma': 1.
    },
#     {
#         'channel': 'Pan-Cytokeratin',
#         'color': 'red',
#         'min_value': 15,
#         'max_value': 100,
#         'gamma': 1.
#     },
    {
        'channel': 'HLA-DR',
        'color': 'magenta',
        'min_value': 15,
        'max_value': 255,
        'gamma': 1.
    },
    {
        'channel': 'CK5',
        'color': 'white',
        'min_value': 15,
        'max_value': 255,
        'gamma': 1.
    },
#     {
#         'channel': 'Keratin 14',
#         'color': 'white',
#         'min_value': 10,
#         'max_value': 30,
#         'gamma': 1.
#     },
#     {
#         'channel': 'SMA (D)',
#         'color': 'white',
#         'min_value': 15,
#         'max_value': 255,
#         'gamma': 1.
#     },
    {
        'channel': 'CD3e',
        'color': 'green',
        'min_value': 20,
        'max_value': 75,
        'gamma': 1.
    },
#     {
#         'channel': 'CD45 (D)',
#         'color': 'green',
#         'min_value': 10,
#         'max_value': 20,
#         'gamma': 1.
#     },
    {
        'channel': 'CD68',
        'color': 'cyan',
        'min_value': 15,
        'max_value': 255,
        'gamma': 1.
    },
]

def get_multiplex_pseudo(fp, view_settings):
    channels = [x['channel'] for x in view_settings]
    colors = [x['color'] for x in view_settings]
    min_values = [x['min_value'] for x in view_settings]
    max_values = [x['max_value'] for x in view_settings]
    gammas = [x['gamma'] for x in view_settings]
    
    channel_to_img = multiplex.extract_ome_tiff(fp, as_dict=True)
    data = np.stack([channel_to_img[x] for x in channels])
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
    
    return rgb

In [None]:
xenium_view_settings = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'min_value': 0,
        'max_value': 20,
        'gamma': 1.
    },
#     {
#         'channel': 'HLA-DQB2',
#         'color': 'magenta',
#         'min_value': 0,
#         'max_value': 50,
#         'gamma': 1.
#     },
    {
        'channel': 'TP63',
        'color': 'white',
        'min_value': 0,
        'max_value': 20,
        'gamma': 1.
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'min_value': 0,
        'max_value': 20,
        'gamma': 1.
    },
#     {
#         'channel': 'ACTA2',
#         'color': 'white',
#         'min_value': 0,
#         'max_value': 20,
#         'gamma': 1.
#     },
    {
        'channel': 'CP',
        'color': 'cyan',
        'min_value': 0,
        'max_value': 30,
        'gamma': 1.
    },
    {
        'channel': 'CD3E',
        'color': 'green',
        'min_value': 0,
        'max_value': 30,
        'gamma': 1.
    },
    {
        'channel': 'CD68',
        'color': 'magenta',
        'min_value': 0,
        'max_value': 30,
        'gamma': 1.
    },
]

def tile_xenium(adata, target_size=None, tile_size=20):
    if target_size is None:
        target_size = xenium.get_fullres_size(adata)
    
    adata.obs['grid_name'] = [f'{x // tile_size}_{y // tile_size}' for x, y in adata.obsm['spatial']]
    df = pd.DataFrame(data=adata.X, columns=adata.var.index.to_list(), index=adata.obs.index.to_list())
    df['grid_name'] = adata.obs['grid_name'].to_list()
    df = df.groupby('grid_name').sum()
    
    img = np.zeros((target_size[0] // tile_size + 1, target_size[1] // tile_size + 1, df.shape[1]))
    for name, row in df.iterrows():
        x, y = [int(x) for x in name.split('_')]
        img[y, x] = row.values
    return img
    

def get_xenium_pseudo(fp, view_settings, tile_size=20, target_size=None):
    channels = [x['channel'] for x in view_settings]
    colors = [x['color'] for x in view_settings]
    min_values = [x['min_value'] for x in view_settings]
    max_values = [x['max_value'] for x in view_settings]
    gammas = [x['gamma'] for x in view_settings]
    
    adata = xenium.adata_from_xenium(fp)
    if target_size is None:
        target_size = xenium.get_fullres_size(adata)
    
    img = tile_xenium(adata, tile_size=tile_size)
    
    both = set(channels).intersection(set(adata.var.index.to_list()))
    missing = set(channels) - both
    print('missing', missing)
    data = np.zeros((img.shape[0], img.shape[1], len(channels)), dtype=img.dtype)
    for channel in both:
        idx = channels.index(channel)
        data[..., channels.index(channel)] = img[..., adata.var.index.to_list().index(channel)]
    for channel in missing:
        idx = channels.index(channel)
        data[..., channels.index(channel)] = np.zeros_like(img[..., 0])
        data[0, 0, idx] = 1.
        
    
    data -= data.min((0, 1))
    data /= data.max((0, 1))
    data *= 255.
    data = data.astype(np.uint8)
    data = rearrange(data, 'h w c -> c h w')
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
#     plt.imshow(rgb[4000:4500, 4000:4500])
#     plt.show()
    
    return rgb

In [None]:
target_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))
# target_size = tifffile.imread(config['sections'][1]['data'][0]['filepath']).shape[1:]
# target_size = [int(x * .5) for x in target_size]
target_size

In [None]:
imaris_dir = project_dir / 'imaris' / 'rois'
imaris_dir.mkdir(parents=True, exist_ok=True)

In [None]:
r1, r2, c1, c2 = 1300, 2100, 2500, 3500
rgbs = []
for i, section in enumerate(config['sections']):
    print(i, section['sid'])
    sid = section['sid']
    for entry in section['data']:
        dtype = entry['dtype']
        fp = entry['filepath']
        position = section['position']

        if dtype in ['he', 'batch2_he']: 
            rgb = tifffile.imread(fp)
        elif dtype == 'xenium':
            rgb = get_xenium_pseudo(
                fp,
                xenium_view_settings,
                tile_size=10,
                target_size=target_size,
            )
        elif dtype == 'multiplex':
            rgb = get_multiplex_pseudo(
                fp,
                multiplex_view_settings
            )
        else:
            rgb = None


        if rgb is not None:
            if rgb.shape[:2] != target_size:
                rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)
            rgbs.append(rgb[r1:r2, c1:c2])


In [None]:
stacked = np.stack(rgbs)
stacked.shape

In [None]:
multiplex.write_basic_ome_tiff(
    imaris_dir / 'roi1.ome.tif',
    rearrange(stacked, 'z h w c -> 1 z c h w'),
    ['red', 'green', 'blue'],
    1.
)

In [None]:
sid_to_rgb = {entry['sid']:x for entry, x in zip(config['sections'], stacked)}

In [None]:
def load_regions(regions_fp):
    regions = json.load(open(regions_fp))['features'] 
    regions = [{'id': x['id'], 'z': x['geometry']['plane']['z'] if 'plane' in x['geometry'] else 0, 'coordinates': np.asarray(x['geometry']['coordinates'])}
               for x in regions]
    regions = [x for x in regions if len(x['coordinates'].shape) == 3]
    for x in regions:
        x['coordinates'] = x['coordinates'][0][:, [1, 0]]
        x['mask'] = skimage.draw.polygon2mask(stacked.shape[1:3], x['coordinates'])

    return regions

In [None]:
regions_fp = imaris_dir / 'roi1.geojson'
regions = load_regions(regions_fp)

len(regions), regions[0].keys()

In [None]:
for region in regions:
    print(region['id'] + '\t' + str(region['z']))

In [None]:
dtype_to_channels = {
    'xenium': [],
    'multiplex': []
}
for dtype in dtype_to_channels.keys():
    fps = [entry['data'][0]['filepath'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    if 'multiplex' in dtype:
        dtype_to_channels[dtype] = multiplex.get_common_channels(fps)
    else:
        dtype_to_channels[dtype] = xenium.get_common_channels(fps)
dtype_to_channels

In [None]:
fullres_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))
tiling_size = 10
size = [x // tiling_size for x in fullres_size]
size

In [None]:
dtype_to_tiled = {}
for dtype, channels in dtype_to_channels.items():
    dtype_to_tiled[dtype] = tiling_utils.get_tiled_sections(
        config, dtype=dtype, channel_names=channels,
        tiling_size=tiling_size, target_size=size
    )
for dtype, tiled in dtype_to_tiled.items():
    print(dtype, tiled.shape)

In [None]:
sid_to_tiled, sid_to_dtype = {}, {}
for dtype in dtype_to_channels.keys():
    tiled = dtype_to_tiled[dtype]
    sids = [entry['sid'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    for sid, t in zip(sids, tiled):
        sid_to_tiled[sid] = t
        sid_to_dtype[sid] = dtype
sid_to_tiled.keys()

In [None]:
dtype_to_epi = {
    'xenium': ('EPCAM', 1.),
    'multiplex': ('Pan-Cytokeratin', 10.),
}

local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [r1, r2, c1, c2]]
for region in regions:
    sid = config['sections'][region['z']]['sid']
    
    if sid in sid_to_tiled:
        tiled = sid_to_tiled[sid]
        dtype = sid_to_dtype[sid]
        channels = dtype_to_channels[dtype]

        coords = region['coordinates'] + np.asarray([r1, c1])
        coords //= tiling_resolution
        mask = skimage.draw.polygon2mask(tiled.shape[-2:], coords)

        outer = mask.copy()
        for i in range(1):
            outer = skimage.morphology.binary_dilation(outer)
        inner = mask.copy()
        for i in range(1):
            inner = skimage.morphology.binary_erosion(inner)
        expanded = mask.copy()
        for i in range(5):
            expanded = skimage.morphology.binary_dilation(expanded)

        ring = outer ^ inner
        means = tiled[:, ring].mean(1)
        fracs = np.count_nonzero(tiled[:, ring] > 0, axis=1) / np.count_nonzero(ring)
        
        tme = expanded ^ outer
        channel, thresh = dtype_to_epi[dtype]
        m = tiled[channels.index(channel)] > thresh
        tme = np.logical_and(tme, ~m)
        tme_means = tiled[:, tme].mean(1)
        tme_fracs = np.count_nonzero(tiled[:, tme] > 0, axis=1) / np.count_nonzero(tme)
        
        
        
        region['sid'] = sid
        region['fracs'] = fracs
        region['means'] = means
        region['fracs_tme'] = tme_means
        region['means_tme'] = tme_fracs
        region['dtype'] = dtype
        region['channels'] = channels
        region['ring'] = ring
        region['tme'] = tme
        
        


In [None]:
# cs, zs, vals = [], [], []
# for region in regions:
#     if 'dtype' in region and region['dtype'] == 'xenium':
#         cs.append(region['coordinates'])
#         zs.append(region['z'] * 5)
#         vals.append(region['means'][region['channels'].index('KRT5')])
# zs, vals = np.asarray(zs), np.asarray(vals)
# vals -= vals.min()
# vals /= vals.max()
# cmap = plt.colormaps['viridis_r']
# facecolors = [cmap(x) for x in vals]


In [None]:
cs, zs, vals = [], [], []
for region in regions:
    if 'dtype' in region and region['dtype'] == 'xenium':
        cs.append(region['coordinates'])
        zs.append(region['z'] * 5)

        if 'dtype' in region and region['dtype'] == 'xenium':
            print(region['sid'])
            vals.append(region['means'][region['channels'].index('KRT5')])
        else:
            vals.append(0.)
zs, vals = np.asarray(zs), np.asarray(vals)
vals -= vals.min()
vals /= vals.max()
cmap = plt.colormaps['viridis_r']
facecolors = [cmap(x) for x in vals]


In [None]:
ax = plt.figure().add_subplot(projection='3d')
poly = PolyCollection(cs, facecolors=facecolors, alpha=.7)
ax.set(xlim=(0, max([x[0] for xs in cs for x in xs])),
       ylim=(0,  max([x[1] for xs in cs for x in xs])),
       zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
plt.show()

In [None]:
ax = plt.figure().add_subplot(projection='3d')

coords = [x['coordinates'] for x in regions]
zs = [x['z'] * 5 for x in regions]
facecolors = [(.3, .3, .8, 1.) for i in range(len(zs))]
poly = PolyCollection(coords, facecolors=facecolors, alpha=.7, edgecolor=(.3, .3, .3, 1.))

ax.set(xlim=(0, 1000), ylim=(0, 1000), zlim=(0, 50 * 5))
ax.add_collection3d(poly, zs=zs, zdir='z')
ax.view_init(elev=10., azim=70)
plt.show()

In [None]:
connectivity_fp = imaris_dir / 'roi1_aligned.geojson'
conn_regions = load_regions(connectivity_fp)

len(conn_regions), conn_regions[0].keys()

In [None]:
def regions_to_volume(regions):
    zs = [region['z'] for region in regions]
    rgb = next(iter(sid_to_rgb.values()))
    labeled = np.zeros((max(zs) + 1, *rgb.shape[:2]))
    region_id_to_label = {}
    for i, region in enumerate(regions):
        sid = config['sections'][region['z']]['sid']
        coords = region['coordinates'] + np.asarray([r1, c1])
        coords //= tiling_resolution
        mask = skimage.draw.polygon2mask(rgb.shape[:2], coords)
        labeled[region['z'], mask] = i + 1
        region_id_to_label[region['id']] = i + 1
        
    return labeled.astype(int), region_id_to_label

In [None]:
labeled, region_id_to_label = regions_to_volume(conn_regions)

In [None]:
def graph_from_labeled(labeled):
    edges = []
    for i in range(labeled.shape[0]):
        m1, m3 = None, None
        m2 = labeled[i]
        if i == 0:
            m3 = labeled[i + 1]
        elif i == labeled.shape[0] - 1:
            m1 = labeled[i - 1]
        else:
            m1 = labeled[i - 1]
            m3 = labeled[i + 1]
        
        for label in np.unique(m2)[1:]:
            if m1 is not None:
                overlap = [l for l in np.unique(m1[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
            if m3 is not None:
                overlap = [l for l in np.unique(m3[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
        
    return edges
            
        
        
        

In [None]:
import networkx

In [None]:
graph_from_labeled(labeled)

In [None]:
for x in labeled:
    print(np.unique(x))
    plt.imshow(x)
    plt.show()

In [None]:
pts = json.load(open(regions_fp))['features']
pts = [x['geometry'] for x in pts if x['geometry']['type']=='Point']
pts = [{'coordinates': np.asarray(x['coordinates'])[[1, 0]], 'z': x['plane']['z']} for x in pts]
pts

In [None]:
def get_means_and_fracs(regions, pts, dtype='xenium', mean_key='means', frac_key='fracs'):
    vals = []
    fracs = []
    selected_regions = {}
    for region in regions:
        if 'dtype' in region and region['dtype'] == dtype:
            for pt in pts:
                r, c = [int(x) for x in pt['coordinates']]
                if region['mask'][r, c] and region['z']==pt['z']:
                    print(region['z'])
                    vals.append(region[mean_key])
                    fracs.append(region[frac_key])
                    selected_regions[region['sid']] = region
    vals, fracs = np.stack(vals), np.stack(fracs)
    vals, fracs = rearrange(vals, 'c n -> n c'), rearrange(fracs, 'c n -> n c')
    return vals, fracs

def get_correlations(vals, default=0.):
    coors = []
    for row in vals:
        x = scipy.stats.pearsonr(row, np.arange(len(row))).correlation
        if not pd.isnull(pval):
            coors.append(x)
        else:
            coors.append(default)
    return np.asarray(coors)

In [None]:
vals, fracs = get_means_and_fracs(regions, pts)
coors = get_correlations(vals)
vals.shape, fracs.shape, coors.shape

In [None]:
idxs = np.argsort(coors)
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
np.asarray(dtype_to_channels['xenium'])[np.flip(idxs)]

In [None]:
idxs = [i for i in idxs if fracs[i].max() > .1]
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
import seaborn as sns
for i in range(10):
    sns.lineplot(np.flip(vals[idxs[i]]))

In [None]:
import seaborn as sns
for i in range(1, 10, 1):
    sns.lineplot(np.flip(vals[idxs[-i]]))

In [None]:
size

In [None]:
def display_genes(gene, dt='xenium', figsize=(6, 20), mask_key='ring'):
    fig, axs = plt.subplots(ncols=3, nrows=len(dtype_to_tiled[dt]), figsize=figsize)

    for idx, (sid, tiled) in enumerate(sid_to_tiled.items()):
        dtype = sid_to_dtype[sid]
        rgb = sid_to_rgb[sid]
        if dtype == dt:
            channels = dtype_to_channels[dtype]
            local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [r1, r2, c1, c2]]
            x = tiled[channels.index(gene), local_r1:local_r2, local_c1:local_c2].copy()
            mask = None
            if sid in selected_regions:
                region = selected_regions[sid]
                mask = region[mask_key][local_r1:local_r2, local_c1:local_c2].copy()
#                 initial = scipy.ndimage.binary_fill_holes(mask)
#                 outer = skimage.morphology.binary_dilation(initial)
#                 mask = outer ^ initial
#                 x[mask] = x.max()
                axs[idx, 2].imshow(mask)

            axs[idx, 0].imshow(rgb)
            axs[idx, 1].imshow(x)


    

In [None]:
for i in range(0, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[i]]
    print(gene)
    display_genes(gene)
    plt.show()

In [None]:
for i in range(1, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[-i]]
    print(gene)
    display_genes(gene)
    plt.show()

In [None]:
vals, fracs = get_means_and_fracs(regions, pts, mean_key='means_tme', frac_key='fracs_tme')
coors = get_correlations(vals)
vals.shape, fracs.shape, coors.shape

In [None]:
idxs = np.argsort(coors)
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
idxs = [i for i in idxs if fracs[i].max() > .1]
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
np.asarray(dtype_to_channels['xenium'])[np.flip(idxs)]

In [None]:
import seaborn as sns
for i in range(10):
    sns.lineplot(np.flip(vals[idxs[i]]))

In [None]:
import seaborn as sns
for i in range(1, 10, 1):
    sns.lineplot(np.flip(vals[idxs[-i]]))

In [None]:
for i in range(0, 20, 1):
    gene = dtype_to_channels['xenium'][idxs[i]]
    print(gene)
    display_genes(gene, mask_key='tme')
    plt.show()

In [None]:
for i in range(1, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[-i]]
    print(gene)
    display_genes(gene, mask_key='tme')
    plt.show()

In [None]:
display_genes('C5AR1', mask_key='tme')

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.collections import PolyCollection

# Fixing random state for reproducibility
np.random.seed(19680801)


def polygon_under_graph(x, y):
    """
    Construct the vertex list which defines the polygon filling the space under
    the (x, y) line graph. This assumes x is in ascending order.
    """
    return [(x[0], 0.), *zip(x, y), (x[-1], 0.)]


ax = plt.figure().add_subplot(projection='3d')

x = np.linspace(0., 10., 31)
lambdas = range(1, 9)

# verts[i] is a list of (x, y) pairs defining polygon i.
gamma = np.vectorize(math.gamma)
verts = [polygon_under_graph(x, l**x * np.exp(-l) / gamma(x + 1))
         for l in lambdas]
facecolors = plt.colormaps['viridis_r'](np.linspace(0, 1, len(verts)))
verts = [np.asarray(x) for x in verts]
print(len(verts), verts[0].shape)

poly = PolyCollection(verts, facecolors=facecolors, alpha=.7)
ax.add_collection3d(poly, zs=lambdas, zdir='y')

ax.set(xlim=(0, 10), ylim=(1, 9), zlim=(0, 0.35),
       xlabel='x', ylabel=r'$\lambda$', zlabel='probability')

plt.show()

In [None]:
ax = plt.figure().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=facecolors, alpha=.7)
ax.add_collection3d(poly, zs=lambdas, zdir='y')

In [None]:
dtype_to_channels = {
    'xenium': [],
    'multiplex': []
}
for dtype in dtype_to_channels.keys():
    fps = [entry['data'][0]['filepath'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    if 'multiplex' in dtype:
        dtype_to_channels[dtype] = multiplex.get_common_channels(fps)
    else:
        dtype_to_channels[dtype] = xenium.get_common_channels(fps)
dtype_to_channels

In [None]:
dtype_to_tiled = {}
for dtype, channels in dtype_to_channels.items():
    dtype_to_tiled[dtype] = tiling_utils.get_tiled_sections(
        config, dtype=dtype, channel_names=channels,
        tiling_size=tiling_size, target_size=target_size
    )
for dtype, tiled in dtype_to_tiled.items():
    print(dtype, tiled.shape)