In [None]:
import pickle
import json
import os
import re
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage
import tifffile
import yaml
from matplotlib.collections import PolyCollection
from einops import rearrange, repeat
from pydantic_extra_types.color import Color

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['svg.fonttype'] = 'none'

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils
import mushroom.data.multiplex as multiplex
import mushroom.data.visium as visium
import mushroom.data.xenium as xenium
import mushroom.data.cosmx as cosmx
import mushroom.visualization.tiling_utils as tiling_utils

In [None]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [None]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1'

In [None]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
        
    return config

In [None]:
# case = 'HT913Z1'
case = 'HT891Z1'
# case = 'HT704B1'
# case = 'HT206B1'
# case = 'HT397B1'
# case = 'HT413C1-Th1k4A1'

In [None]:
project_dir = Path(f'/data/estorrs/mushroom/data/projects/submission_v1/{case}')

In [None]:
fig_dir = project_dir / 'figures' 
fig_dir.mkdir(parents=True, exist_ok=True)

In [None]:
config = yaml.safe_load(open(os.path.join(project_dir, 'registered', 'metadata.yaml')))
config = alter_filesystem(config, source_root, target_root)
config

In [None]:
case_to_multiplex_view_settings = {
    'HT913Z1': [
        {
            'channel': 'E-Cadherin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CK5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 75,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT891Z1': [
        {
            'channel': 'E-Cadherin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CK5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT704B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'Keratin 5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT206B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 150,
            'gamma': 1.
        },
        {
            'channel': 'SMA (D)',
            'color': 'white',
            'min_value': 15,
            'max_value': 150,
            'gamma': 1.
        },
        {
            'channel': 'Podoplanin (D)',
            'color': 'white',
            'min_value': 15,
            'max_value': 65,
            'gamma': 1.
        },
        {
            'channel': 'CD4 (D)',
            'color': 'green',
            'min_value': 20,
            'max_value': 165,
            'gamma': 1.
        },
        {
            'channel': 'CD45 (D)',
            'color': 'cyan',
            'min_value': 40,
            'max_value': 130,
            'gamma': 1.
        },
    ],
    'HT397B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'Keratin 14',
            'color': 'white',
            'min_value': 25,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'CD8',
            'color': 'green',
            'min_value': 10,
            'max_value': 31,
            'gamma': 1.
        },
        {
            'channel': 'CD45 (D)',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 35,
            'gamma': 1.
        },
    ]
}

In [None]:

multiplex_view_settings = case_to_multiplex_view_settings[case]

def get_multiplex_pseudo(fp, view_settings):
    channels = [x['channel'] for x in view_settings]
    colors = [x['color'] for x in view_settings]
    min_values = [x['min_value'] for x in view_settings]
    max_values = [x['max_value'] for x in view_settings]
    gammas = [x['gamma'] for x in view_settings]
    
    channel_to_img = multiplex.extract_ome_tiff(fp, as_dict=True)
    data = np.stack([channel_to_img[x] for x in channels])
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
    
    return rgb

In [None]:
XENIUM_FOCUS_SETTINGS = [
    {
        'channel': 'DAPI',
        'color': 'blue',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
    {
        'channel': 'ATP1A1/CD45/E-Cadherin',
        'color': 'magenta',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
                         {
        'channel': '18S',
        'color': 'yellow',
        'min_value': 0,
        'max_value': 50,
        'gamma': 1.
    },
                         {
        'channel': 'alphaSMA/Vimentin',
        'color': 'green',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
]

case_to_xenium_view_settings = {
    'HT913Z1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CP',
            'color': 'purple',
            'marker': 'P',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT891Z1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CP',
            'color': 'purple',
            'marker': 'P',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT704B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT206B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'ACTA2',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'PDPN',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT397B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
}

xenium_view_settings = case_to_xenium_view_settings[case]

def get_xenium_pseudo(morph_fp):
    cs = multiplex.get_ome_tiff_channels(morph_fp)
    if len(cs) == 1: # just nuclei
        channels = ['DAPI']
        colors = ['white']
        min_values = [0.]
        max_values = [100.]
        gammas = [1.]
    else: # multiple channels
        channels = [x['channel'] for x in XENIUM_FOCUS_SETTINGS]
        colors = [x['color'] for x in XENIUM_FOCUS_SETTINGS]
        min_values = [x['min_value'] for x in XENIUM_FOCUS_SETTINGS]
        max_values = [x['max_value'] for x in XENIUM_FOCUS_SETTINGS]
        gammas = [x['gamma'] for x in XENIUM_FOCUS_SETTINGS]
    
    channel_to_img = multiplex.extract_ome_tiff(morph_fp, as_dict=True)
    data = np.stack([channel_to_img[x] for x in channels])
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
    
    return rgb

def plot_xenium_pseudo(ax, rgb, transcripts, view_settings, s=.1):
    ax.imshow(rgb)
    
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)
        
    pool = set(transcripts['feature_name'])
    for entry in view_settings:
        color = np.asarray(Color(entry['color']).as_rgb_tuple()) / 255.
        channel = entry['channel']
        marker = entry['marker']
        
        if channel in pool:
            small = transcripts[transcripts['feature_name']==channel]
            X = small[['y_location', 'x_location']].values.astype(int)
            ax.scatter(X[:, 1], X[:, 0], color=color, s=s, marker=marker, edgecolors='none')
    return ax

def add_transcripts_to_rgb(rgb, transcripts, xenium_view_settings, pt_scaler=50):
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)

    size = (rgb.shape[1] / 1000, rgb.shape[0] / 1000)
    fig, ax = plt.subplots(figsize=(rgb.shape[1] / 1000, rgb.shape[0] / 1000))
    ax.set_axis_off()
    ax = plot_xenium_pseudo(ax, rgb, transcripts, xenium_view_settings, s=size[0] / pt_scaler)
    plt.savefig('temp.tif', bbox_inches='tight', pad_inches=0, dpi=int(1000 * 1.2987411728584588))
    rgb = tifffile.imread('temp.tif')
    os.remove('temp.tif')
    
    if rgb.shape[-1] == 4:
        rgb = rgb[..., :-1]
    
    return rgb



In [None]:
target_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))
# target_size = tifffile.imread(config['sections'][1]['data'][0]['filepath']).shape[1:]
# target_size = [int(x * .5) for x in target_size]
target_size

In [None]:
imaris_dir = project_dir / 'imaris' / 'rois'
imaris_dir.mkdir(parents=True, exist_ok=True)

In [None]:
r1, r2, c1, c2 = 1300, 2100, 2500, 3500
# r1, r2, c1, c2 = 1450, 1900, 2670, 2670 + 450
rgbs = []
for i, section in enumerate(config['sections']):
    print(i, section['sid'])
    sid = section['sid']
    for entry in section['data']:
        dtype = entry['dtype']
        fp = entry['filepath']
        position = section['position']

        if dtype in ['he', 'batch2_he']: 
            rgb = tifffile.imread(fp)
            if rgb.shape[0] <= 4:
                rgb = rearrange(rgb, 'c h w -> h w c')
        elif dtype == 'xenium':
            transcripts_fp = fp.replace('.h5ad', '_transcripts.parquet')
            morph_fp = fp.replace('.h5ad', '_morphologyfocus.ome.tiff')
            assert Path(transcripts_fp).exists()
            assert Path(morph_fp).exists()
            rgb = get_xenium_pseudo(morph_fp)
            rgb = add_transcripts_to_rgb(rgb, transcripts_fp, xenium_view_settings)

        elif dtype == 'multiplex':
            rgb = get_multiplex_pseudo(
                fp,
                multiplex_view_settings
            )
        else:
            rgb = None

        if rgb is not None:
            if rgb.shape[:2] != target_size:
                rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)
            rgbs.append(rgb[r1:r2, c1:c2])


In [None]:
stacked = np.stack(rgbs)
stacked.shape

In [None]:
# aa660b20-6a7b-484f-99b1-ece711a44c9a -> ef0a6544-d5f0-420e-9fef-2d8bf1049dfe

In [None]:
multiplex.write_basic_ome_tiff(
    imaris_dir / 'roi1.ome.tif',
    rearrange(stacked, 'z h w c -> 1 z c h w'),
    ['red', 'green', 'blue'],
    1.
)

In [None]:
sid_to_rgb = {entry['sid']:x for entry, x in zip(config['sections'], stacked)}

In [None]:
for i, (sid, rgb) in enumerate(sid_to_rgb.items()):
    print(i, sid)
    plt.imshow(rgb)
    plt.axis('off')
    plt.show()

In [None]:
def load_regions(regions_fp):
    regions = json.load(open(regions_fp))['features'] 
    regions = [{'id': x['id'], 'z': x['geometry']['plane']['z'] if 'plane' in x['geometry'] else 0, 'coordinates': np.asarray(x['geometry']['coordinates'])}
               for x in regions]
    regions = [x for x in regions if len(x['coordinates'].shape) == 3]
    for x in regions:
        x['coordinates'] = x['coordinates'][0][:, [1, 0]]
        x['mask'] = skimage.draw.polygon2mask(stacked.shape[1:3], x['coordinates'])

    return regions

In [None]:
dtype_to_channels = {
    'xenium': [],
    'multiplex': []
}
for dtype in dtype_to_channels.keys():
    fps = [entry['data'][0]['filepath'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    if 'multiplex' in dtype:
        dtype_to_channels[dtype] = multiplex.get_common_channels(fps)
    else:
        dtype_to_channels[dtype] = xenium.get_common_channels(fps)
dtype_to_channels

In [None]:
fullres_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))
tiling_size = 10
tiling_resolution = tiling_size
size = [x // tiling_size for x in fullres_size]
size

In [None]:
transcripts = pd.read_parquet('/data/estorrs/mushroom/data/projects/submission_v1/HT891Z1/registered/s0_HT891Z1-U1_xenium_transcripts.parquet')
transcripts


In [None]:
set(transcripts['feature_name'])

In [None]:
adata = xenium.adata_from_xenium('/data/estorrs/mushroom/data/projects/submission_v1/HT891Z1/registered/s0_HT891Z1-U1_xenium.h5ad')
adata

In [None]:
# tiled, tiled_channels = tiling_utils.tile_xenium(adata, tile_size=10, transcripts=transcripts)
# tiled.shape

In [None]:
dtype_to_tiled = {}
for dtype, channels in dtype_to_channels.items():
    dtype_to_tiled[dtype] = tiling_utils.get_tiled_sections(
        config, dtype=dtype, channel_names=channels,
        tiling_size=tiling_size, target_size=size, use_transcripts=True
    )
for dtype, tiled in dtype_to_tiled.items():
    print(dtype, tiled.shape)

In [None]:
sid_to_tiled, sid_to_dtype = {}, {}
for dtype in dtype_to_channels.keys():
    tiled = dtype_to_tiled[dtype]
    sids = [entry['sid'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    for sid, t in zip(sids, tiled):
        sid_to_tiled[sid] = t / t.std(axis=(-2, -1), keepdims=True)
        sid_to_dtype[sid] = dtype
sid_to_tiled.keys()

In [None]:
regions_fp = imaris_dir / 'roi1.geojson'
regions = load_regions(regions_fp)

len(regions), regions[0].keys()

In [None]:
dtype_to_epi = {
    'xenium': ('EPCAM', 1.),
    'multiplex': ('E-Cadherin', 10.),
}

local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [r1, r2, c1, c2]]
for region in regions:
    sid = config['sections'][region['z']]['sid']
    region['position'] = config['sections'][region['z']]['position']
    if sid in sid_to_tiled:
        tiled = sid_to_tiled[sid]
        dtype = sid_to_dtype[sid]
        channels = dtype_to_channels[dtype]

        coords = region['coordinates'] + np.asarray([r1, c1])
        coords //= tiling_resolution
        mask = skimage.draw.polygon2mask(tiled.shape[-2:], coords)

        outer = mask.copy()
        for i in range(1):
            outer = skimage.morphology.binary_dilation(outer)
        inner = mask.copy()
        for i in range(1):
            inner = skimage.morphology.binary_erosion(inner)
        expanded = mask.copy()
        for i in range(5):
            expanded = skimage.morphology.binary_dilation(expanded)

        ring = outer ^ inner
        means = tiled[:, ring].mean(1)
        fracs = np.count_nonzero(tiled[:, ring] > 0, axis=1) / np.count_nonzero(ring)
        
        tme = expanded ^ outer
        channel, thresh = dtype_to_epi[dtype]
        m = tiled[channels.index(channel)] > thresh
        tme = np.logical_and(tme, ~m)
        tme_means = tiled[:, tme].mean(1)
        tme_fracs = np.count_nonzero(tiled[:, tme] > 0, axis=1) / np.count_nonzero(tme)
        
        region['fracs'] = fracs
        region['means'] = means
        region['fracs_tme'] = tme_means
        region['means_tme'] = tme_fracs
        region['dtype'] = dtype
        region['channels'] = channels
        region['ring'] = ring
        region['tme'] = tme
    region['sid'] = sid
        
        


In [None]:
region_id_to_region = {x['id']:x for x in regions}

In [None]:
# cs, zs, vals = [], [], []
# for region in regions:
#     if 'dtype' in region and region['dtype'] == 'xenium':
#         cs.append(region['coordinates'])
#         zs.append(region['z'] * 5)
#         vals.append(region['means'][region['channels'].index('KRT5')])
# zs, vals = np.asarray(zs), np.asarray(vals)
# vals -= vals.min()
# vals /= vals.max()
# cmap = plt.colormaps['viridis_r']
# facecolors = [cmap(x) for x in vals]


In [None]:
cs, zs, vals = [], [], []
for region in regions:
    if 'dtype' in region and region['dtype'] == 'xenium':
        cs.append(region['coordinates'])
        zs.append(region['z'] * 5)

        if 'dtype' in region and region['dtype'] == 'xenium':
            print(region['sid'])
            vals.append(region['means'][region['channels'].index('KRT5')])
        else:
            vals.append(0.)
zs, vals = np.asarray(zs), np.asarray(vals)
vals -= vals.min()
vals /= vals.max()
cmap = plt.colormaps['viridis_r']
facecolors = [cmap(x) for x in vals]


In [None]:
ax = plt.figure().add_subplot(projection='3d')
poly = PolyCollection(cs, facecolors=facecolors, alpha=.7)
ax.set(xlim=(0, max([x[0] for xs in cs for x in xs])),
       ylim=(0,  max([x[1] for xs in cs for x in xs])),
       zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
plt.show()

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')

coords = [x['coordinates'] for x in regions]
# zs = [x['z'] * 5 for x in regions]
zs = [x['position'] for x in regions]
facecolors = [(.5, .5, .5, 1.) for i in range(len(zs))]
poly = PolyCollection(coords, facecolors=facecolors, alpha=.7, edgecolor=(.2, .2, .2, 1.))

ax.set(xlim=(0, 1000), ylim=(0, 1000), zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
ax.view_init(elev=10., azim=80)
ax.invert_zaxis()
plt.show()

In [None]:
# x = json.load(open(connectivity_fp))
# [geom for geom in x['features'] if geom['geometry']['type']!='Polygon']

In [None]:
connectivity_fp = imaris_dir / 'roi1_aligned.geojson'
conn_regions = load_regions(connectivity_fp)

len(conn_regions), conn_regions[0].keys()

In [None]:
assert len(set([x['id'] for x in conn_regions]).intersection(set([x['id'] for x in regions]))) == len(regions)

In [None]:
# def regions_to_volume(regions):
#     zs = [region['z'] for region in regions]
#     rgb = next(iter(sid_to_rgb.values()))
#     labeled = np.zeros((max(zs) + 1, *rgb.shape[:2]))
#     region_id_to_label = {}
#     for i, region in enumerate(regions):
#         sid = config['sections'][region['z']]['sid']
#         coords = region['coordinates'] + np.asarray([r1, c1])
#         coords //= tiling_resolution
#         mask = skimage.draw.polygon2mask(rgb.shape[:2], coords)
#         labeled[region['z'], mask] = i + 1
#         region_id_to_label[region['id']] = i + 1

#     return labeled.astype(int), region_id_to_label
def regions_to_volume(regions):
    zs = [region['z'] for region in regions]
    h, w = r2 - r1, c2 - c1
    labeled = np.zeros((max(zs) + 1, h, w))
    region_id_to_label = {}
    for i, region in enumerate(regions):
        sid = config['sections'][region['z']]['sid']
        coords = region['coordinates']
        mask = skimage.draw.polygon2mask((h, w), coords)
        labeled[region['z'], mask] = i + 1
        region_id_to_label[region['id']] = i + 1

    return labeled.astype(int), region_id_to_label

In [None]:
labeled, region_id_to_label = regions_to_volume(conn_regions)
label_to_region_id = {v:k for k, v in region_id_to_label.items()}

In [None]:
# for i, x in enumerate(labeled):
#     plt.imshow(x>0)
#     plt.title(i)
#     plt.show()

In [None]:
def graph_from_labeled(labeled):
    edges = []
    for i in range(labeled.shape[0]):
        m1, m3 = None, None
        m2 = labeled[i]
        if i == 0:
            m3 = labeled[i + 1]
        elif i == labeled.shape[0] - 1:
            m1 = labeled[i - 1]
        else:
            m1 = labeled[i - 1]
            m3 = labeled[i + 1]
        
        labels = [l for l in np.unique(m2) if l]
        for label in labels:
            if m1 is not None:
                overlap = [l for l in np.unique(m1[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
            if m3 is not None:
                overlap = [l for l in np.unique(m3[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
        
    return edges

In [None]:
edges = graph_from_labeled(labeled)

In [None]:
import networkx as nx

In [None]:
G = nx.Graph()
G.add_edges_from(edges)

In [None]:
terminals = [k for k, v in G.adj.items() if len(v) == 1]
terminals

In [None]:
[c for c in sorted(nx.connected_components(G), key=len, reverse=True)]

In [None]:
# missing = []
# for i, x in enumerate(labeled):
#     m = np.zeros_like(x, dtype=bool)
#     for node in missing:
#         m |= x==node
#     plt.imshow(m)
#     plt.title(i)
#     plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
nx.draw_networkx(G, font_size=6, node_size=100)

In [None]:
# Object ID	bd096426-679b-4216-a698-7d35a88fe3f0
# endpoints = {
#     'T1': ('f6032595-1fd0-44c2-82b0-72a3fac80f03', 'bd096426-679b-4216-a698-7d35a88fe3f0'),
#     'P1': ('cc6ab9c1-0805-46d4-87bb-8b3fb8b3b98d', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e'),
#     'P2': ('6c6e1ff9-1c07-408b-a4a6-e4e029a14523', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e'),
#     'P3': ('7ac3ea71-73a8-4311-8379-1b646d2f278b', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e'),
#     'P4': ('7ac3ea71-73a8-4311-8379-1b646d2f278b', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e'),
# }
# Object ID	18073a83-5d22-4c2a-8cc8-e1b27a4de88e
endpoints = [
    'f6032595-1fd0-44c2-82b0-72a3fac80f03', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e'
]


In [None]:
n1, n2 = endpoints
path = nx.algorithms.shortest_path(G, region_id_to_label[n1], region_id_to_label[n2])
path
# for n1 in terminals:
#     for n2 in terminals:
#         if n1 != n2:
#             path = nx.algorithms.shortest_path(G, n1, n2)
#             paths.append(path)

In [None]:
keep = [label_to_region_id[l] for l in path]

In [None]:
for x in keep:
    region = region_id_to_region[x]
    if region.get('dtype', '') == 'xenium':
        print(region['id'], region['position'])

In [None]:
import seaborn as sns
colors = sns.color_palette('deep')
colors

In [None]:
colors[0]

In [None]:
dtype_to_color = {
    'xenium': colors[0],
    'multiplex': colors[2],
    'he': colors[4]
}

dtype_to_edgecolor = {
    'xenium': sns.color_palette('dark')[0],
    'multiplex': sns.color_palette('dark')[2],
    'he': sns.color_palette('dark')[4]
}

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')

rs = [x for x in regions if x['id'] not in keep]
coords = [x['coordinates'] for x in rs]
zs = [x['position'] for x in rs]
facecolors = [(.5, .5, .5, 1.)] * len(rs)
edgecolors = [(.2, .2, .2, 1.)] * len(rs)


rs = [x for x in regions if x['id'] in keep]
coords += [x['coordinates'] for x in rs]
zs += [x['position'] for x in rs]
facecolors += [(.5, .5, .5, 1.)] * len(rs)
edgecolors += [(.8, .2, .2, 1.)] * len(rs)
poly = PolyCollection(coords, facecolors=facecolors, alpha=.7, edgecolor=edgecolors)

ax.set(xlim=(0, 1000), ylim=(0, 1000), zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
ax.view_init(elev=10., azim=80)
ax.invert_zaxis()

plt.show()

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')

rs = [x for x in regions if x['id'] in keep]
coords = [x['coordinates'] for x in rs]
zs = [x['position'] for x in rs]
facecolors = [dtype_to_color[x.get('dtype', 'he')] for x in rs]
edgecolors = [dtype_to_edgecolor[x.get('dtype', 'he')] for x in rs]
poly = PolyCollection(coords, facecolors=facecolors, alpha=.7, edgecolor=edgecolors)

ax.set(xlim=(0, 1000), ylim=(0, 1000), zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
ax.view_init(elev=10., azim=80)
ax.invert_zaxis()
ax.set_axis_off()

plt.show()

In [None]:
# missing = [1, 174]
# for x in labeled:
#     m = np.zeros_like(x, dtype=bool)
#     for node in missing:
#         m |= x==node
#     plt.imshow(m)
#     plt.show()
    

In [None]:
# for x in missing:
#     print(label_to_region_id[x])

In [None]:
def get_means_and_fracs(region_id_to_regions, paths, dtype='xenium', mean_key='means', frac_key='fracs'):
    results = []
    for path in paths:
        vals, fracs, zs, other_vals, other_fracs = [], [], [], [], []
        for k in path:
            region = region_id_to_region[label_to_region_id[k]]
            if 'dtype' in region and region['dtype'] == dtype:
                vals.append(region[mean_key])
                fracs.append(region[frac_key])
                zs.append(region['position'])
                
                rs = [x for x in region_id_to_region.values() if x['position']==region['position'] and x['id']!=region['id']]
                other_vals.append(np.stack([x[mean_key] for x in rs]))
                other_fracs.append(np.stack([x[frac_key] for x in rs]))
                
        if vals:
            vals, fracs = np.stack(vals), np.stack(fracs)
            vals, fracs = rearrange(vals, 'd n -> n d'), rearrange(fracs, 'd n -> n d')
            results.append({
                'vals': vals,
                'fracs': fracs,
                'other_vals': other_vals,
                'other_fracs': other_fracs,
                'zs': zs,
            })
    return results

import scipy
def get_correlations(vals, default=0.):
    coors, pvals = [], []
    for row in vals:
        r = scipy.stats.pearsonr(row, np.arange(len(row)))
        if not pd.isnull(r.pvalue):
            coors.append(r.correlation)
            pvals.append(r.pvalue)
        else:
            coors.append(default)
            pvals.append(1.)
    return np.asarray(coors), np.asarray(pvals)

In [None]:
result = get_means_and_fracs(region_id_to_region, [path])[0]

In [None]:
result['vals'].shape

In [None]:
# avg = np.stack([x.mean(0) for x in result['other_vals']]).T
# coors, pvals = get_correlations(result['vals'] - avg)
coors, pvals = get_correlations(result['vals'])
vals, fracs, zs, other_vals, other_fracs = result['vals'], result['fracs'], result['zs'], result['other_vals'], result['other_fracs']

coors.shape

In [None]:
avg.shape, result['vals'].shape

In [None]:
pvals

In [None]:
idxs = np.argsort(coors)
m = np.asarray([x > .1 for x in fracs[idxs].max(-1)])
m &= pvals[idxs] < .05
idxs = idxs[m]

ordered = np.asarray(dtype_to_channels['xenium'])[idxs]

ordered

In [None]:
# n = 15

# data = vals[idxs[:n]]
# data = np.concatenate((data, vals[idxs[-n:]]))
# source = pd.DataFrame(data=data.T, columns=np.concatenate((ordered[:n], ordered[-n:])), index=zs)
data = vals[idxs]
source = pd.DataFrame(data=data.T, columns=ordered, index=zs)
source

In [None]:
fig_dir

In [None]:
fig, ax = plt.subplots(figsize=(15, 10))
sns.heatmap(source, cmap='Blues')
plt.savefig(fig_dir / 'roi1_gene_heatmap.svg')

In [None]:
df = pd.DataFrame(coors, index=dtype_to_channels['xenium'], columns=['correlation'])
df['pvals'] = pvals
df = df.loc[source.columns]
df

In [None]:
sns.heatmap(df[['correlation']], cmap='PiYG')
plt.savefig(fig_dir / 'roi1_gene_heatmap_corrs.svg')

In [None]:
sns.heatmap(df[['pvals']], cmap='cividis_r', vmin=0, vmax=.05)
plt.savefig(fig_dir / 'roi1_gene_heatmap_pvals.svg')

In [None]:
region['coordinates']

In [None]:
radius = 150
imgs = []
for rid in keep:
    region = region_id_to_region[rid]
    mask = region['mask']
    
    expanded = mask.copy()
    for i in range(15):
        expanded = skimage.morphology.binary_dilation(expanded)
    
    rlabeled = skimage.morphology.label(mask)
    prop = skimage.measure.regionprops(rlabeled)[0]
    r, c = prop['centroid']
    r, c = int(r), int(c)
    
    r1 = max(0, r - radius)
    c1 = max(0, c - radius)
    r2 = min(mask.shape[0] - 1, r + 100)
    c2 = min(mask.shape[1] - 1, c + 100)
    
    rgb = sid_to_rgb[region['sid']]
    
    inner = expanded.copy()
    for i in range(5):
        inner = skimage.morphology.binary_erosion(inner)
    ring = expanded ^ inner
    rgb[ring] = [255, 255, 0]
    
    imgs.append(rgb[r1:r2, c1:c2])
len(imgs)

In [None]:

for img in imgs:
    plt.imshow(img)
    plt.show()

In [None]:
fig, axs = plt.subplots(ncols=len(imgs), figsize=(len(imgs), 1))

for rid, img, ax in zip(keep, imgs, axs):
    region = region_id_to_region[rid]
    position = region['position']
    
    ax.imshow(img)
    ax.set_axis_off()
    ax.set_aspect('equal')
    
    ax.set_title(f'Z{position}')
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(fig_dir / 'roi1_all_images.svg')

In [None]:
region = list(region_id_to_region.values())[0]
region.keys()

In [None]:
rlabeled = skimage.morphology.label(region['mask'])
prop = skimage.measure.regionprops(rlabeled)[0]
prop

In [None]:
plt.imshow(region['mask'])

In [None]:
prop.centroid_local

In [None]:
prop.centroid

In [None]:
def plot_genes(genes):
    cmap = sns.color_palette()
    for i, gene in enumerate(genes):
        sns.lineplot(x=zs, y=vals[dtype_to_channels['xenium'].index(gene)], label=gene, c=cmap[i])

#         xs = [z for i, z in enumerate(zs) for j in range(len(other_vals[i]))]
#         sns.lineplot(x=xs,
#                      y=[val for x in other_vals for val in x[:, dtype_to_channels['xenium'].index(gene)]],
#                      c=cmap[i], dashes=(2, 2), err_style='bars')

    plt.xlabel('Z depth')
    plt.ylabel('Expression')
    plt.gca().legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.gca().invert_xaxis()

In [None]:
genes = ordered[:10]
plot_genes(genes)
plt.savefig(fig_dir / 'roi1_expression_lineplot_pos.svg', dpi=300)

In [None]:
genes = np.flip(ordered[-10:])
plot_genes(genes)
plt.savefig(fig_dir / 'roi1_expression_lineplot_neg.svg', dpi=300)

In [None]:
import enrichrpy.enrichr as een
import enrichrpy.plotting as epl


In [None]:
df = een.get_pathway_enrichment(ordered[:10], gene_set_library='GO_Biological_Process_2021')
epl.enrichment_barplot(df, n=20)

In [None]:
df = een.get_pathway_enrichment(ordered[-10:], gene_set_library='GO_Biological_Process_2021')
epl.enrichment_barplot(df, n=20)

In [None]:
results = get_means_and_fracs(region_id_to_region, paths.values(), mean_key='means_tme', frac_key='fracs_tme')
avg = np.stack([x.mean(0) for x in results[0]['other_vals']]).T
coors = get_correlations(results[0]['vals'] - avg)
vals, fracs, zs, other_vals, other_fracs = results[0]['vals'], results[0]['fracs'], results[0]['zs'], results[0]['other_vals'], results[0]['other_fracs']

idxs = np.argsort(coors)
idxs = idxs[[x > .01 for x in fracs[idxs].max(-1)]]
ordered = np.asarray(dtype_to_channels['xenium'])[idxs]

# ordered = ordered[[x > .01 for x in fracs.max(-1)]]
ordered

In [None]:
genes = ordered[:10]
plot_genes(genes)
plt.savefig(fig_dir / 'roi1_expression_lineplot_pos_tme.svg', dpi=300)

In [None]:
genes = np.flip(ordered[-10:])
plot_genes(genes)
plt.savefig(fig_dir / 'roi1_expression_lineplot_neg_tme.svg', dpi=300)

In [None]:
def display_genes(gene, dt='xenium', figsize=(6, 20), mask_key='ring'):
    fig, axs = plt.subplots(ncols=3, nrows=len(dtype_to_tiled[dt]), figsize=figsize)

    for idx, (sid, tiled) in enumerate(sid_to_tiled.items()):
        dtype = sid_to_dtype[sid]
        rgb = sid_to_rgb[sid]
        if dtype == dt:
            channels = dtype_to_channels[dtype]
            local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [r1, r2, c1, c2]]
            x = tiled[channels.index(gene), local_r1:local_r2, local_c1:local_c2].copy()
            mask = None
#             if sid in selected_regions:
#                 region = selected_regions[sid]
#                 mask = region[mask_key][local_r1:local_r2, local_c1:local_c2].copy()
# #                 initial = scipy.ndimage.binary_fill_holes(mask)
# #                 outer = skimage.morphology.binary_dilation(initial)
# #                 mask = outer ^ initial
# #                 x[mask] = x.max()
#                 axs[idx, 2].imshow(mask)

            axs[idx, 0].imshow(rgb)
            axs[idx, 1].imshow(x)


    

In [None]:
display_genes('AQP9', mask_key='tme')

In [None]:
transcripts = pd.read_parquet(project_dir / 'sandbox' / 'transcripts.parquet')
transcripts

In [None]:
h, w = transcripts['y_location'].max(), transcripts['x_location'].max()
h, w = int(h) + 1, int(w) + 1
h, w

In [None]:
intact_coords = pd.read_csv(project_dir / 'sandbox' / 'intact_coordinates.csv', skiprows=2)
intact_coords

In [None]:
deg_coords = pd.read_csv(project_dir / 'sandbox' / 'degraded_coordinates.csv', skiprows=2)
deg_coords

In [None]:
def get_transcripts(coord_df, transcripts_df):
    sections = sorted(set(coord_df['Selection']))
    mask = np.zeros((h, w), dtype=bool)
    for section in sections:
        f = intact_coords[coord_df['Selection']==section]
        mask |= skimage.draw.polygon2mask((h, w), f[['Y', 'X']].values)
        
    keep = np.asarray([mask[int(r), int(c)] for r, c in zip(transcripts_df['y_location'], transcripts_df['x_location'])])
    return transcripts_df[keep]

In [None]:
intact_df = get_transcripts(intact_coords, transcripts)
intact_df

In [None]:
deg_df = get_transcripts(deg_coords, transcripts)
deg_df

In [None]:
intact_counts = {i:x for i, x in intact_df.groupby('feature_name').count()['transcript_id'].items()}
deg_counts = {i:x for i, x in deg_df.groupby('feature_name').count()['transcript_id'].items()}

pool = set(intact_counts.keys()).union(set(deg_counts.keys()))
pool = [x for x in pool if intact_counts.get(x, 0) > 10 or deg_counts.get(x, 0) > 10]
len(pool)

In [None]:
ratios = np.asarray([deg_counts.get(gene, 0) / (deg_counts.get(gene, 0) + intact_counts.get(gene, 0)) for gene in pool])
ratios

In [None]:
np.asarray(pool)[np.argsort(ratios)]

In [None]:
genes = ordered[:5]
fig, axs = plt.subplots(ncols=1, nrows=len(results))
for ax, result in zip(axs, results):
    for gene in genes:
        sns.lineplot(x=result['zs'], y=result['vals'][dtype_to_channels['xenium'].index(gene)], ax=ax)
    ax.set_ylim(0, 1.5)
    ax.set_xlim(0, 600)

In [None]:
genes = np.flip(ordered[-5:])
fig, axs = plt.subplots(ncols=1, nrows=len(results))
for ax, result in zip(axs, results):
    for gene in genes:
        sns.lineplot(x=result['zs'], y=result['vals'][dtype_to_channels['xenium'].index(gene)], ax=ax)
    ax.set_ylim(0, 2.5)
    ax.set_xlim(0, 600)

In [None]:
import seaborn as sns

In [None]:
pts = json.load(open(regions_fp))['features']
pts = [x['geometry'] for x in pts if x['geometry']['type']=='Point']
pts = [{'coordinates': np.asarray(x['coordinates'])[[1, 0]], 'z': x['plane']['z']} for x in pts]
pts

In [None]:
def get_means_and_fracs(regions, pts, dtype='xenium', mean_key='means', frac_key='fracs'):
    vals = []
    fracs = []
    selected_regions = {}
    for region in regions:
        if 'dtype' in region and region['dtype'] == dtype:
            for pt in pts:
                r, c = [int(x) for x in pt['coordinates']]
                if region['mask'][r, c] and region['z']==pt['z']:
                    print(region['z'])
                    vals.append(region[mean_key])
                    fracs.append(region[frac_key])
                    selected_regions[region['sid']] = region
    vals, fracs = np.stack(vals), np.stack(fracs)
    vals, fracs = rearrange(vals, 'c n -> n c'), rearrange(fracs, 'c n -> n c')
    return vals, fracs

def get_correlations(vals, default=0.):
    coors = []
    for row in vals:
        x = scipy.stats.pearsonr(row, np.arange(len(row))).correlation
        if not pd.isnull(pval):
            coors.append(x)
        else:
            coors.append(default)
    return np.asarray(coors)

In [None]:
vals, fracs = get_means_and_fracs(regions, pts)
coors = get_correlations(vals)
vals.shape, fracs.shape, coors.shape

In [None]:
idxs = np.argsort(coors)
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
np.asarray(dtype_to_channels['xenium'])[np.flip(idxs)]

In [None]:
idxs = [i for i in idxs if fracs[i].max() > .1]
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
import seaborn as sns
for i in range(10):
    sns.lineplot(np.flip(vals[idxs[i]]))

In [None]:
import seaborn as sns
for i in range(1, 10, 1):
    sns.lineplot(np.flip(vals[idxs[-i]]))

In [None]:
size

In [None]:
def display_genes(gene, dt='xenium', figsize=(6, 20), mask_key='ring'):
    fig, axs = plt.subplots(ncols=3, nrows=len(dtype_to_tiled[dt]), figsize=figsize)

    for idx, (sid, tiled) in enumerate(sid_to_tiled.items()):
        dtype = sid_to_dtype[sid]
        rgb = sid_to_rgb[sid]
        if dtype == dt:
            channels = dtype_to_channels[dtype]
            local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [r1, r2, c1, c2]]
            x = tiled[channels.index(gene), local_r1:local_r2, local_c1:local_c2].copy()
            mask = None
            if sid in selected_regions:
                region = selected_regions[sid]
                mask = region[mask_key][local_r1:local_r2, local_c1:local_c2].copy()
#                 initial = scipy.ndimage.binary_fill_holes(mask)
#                 outer = skimage.morphology.binary_dilation(initial)
#                 mask = outer ^ initial
#                 x[mask] = x.max()
                axs[idx, 2].imshow(mask)

            axs[idx, 0].imshow(rgb)
            axs[idx, 1].imshow(x)


    

In [None]:
for i in range(0, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[i]]
    print(gene)
    display_genes(gene)
    plt.show()

In [None]:
for i in range(1, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[-i]]
    print(gene)
    display_genes(gene)
    plt.show()

In [None]:
vals, fracs = get_means_and_fracs(regions, pts, mean_key='means_tme', frac_key='fracs_tme')
coors = get_correlations(vals)
vals.shape, fracs.shape, coors.shape

In [None]:
idxs = np.argsort(coors)
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
idxs = [i for i in idxs if fracs[i].max() > .1]
np.asarray(dtype_to_channels['xenium'])[idxs]

In [None]:
np.asarray(dtype_to_channels['xenium'])[np.flip(idxs)]

In [None]:
import seaborn as sns
for i in range(10):
    sns.lineplot(np.flip(vals[idxs[i]]))

In [None]:
import seaborn as sns
for i in range(1, 10, 1):
    sns.lineplot(np.flip(vals[idxs[-i]]))

In [None]:
for i in range(0, 20, 1):
    gene = dtype_to_channels['xenium'][idxs[i]]
    print(gene)
    display_genes(gene, mask_key='tme')
    plt.show()

In [None]:
for i in range(1, 10, 1):
    gene = dtype_to_channels['xenium'][idxs[-i]]
    print(gene)
    display_genes(gene, mask_key='tme')
    plt.show()

In [None]:
display_genes('C5AR1', mask_key='tme')

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.collections import PolyCollection

# Fixing random state for reproducibility
np.random.seed(19680801)


def polygon_under_graph(x, y):
    """
    Construct the vertex list which defines the polygon filling the space under
    the (x, y) line graph. This assumes x is in ascending order.
    """
    return [(x[0], 0.), *zip(x, y), (x[-1], 0.)]


ax = plt.figure().add_subplot(projection='3d')

x = np.linspace(0., 10., 31)
lambdas = range(1, 9)

# verts[i] is a list of (x, y) pairs defining polygon i.
gamma = np.vectorize(math.gamma)
verts = [polygon_under_graph(x, l**x * np.exp(-l) / gamma(x + 1))
         for l in lambdas]
facecolors = plt.colormaps['viridis_r'](np.linspace(0, 1, len(verts)))
verts = [np.asarray(x) for x in verts]
print(len(verts), verts[0].shape)

poly = PolyCollection(verts, facecolors=facecolors, alpha=.7)
ax.add_collection3d(poly, zs=lambdas, zdir='y')

ax.set(xlim=(0, 10), ylim=(1, 9), zlim=(0, 0.35),
       xlabel='x', ylabel=r'$\lambda$', zlabel='probability')

plt.show()

In [None]:
ax = plt.figure().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=facecolors, alpha=.7)
ax.add_collection3d(poly, zs=lambdas, zdir='y')

In [None]:
dtype_to_channels = {
    'xenium': [],
    'multiplex': []
}
for dtype in dtype_to_channels.keys():
    fps = [entry['data'][0]['filepath'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    if 'multiplex' in dtype:
        dtype_to_channels[dtype] = multiplex.get_common_channels(fps)
    else:
        dtype_to_channels[dtype] = xenium.get_common_channels(fps)
dtype_to_channels

In [None]:
dtype_to_tiled = {}
for dtype, channels in dtype_to_channels.items():
    dtype_to_tiled[dtype] = tiling_utils.get_tiled_sections(
        config, dtype=dtype, channel_names=channels,
        tiling_size=tiling_size, target_size=target_size
    )
for dtype, tiled in dtype_to_tiled.items():
    print(dtype, tiled.shape)