In [None]:
import pickle
import json
import os
import re
import sys
from copy import deepcopy
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage
import tifffile
import yaml
from matplotlib.collections import PolyCollection
from einops import rearrange, repeat
from pydantic_extra_types.color import Color

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['svg.fonttype'] = 'none'

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mushroom.mushroom import Mushroom, DEFAULT_CONFIG
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils
import mushroom.data.multiplex as multiplex
import mushroom.data.visium as visium
import mushroom.data.xenium as xenium
import mushroom.data.cosmx as cosmx
import mushroom.visualization.tiling_utils as tiling_utils

In [None]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

In [None]:
run_dir = '/data/estorrs/mushroom/data/projects/submission_v1'

In [None]:
def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
        
    return config

In [None]:
# case = 'HT913Z1'
case = 'HT891Z1'
# case = 'HT704B1'
# case = 'HT206B1'
# case = 'HT397B1'
# case = 'HT413C1-Th1k4A1'

In [None]:
project_dir = Path(f'/data/estorrs/mushroom/data/projects/submission_v1/{case}')

In [None]:
fig_dir = project_dir / 'figures' 
fig_dir.mkdir(parents=True, exist_ok=True)

In [None]:
config = yaml.safe_load(open(os.path.join(project_dir, 'registered', 'metadata.yaml')))
config = alter_filesystem(config, source_root, target_root)
config

In [None]:
sid_to_z = {entry['sid']:i for i, entry in enumerate(config['sections'])}

In [None]:
case_to_multiplex_view_settings = {
    'HT913Z1': [
        {
            'channel': 'E-Cadherin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CK5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 75,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT891Z1': [
        {
            'channel': 'E-Cadherin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CK5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT704B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'Keratin 5',
            'color': 'white',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD3e',
            'color': 'green',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
    ],
    'HT206B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 15,
            'max_value': 150,
            'gamma': 1.
        },
        {
            'channel': 'SMA (D)',
            'color': 'white',
            'min_value': 15,
            'max_value': 150,
            'gamma': 1.
        },
        {
            'channel': 'Podoplanin (D)',
            'color': 'white',
            'min_value': 15,
            'max_value': 65,
            'gamma': 1.
        },
        {
            'channel': 'CD4 (D)',
            'color': 'green',
            'min_value': 20,
            'max_value': 165,
            'gamma': 1.
        },
        {
            'channel': 'CD45 (D)',
            'color': 'cyan',
            'min_value': 40,
            'max_value': 130,
            'gamma': 1.
        },
    ],
    'HT397B1': [
        {
            'channel': 'Pan-Cytokeratin',
            'color': 'red',
            'min_value': 15,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'HLA-DR',
            'color': 'yellow',
            'min_value': 20,
            'max_value': 255,
            'gamma': 1.
        },
        {
            'channel': 'Keratin 14',
            'color': 'white',
            'min_value': 25,
            'max_value': 100,
            'gamma': 1.
        },
        {
            'channel': 'CD8',
            'color': 'green',
            'min_value': 10,
            'max_value': 31,
            'gamma': 1.
        },
        {
            'channel': 'CD45 (D)',
            'color': 'cyan',
            'min_value': 15,
            'max_value': 35,
            'gamma': 1.
        },
    ]
}

In [None]:

multiplex_view_settings = case_to_multiplex_view_settings[case]

def get_multiplex_pseudo(fp, view_settings):
    channels = [x['channel'] for x in view_settings]
    colors = [x['color'] for x in view_settings]
    min_values = [x['min_value'] for x in view_settings]
    max_values = [x['max_value'] for x in view_settings]
    gammas = [x['gamma'] for x in view_settings]
    
    channel_to_img = multiplex.extract_ome_tiff(fp, as_dict=True)
    data = np.stack([channel_to_img[x] for x in channels])
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
    
    return rgb

In [None]:
XENIUM_FOCUS_SETTINGS = [
    {
        'channel': 'DAPI',
        'color': 'blue',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
    {
        'channel': 'ATP1A1/CD45/E-Cadherin',
        'color': 'magenta',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
                         {
        'channel': '18S',
        'color': 'yellow',
        'min_value': 0,
        'max_value': 50,
        'gamma': 1.
    },
                         {
        'channel': 'alphaSMA/Vimentin',
        'color': 'green',
        'min_value': 0,
        'max_value': 100,
        'gamma': 1.
    },
]

case_to_xenium_view_settings = {
    'HT913Z1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CP',
            'color': 'purple',
            'marker': 'P',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT891Z1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CP',
            'color': 'purple',
            'marker': 'P',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT704B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'TP63',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT206B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'ACTA2',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'PDPN',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
    'HT397B1': [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'HLA-DQB2',
            'color': 'yellow',
            'marker': 'v',
        },
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 's',
        },
        {
            'channel': 'CD3E',
            'color': 'limegreen',
            'marker': '*',
        },
        {
            'channel': 'CD68',
            'color': 'cyan',
            'marker': 'D',
        },
    ],
}

xenium_view_settings = case_to_xenium_view_settings[case]

def get_xenium_pseudo(morph_fp):
    cs = multiplex.get_ome_tiff_channels(morph_fp)
    if len(cs) == 1: # just nuclei
        channels = ['DAPI']
        colors = ['white']
        min_values = [0.]
        max_values = [100.]
        gammas = [1.]
    else: # multiple channels
        channels = [x['channel'] for x in XENIUM_FOCUS_SETTINGS]
        colors = [x['color'] for x in XENIUM_FOCUS_SETTINGS]
        min_values = [x['min_value'] for x in XENIUM_FOCUS_SETTINGS]
        max_values = [x['max_value'] for x in XENIUM_FOCUS_SETTINGS]
        gammas = [x['gamma'] for x in XENIUM_FOCUS_SETTINGS]
    
    channel_to_img = multiplex.extract_ome_tiff(morph_fp, as_dict=True)
    data = np.stack([channel_to_img[x] for x in channels])
    
    rgb = multiplex.to_pseudocolor(
        data,
        colors=colors,
        min_values=min_values,
        max_values=max_values,
        gammas=gammas
    )
    
    rgb *= 255.
    rgb = rgb.astype(np.uint8)
    
    return rgb

def plot_xenium_pseudo(ax, rgb, transcripts, view_settings, s=.1):
    ax.imshow(rgb)
    
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)
        
    pool = set(transcripts['feature_name'])
    for entry in view_settings:
        color = np.asarray(Color(entry['color']).as_rgb_tuple()) / 255.
        channel = entry['channel']
        marker = entry['marker']
        
        if channel in pool:
            small = transcripts[transcripts['feature_name']==channel]
            X = small[['y_location', 'x_location']].values.astype(int)
            ax.scatter(X[:, 1], X[:, 0], color=color, s=s, marker=marker, edgecolors='none')
    return ax

def add_transcripts_to_rgb(rgb, transcripts, xenium_view_settings, pt_scaler=50):
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)

    size = (rgb.shape[1] / 1000, rgb.shape[0] / 1000)
    fig, ax = plt.subplots(figsize=(rgb.shape[1] / 1000, rgb.shape[0] / 1000))
    ax.set_axis_off()
    ax = plot_xenium_pseudo(ax, rgb, transcripts, xenium_view_settings, s=size[0] / pt_scaler)
    plt.savefig('temp.tif', bbox_inches='tight', pad_inches=0, dpi=int(1000 * 1.2987411728584588))
    rgb = tifffile.imread('temp.tif')
    os.remove('temp.tif')
    
    if rgb.shape[-1] == 4:
        rgb = rgb[..., :-1]
    
    return rgb



In [None]:
if case in ['HT397B1']:
    target_size = tifffile.imread(config['sections'][1]['data'][0]['filepath']).shape[1:]
    target_size = [int(x * .5) for x in target_size]
else:
    target_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))

target_size

In [None]:
imaris_dir = project_dir / 'imaris' / 'rois'
imaris_dir.mkdir(parents=True, exist_ok=True)

In [None]:
case_to_rois = {
    'HT891Z1': {
        'roi1': (1300, 2100, 2500, 3500),
        'roi2': (4650, 5650, 1900, 2900),
        'roi2-p2': (4650, 5650, 1900, 2900),
    }
}

In [None]:
roi = 'roi2'

In [None]:
R1, R2, C1, C2 = case_to_rois[case][roi]
rgbs = []
for i, section in enumerate(config['sections']):
    print(i, section['sid'])
    sid = section['sid']
    for entry in section['data']:
        dtype = entry['dtype']
        fp = entry['filepath']
        position = section['position']

        if dtype in ['he', 'batch2_he']: 
            rgb = tifffile.imread(fp)
            if rgb.shape[0] <= 4:
                rgb = rearrange(rgb, 'c h w -> h w c')
        elif dtype == 'xenium':
            transcripts_fp = fp.replace('.h5ad', '_transcripts.parquet')
            morph_fp = fp.replace('.h5ad', '_morphologyfocus.ome.tiff')
            assert Path(transcripts_fp).exists()
            assert Path(morph_fp).exists()
            rgb = get_xenium_pseudo(morph_fp)
            rgb = add_transcripts_to_rgb(rgb, transcripts_fp, xenium_view_settings)

        elif dtype == 'multiplex':
            rgb = get_multiplex_pseudo(
                fp,
                multiplex_view_settings
            )
        else:
            rgb = None

        if rgb is not None:
            if rgb.shape[:2] != target_size:
                rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)
            rgbs.append(rgb[R1:R2, C1:C2])


In [None]:
stacked = np.stack(rgbs)
stacked.shape

In [None]:
multiplex.write_basic_ome_tiff(
    imaris_dir / f'{roi}.ome.tif',
    rearrange(stacked, 'z h w c -> 1 z c h w'),
    ['red', 'green', 'blue'],
    1.
)

In [None]:
sid_to_rgb = {entry['sid']:x for entry, x in zip(config['sections'], stacked)}

In [None]:
for i, (sid, rgb) in enumerate(sid_to_rgb.items()):
    print(i, sid)
    plt.imshow(rgb)
    plt.axis('off')
    plt.show()

In [None]:
dtype_to_channels = {
    'xenium': [],
    'multiplex': []
}
for dtype in dtype_to_channels.keys():
    fps = [entry['data'][0]['filepath'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    if 'multiplex' in dtype:
        dtype_to_channels[dtype] = multiplex.get_common_channels(fps)
    else:
        dtype_to_channels[dtype] = xenium.get_common_channels(fps)
dtype_to_channels

In [None]:
fullres_size = xenium.get_fullres_size(xenium.adata_from_xenium(config['sections'][0]['data'][0]['filepath']))
tiling_size = 10
tiling_resolution = tiling_size
size = [x // tiling_size for x in fullres_size]
size

In [None]:
dtype_to_tiled = {}
for dtype, channels in dtype_to_channels.items():
    dtype_to_tiled[dtype] = tiling_utils.get_tiled_sections(
        config, dtype=dtype, channel_names=channels,
        tiling_size=tiling_size, target_size=size, use_transcripts=True
    )
for dtype, tiled in dtype_to_tiled.items():
    print(dtype, tiled.shape)

In [None]:
sid_to_tiled, sid_to_dtype = {}, {}
for dtype in dtype_to_channels.keys():
    tiled = dtype_to_tiled[dtype]
    sids = [entry['sid'] for entry in config['sections'] if entry['data'][0]['dtype']==dtype]
    for sid, t in zip(sids, tiled):
        sid_to_tiled[sid] = t / t.std(axis=(-2, -1), keepdims=True)
        sid_to_dtype[sid] = dtype
sid_to_tiled.keys()

In [None]:
def load_regions(regions_fp):
    regions = json.load(open(regions_fp))['features'] 
    regions = [{'id': x['id'], 'z': x['geometry']['plane']['z'] if 'plane' in x['geometry'] else 0, 'coordinates': np.asarray(x['geometry']['coordinates'])}
               for x in regions]
    regions = [x for x in regions if len(x['coordinates'].shape) == 3]
    for x in regions:
        x['coordinates'] = x['coordinates'][0][:, [1, 0]]
        x['mask'] = skimage.draw.polygon2mask(stacked.shape[1:3], x['coordinates'])

    return regions

In [None]:
regions_fp = imaris_dir / f'{roi}.geojson'
regions = load_regions(regions_fp)

len(regions), regions[0].keys()

In [None]:
for sid, x in sid_to_tiled.items():
    if sid_to_dtype[sid] == 'xenium':
        # gpc1
        print(sid, x.shape, x[dtype_to_channels['xenium'].index('GPC1')].sum())

In [None]:
# roi = f'{roi}-p2'

In [None]:
blank = np.zeros(size, dtype=bool)
z_to_regions_mask = {}
for region in regions:
    if region['z'] not in z_to_regions_mask:
        z_to_regions_mask[region['z']] = blank.copy()

    coords = region['coordinates'] + np.asarray([R1, C1])
    coords //= tiling_resolution
    mask = skimage.draw.polygon2mask(size, coords)
    z_to_regions_mask[region['z']] |= mask

In [None]:
dtype_to_epi = case_to_dtype_to_epi[case]
local_r1, local_r2, local_c1, local_c2 = [x // tiling_resolution for x in [R1, R2, C1, C2]]
for region in regions:
    sid = config['sections'][region['z']]['sid']
    region['position'] = config['sections'][region['z']]['position']
    if sid in sid_to_tiled:
        tiled = sid_to_tiled[sid]
        dtype = sid_to_dtype[sid]
        channels = dtype_to_channels[dtype]
        regions_mask = z_to_regions_mask[region['z']]

        coords = region['coordinates'] + np.asarray([R1, C1])
        coords //= tiling_resolution
        mask = skimage.draw.polygon2mask(tiled.shape[-2:], coords)

        outer = mask.copy()
        for i in range(1):
            outer = skimage.morphology.binary_dilation(outer)
        inner = mask.copy()
        for i in range(1):
            inner = skimage.morphology.binary_erosion(inner)
        expanded = mask.copy()
        for i in range(5):
            expanded = skimage.morphology.binary_dilation(expanded)
            
        means = tiled[:, mask].mean(1)
        fracs = np.count_nonzero(tiled[:, mask] > 0, axis=1) / np.count_nonzero(mask)

        ring = outer ^ inner
        ring_means = tiled[:, ring].mean(1)
        ring_fracs = np.count_nonzero(tiled[:, ring] > 0, axis=1) / np.count_nonzero(ring)

        tme = expanded ^ outer
        tme = tme ^ regions_mask
        tme_means = tiled[:, tme].mean(1)
        tme_fracs = np.count_nonzero(tiled[:, tme] > 0, axis=1) / np.count_nonzero(tme)

        region['fracs'] = fracs
        region['means'] = means
        region['fracs_ring'] = tme_means
        region['means_ring'] = tme_fracs
        region['fracs_tme'] = tme_means
        region['means_tme'] = tme_fracs
        region['dtype'] = dtype
        region['channels'] = channels
        region['ring'] = ring
        region['tme'] = tme
        region['m'] = mask
        
    region['sid'] = sid
        
        


In [None]:
region_id_to_region = {x['id']:x for x in regions}

In [None]:

# cmap = plt.colormaps['viridis_r']
# facecolors = [cmap(x) for x in vals]


In [None]:
case_to_zeniths = {
    'HT891Z1': {
        'roi1': 80,
        'roi2': 80,
        'roi2-p2': 80
    }
}

In [None]:
ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')

coords = [x['coordinates'] for x in regions]
zs = [x['position'] for x in regions]
facecolors = [(.5, .5, .5, 1.) for i in range(len(zs))]
poly = PolyCollection(coords, facecolors=facecolors, alpha=.7, edgecolor=(.2, .2, .2, 1.))

ax.set(xlim=(0, C2 - C1), ylim=(0, R2 - R1), zlim=(0, max(zs)))
ax.add_collection3d(poly, zs=zs, zdir='z')
ax.view_init(elev=10., azim=case_to_zeniths[case][roi])
ax.invert_zaxis()
plt.show()

#### path from connectivity annotations

In [None]:
connectivity_fp = imaris_dir / f'{roi}_aligned.geojson'
conn_regions = load_regions(connectivity_fp)

len(conn_regions), conn_regions[0].keys()

In [None]:
assert len(set([x['id'] for x in conn_regions]).intersection(set([x['id'] for x in regions]))) == len(regions)

In [None]:
# def regions_to_volume(regions):
#     zs = [region['z'] for region in regions]
#     rgb = next(iter(sid_to_rgb.values()))
#     labeled = np.zeros((max(zs) + 1, *rgb.shape[:2]))
#     region_id_to_label = {}
#     for i, region in enumerate(regions):
#         sid = config['sections'][region['z']]['sid']
#         coords = region['coordinates'] + np.asarray([r1, c1])
#         coords //= tiling_resolution
#         mask = skimage.draw.polygon2mask(rgb.shape[:2], coords)
#         labeled[region['z'], mask] = i + 1
#         region_id_to_label[region['id']] = i + 1

#     return labeled.astype(int), region_id_to_label
def regions_to_volume(regions):
    zs = [region['z'] for region in regions]
    h, w = R2 - R1, C2 - C1
    labeled = np.zeros((max(zs) + 1, h, w))
    region_id_to_label = {}
    for i, region in enumerate(regions):
        sid = config['sections'][region['z']]['sid']
        coords = region['coordinates']
        mask = skimage.draw.polygon2mask((h, w), coords)
        labeled[region['z'], mask] = i + 1
        region_id_to_label[region['id']] = i + 1

    return labeled.astype(int), region_id_to_label

In [None]:
labeled, region_id_to_label = regions_to_volume(conn_regions)
label_to_region_id = {v:k for k, v in region_id_to_label.items()}

In [None]:
def graph_from_labeled(labeled):
    edges = []
    for i in range(labeled.shape[0]):
        m1, m3 = None, None
        m2 = labeled[i]
        if i == 0:
            m3 = labeled[i + 1]
        elif i == labeled.shape[0] - 1:
            m1 = labeled[i - 1]
        else:
            m1 = labeled[i - 1]
            m3 = labeled[i + 1]
        
        labels = [l for l in np.unique(m2) if l]
        for label in labels:
            if m1 is not None:
                overlap = [l for l in np.unique(m1[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
            if m3 is not None:
                overlap = [l for l in np.unique(m3[m2==label]) if l]
                for l in overlap:
                    edges.append((label, l))
        
    return edges

In [None]:
edges = graph_from_labeled(labeled)

In [None]:
import networkx as nx

In [None]:
G = nx.Graph()
G.add_edges_from(edges)

In [None]:
terminals = [k for k, v in G.adj.items() if len(v) == 1]
terminals

In [None]:
[c for c in sorted(nx.connected_components(G), key=len, reverse=True)]

In [None]:
assert len(list(nx.connected_components(G))) == 1

In [None]:
# missing = []
# for i, x in enumerate(labeled):
#     m = np.zeros_like(x, dtype=bool)
#     for node in missing:
#         m |= x==node
#     plt.imshow(m)
#     plt.title(i)
#     plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
nx.draw_networkx(G, font_size=6, node_size=100)

In [None]:
case_to_endpoints = {
    'HT891Z1': {
        'roi1': ('f6032595-1fd0-44c2-82b0-72a3fac80f03', '18073a83-5d22-4c2a-8cc8-e1b27a4de88e')
    }
}

In [None]:
endpoints = case_to_endpoints[case][roi]

In [None]:
n1, n2 = endpoints
path = nx.algorithms.shortest_path(G, region_id_to_label[n1], region_id_to_label[n2])
path

In [None]:
keep = [label_to_region_id[l] for l in path]

#### user defined path

In [None]:
keep = re.split(r'\nObject ID\t', """
Object ID	2a3facce-de6a-43a7-926b-8fbe42ed75b0
Object ID	f5c166e7-d9db-4bad-b55f-73a0f0799d60
Object ID	cb76f6eb-6278-452d-bb4d-9726d650aec9
Object ID	24000119-0a82-45c1-8505-dc90901e014f
Object ID	e5d1214b-8f62-403c-833a-53032515931c
Object ID	cbf8e855-8a31-4990-b144-064f544d3066
Object ID	0070e7b1-9e6b-4cda-af13-b5d6238e97ed
Object ID	686d4c90-0af6-467a-9134-39f483b04c28
Object ID	f8919fee-046b-436a-b728-94c76c3a153b
Object ID	acaef067-ff89-4676-b10c-6b5f4fd6560e
Object ID	cdef0207-f9b5-4a6b-bdf5-bde6df9f9282
Object ID	acaef067-ff89-4676-b10c-6b5f4fd6560e
Object ID	f47c43ed-3e83-4a2f-86d6-aed44f1316c4
Object ID	a29d8f43-06d4-4e0f-a059-8a4b9fbf959f
Object ID	db3c0014-4c41-49db-9d74-e88f28041aee
Object ID	60111347-4dcd-465d-8537-82cb1cc5484e
Object ID	dc5f3e39-0bea-41d3-b152-8d91c0e1ca17
Object ID	34b75982-a611-4e25-a583-bf10aafdd5bc
Object ID	227dad79-9f9d-465b-8543-0605892b42b3
Object ID	c4380e8c-f7ad-4cfc-a4f3-ed430ec2c1bc
Object ID	b5c5ff29-07e4-4ad1-a866-0e3bb78f4283
Object ID	5769757d-7e89-49a6-97e1-8a8419989506
Object ID	7ecb5ca3-161b-41da-bc89-1b64312c5840
Object ID	13b32752-7032-4403-a855-75b2a2320799
Object ID	12c03edc-81d2-432f-9809-34efbed321d3
Object ID	8cc2c7fe-be42-401d-b779-18b77ac156a3
Object ID	406ac3bf-8609-4786-a1df-6540b4970cea""")[1:]
keep

In [None]:
keep.reverse()
keep

In [None]:
case_to_keep = {
    'HT891Z1': {
        'roi2': [
            '4fd912e9-9eda-4290-aab8-8cfb9bfedb43',
            '2ae1e825-b9cd-45c4-bcf1-b227b48404b1',
            'c2b325db-8518-43fa-8653-b6e0de96a436',
            '92d21f1d-968b-4205-bc6f-8791f5fb7942',
            'bc3ec071-235a-47da-a192-49736ece4a1b',
            'cbd83e7a-d4ec-4558-9145-ecbe3494ee0d',
            '89b146dd-2261-4362-ae61-d812b5a2fb8b',
            'a098aa66-da0a-46ef-8659-ae8fdb02b6c9',
            'f55e7184-f872-4f3c-a01c-6072351148d9',
            '98986d9e-1677-4bca-974f-570e96c7d7f0',
            '12c3eac2-8606-41be-892e-689133c3ad80',
            '71817de5-241d-4448-9aad-91df6ded89c1',
            '9e8d366d-95a4-4a9b-a9da-07b5acf951e2',
            '3f382232-fad5-4246-8029-b2ff0fca459d',
            '4dcbec5f-25c4-441e-b603-0c0b45b74c43',
            'ad46ac66-0e8b-48f1-aba8-e300d19278d2',
            '32ce1bff-0284-4514-ba98-04506f20f256',
            '24687041-03b2-4e55-a05b-0b069a3403e5',
            '38ba9ae1-de88-4b38-a29e-cb9cc87acc41',
            '7a3535f1-c30c-40c1-9d31-24c2ee7e58f9',
            '0070e7b1-9e6b-4cda-af13-b5d6238e97ed',
            'cbf8e855-8a31-4990-b144-064f544d3066',
            'e5d1214b-8f62-403c-833a-53032515931c',
            '24000119-0a82-45c1-8505-dc90901e014f',
            'cb76f6eb-6278-452d-bb4d-9726d650aec9',
            'f5c166e7-d9db-4bad-b55f-73a0f0799d60',
            '2a3facce-de6a-43a7-926b-8fbe42ed75b0',
            '69f28c38-daa2-4007-85ff-e5a1cf88e3eb',
            '2feaaa28-3e77-4138-b9a4-0b75f68d2f64',
            'bcbe6b32-baba-4f9e-bf87-45e5f0660327',
            'd715e31a-7607-4ac3-b0c2-9456b9244140',
            '99087c7b-e996-40ed-b0f4-4286822192fa',
            'b6d1644d-743d-4cf4-9fa0-a43d95ea6baa',
            'bf6c40d9-bdf2-49a5-a3fe-b6e68d0b77c0',
            '386d888c-4246-46f5-b7a0-d0a763511097',
            '90116201-5a8a-4f73-be12-495d21dcf084',
            '142c7012-69e4-4f2a-b00b-9d350db99fd6',
            '68f55189-0e6f-41b1-9c65-2823226b8d7e',
            '511a2ee0-a928-47a5-a586-fab0fb648b99',
            'f1034984-615e-45a5-a4af-05f2fc4eb1b4',
            '92b53248-6da0-492a-9c0b-0c2a60946a0d',
            'b536042f-2705-4627-be52-2465ff038195',
            '6d7255ad-dc1a-47c4-b99f-6fc5fc3f0772',
#             '995bd781-223b-4849-a502-640895bcf6d7', has experimental artifact at roi location
            '9ddde78f-2e7b-4259-9bf9-efdd0a005b79',
            '79e3353f-c2fe-4d54-8d30-a8a661cf61a1',
            '2db3f2d8-37ce-4ffa-81ea-7948b9af77ae',
            'c5facae8-79de-4842-b005-ad9e9d23747e'
        ],
        'roi2-p2': [
            '406ac3bf-8609-4786-a1df-6540b4970cea',
            '8cc2c7fe-be42-401d-b779-18b77ac156a3',
            '12c03edc-81d2-432f-9809-34efbed321d3',
            '13b32752-7032-4403-a855-75b2a2320799',
            '7ecb5ca3-161b-41da-bc89-1b64312c5840',
            '5769757d-7e89-49a6-97e1-8a8419989506',
            'b5c5ff29-07e4-4ad1-a866-0e3bb78f4283',
            'c4380e8c-f7ad-4cfc-a4f3-ed430ec2c1bc',
            '227dad79-9f9d-465b-8543-0605892b42b3',
            '34b75982-a611-4e25-a583-bf10aafdd5bc',
            'dc5f3e39-0bea-41d3-b152-8d91c0e1ca17',
            '60111347-4dcd-465d-8537-82cb1cc5484e',
            'db3c0014-4c41-49db-9d74-e88f28041aee',
            'a29d8f43-06d4-4e0f-a059-8a4b9fbf959f',
            'f47c43ed-3e83-4a2f-86d6-aed44f1316c4',
            'acaef067-ff89-4676-b10c-6b5f4fd6560e',
            'cdef0207-f9b5-4a6b-bdf5-bde6df9f9282',
            'acaef067-ff89-4676-b10c-6b5f4fd6560e',
            'f8919fee-046b-436a-b728-94c76c3a153b',
            '686d4c90-0af6-467a-9134-39f483b04c28',
            '0070e7b1-9e6b-4cda-af13-b5d6238e97ed',
            'cbf8e855-8a31-4990-b144-064f544d3066',
            'e5d1214b-8f62-403c-833a-53032515931c',
            '24000119-0a82-45c1-8505-dc90901e014f',
            'cb76f6eb-6278-452d-bb4d-9726d650aec9',
            'f5c166e7-d9db-4bad-b55f-73a0f0799d60',
            '2a3facce-de6a-43a7-926b-8fbe42ed75b0'
        ],
    }
}

## remaining analyses

In [None]:
keep = case_to_keep[case][roi]

In [None]:
for x in keep:
    region = region_id_to_region[x]
    if region.get('dtype', '') == 'xenium':
        print(region['id'], region['position'], region['z'])

In [None]:
import seaborn as sns
colors = sns.color_palette('deep')
colors

In [None]:
dtype_to_color = {
    'xenium': colors[0],
    'multiplex': colors[2],
    'he': colors[4]
}

dtype_to_edgecolor = {
    'xenium': sns.color_palette('dark')[0],
    'multiplex': sns.color_palette('dark')[2],
    'he': sns.color_palette('dark')[4]
}

In [None]:
def plot_regions(rs, facecolors=(.5, .5, .5, 1.), edgecolors=(.2, .2, .2, 1.), linewidths=1, alpha=.7):
    ax = plt.figure(figsize=(10, 10)).add_subplot(projection='3d')

    coords = [x['coordinates'] for x in rs]
    zs = [x['position'] for x in rs]
    poly = PolyCollection(coords, facecolors=facecolors, alpha=alpha, edgecolor=edgecolors, linewidths=linewidths)

    ax.set(xlim=(0, C2 - C1), ylim=(0, R2 - R1), zlim=(0, max(zs)))
    ax.add_collection3d(poly, zs=zs, zdir='z')
    ax.view_init(elev=10., azim=case_to_zeniths[case][roi])
    ax.invert_zaxis()

    return ax

In [None]:
rs = [x for x in regions if x['id'] not in keep]
facecolors = [(.5, .5, .5, 1.)] * len(rs)
edgecolors = [(.2, .2, .2, 1.)] * len(rs)

keep_rs = [x for x in regions if x['id'] in keep]
coords += [x['coordinates'] for x in keep_rs]
zs += [x['position'] for x in keep_rs]
facecolors += [(.5, .5, .5, 1.)] * len(keep_rs)
edgecolors += [(.8, .2, .2, 1.)] * len(keep_rs)

rs += keep_rs

ax = plot_regions(rs, facecolors=facecolors, edgecolors=edgecolors)
plt.savefig(fig_dir / f'{roi}_3d_section_recon.svg')

In [None]:
rs = [x for x in regions if x['id'] in keep]
facecolors = [dtype_to_color[x.get('dtype', 'he')] for x in rs]
edgecolors = [dtype_to_edgecolor[x.get('dtype', 'he')] for x in rs]
ax = plot_regions(rs, facecolors=facecolors, edgecolors=edgecolors)
ax.set_axis_off()

plt.savefig(fig_dir / f'{roi}_3d_section_recon_iso.svg')

In [None]:
import matplotlib as mpl

In [None]:
rs = [x for x in regions if x['z']==6]
vals = np.asarray([x['fracs'][dtype_to_channels['xenium'].index('KRT5')] for x in rs])
cmap = mpl.colormaps['viridis']
facecolors = [cmap(v) for v in vals / vals.max()]
edgecolors = (.6, .6, .6, 1.)
ax = plot_regions(rs, facecolors=facecolors, edgecolors=edgecolors, linewidths=1, alpha=1.)
ax.set_axis_off()

plt.savefig(fig_dir / f'{roi}_3d_section_recon_xenium_z6_krt5.svg')

In [None]:
vals.max()

In [None]:
z_to_sid = {x['z']:x['sid'] for x in regions}
rgb = sid_to_rgb[z_to_sid[6]].copy()

m = np.zeros((rgb.shape[0], rgb.shape[1]), dtype=bool)
for x in regions:
    if x['z'] == 6:
        m |= x['mask']

expanded = m.copy()
for i in range(4):
    expanded = skimage.morphology.binary_dilation(expanded)

inner = m.copy()
for i in range(2):
    inner = skimage.morphology.binary_erosion(inner)
ring = expanded ^ inner
rgb[ring] = [255, 255, 0]

plt.imshow(rgb)
plt.savefig(fig_dir / f'{roi}_xenium_z6_rgb.svg', dpi=300)

In [None]:
plt.imshow(m)

In [None]:
def get_means_and_fracs(region_id_to_regions, paths, dtype='xenium', mean_key='means', frac_key='fracs'):
    results = []
    for path in paths:
        vals, fracs, zs, other_vals, other_fracs = [], [], [], [], []
        for k in path:
            if isinstance(k, int):
                region = region_id_to_region[label_to_region_id[k]]
            else:
                region = region_id_to_region[k]
            if 'dtype' in region and region['dtype'] == dtype:
                vals.append(region[mean_key])
                fracs.append(region[frac_key])
                zs.append(region['position'])
                
                rs = [x for x in region_id_to_region.values() if x['position']==region['position'] and x['id']!=region['id']]
                other_vals.append(np.stack([x[mean_key] for x in rs]))
                other_fracs.append(np.stack([x[frac_key] for x in rs]))
                
        if vals:
            vals, fracs = np.stack(vals), np.stack(fracs)
            vals, fracs = rearrange(vals, 'd n -> n d'), rearrange(fracs, 'd n -> n d')
            results.append({
                'vals': vals,
                'fracs': fracs,
                'other_vals': other_vals,
                'other_fracs': other_fracs,
                'zs': zs,
            })
    return results

import scipy
def get_correlations(vals, default=0.):
    coors, pvals = [], []
    for row in vals:
        r = scipy.stats.pearsonr(row, np.arange(len(row)))
        if not pd.isnull(r.pvalue):
            coors.append(r.correlation)
            pvals.append(r.pvalue)
        else:
            coors.append(default)
            pvals.append(1.)
    return np.asarray(coors), np.asarray(pvals)

In [None]:
result = get_means_and_fracs(region_id_to_region, [keep])[0]

In [None]:
result['vals'].shape

In [None]:
# avg = np.stack([x.mean(0) for x in result['other_vals']]).T
# coors, pvals = get_correlations(result['vals'] - avg)
coors, pvals = get_correlations(result['vals'])
vals, fracs, zs, other_vals, other_fracs = result['vals'], result['fracs'], result['zs'], result['other_vals'], result['other_fracs']

coors.shape

In [None]:
pvals

In [None]:
idxs = np.argsort(coors)
m = np.asarray([x > .2 for x in fracs[idxs].max(-1)])
m &= pvals[idxs] < .05
idxs = idxs[m]

ordered = np.asarray(dtype_to_channels['xenium'])[idxs]

ordered

In [None]:
# n = 15

# data = vals[idxs[:n]]
# data = np.concatenate((data, vals[idxs[-n:]]))
# source = pd.DataFrame(data=data.T, columns=np.concatenate((ordered[:n], ordered[-n:])), index=zs)
data = vals[idxs]
data -= data.min(1, keepdims=True)
source = pd.DataFrame(data=data.T, columns=ordered, index=zs)
source

In [None]:
fig_dir

In [None]:
fig, ax = plt.subplots(figsize=(25, 20))
sns.heatmap(source, cmap='Blues', vmax=1.6)
plt.savefig(fig_dir / f'{roi}_gene_heatmap.svg')

In [None]:
df = pd.DataFrame(coors, index=dtype_to_channels['xenium'], columns=['correlation'])
df['pvals'] = pvals
df = df.loc[source.columns]
df

In [None]:
sns.heatmap(df[['correlation']], cmap='PiYG')
plt.savefig(fig_dir / f'{roi}_gene_heatmap_corrs.svg')

In [None]:
sns.heatmap(df[['pvals']], cmap='cividis_r', vmin=0, vmax=.05)
plt.savefig(fig_dir / f'{roi}_gene_heatmap_pvals.svg')

In [None]:
case_to_view_details = {
    'HT891Z1': {
        'roi1': {
            'radius': 150,
            'thickness': 5,
            'expansion': 15
        },
        'roi2': {
            'radius': 220,
            'thickness': 5,
            'expansion': 15
        },
        'roi2-p2': {
            'radius': 300,
            'thickness': 5,
            'expansion': 15
        }
    }
}

In [None]:
radius = case_to_view_details[case][roi]['radius']
imgs = []
for rid in keep:
    region = region_id_to_region[rid]
    mask = region['mask']
    
    expanded = mask.copy()
    for i in range(case_to_view_details[case][roi]['expansion']):
        expanded = skimage.morphology.binary_dilation(expanded)
    
    rlabeled = skimage.morphology.label(mask)
    prop = skimage.measure.regionprops(rlabeled)[0]
    r, c = prop['centroid']
    r, c = int(r), int(c)
    
    r1 = max(0, r - radius)
    c1 = max(0, c - radius)
    r2 = min(mask.shape[0] - 1, r + radius)
    c2 = min(mask.shape[1] - 1, c + radius)
    
    diam = radius * 2
    if r2 - r1 < diam:
        if r1 == 0:
            r2 += diam - (r2 - r1)
        else:
            r1 -= diam - (r2 - r1)
    if c2 - c1 < diam:
        if c1 == 0:
            c2 += diam - (c2 - c1)
        else:
            c1 -= diam - (c2 - c1)
    
    rgb = sid_to_rgb[region['sid']].copy()
    
    inner = expanded.copy()
    for i in range(case_to_view_details[case][roi]['thickness']):
        inner = skimage.morphology.binary_erosion(inner)
    ring = expanded ^ inner
    rgb[ring] = [255, 255, 0]
    
    
    
    imgs.append(rgb[r1:r2, c1:c2])
len(imgs)

In [None]:

for sid, img in zip(keep, imgs):
    plt.imshow(img)
    plt.title(sid)
    plt.show()

In [None]:
fig, axs = plt.subplots(ncols=len(imgs), figsize=(len(imgs) * 2, 2))

for rid, img, ax in zip(keep, imgs, axs):
    region = region_id_to_region[rid]
    position = region['position']
    
    ax.imshow(img)
    ax.set_axis_off()
    ax.set_aspect('equal')
    
    ax.set_title(f'Z{position}')
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(fig_dir / f'{roi}_all_images.svg', dpi=300)

In [None]:
gene = 'ALDH1A3'
r1, r2, c1, c2 = [x // tiling_resolution for x in [R1, R2, C1, C2]]
for sid, tiled in sid_to_tiled.items():
    dtype = sid_to_dtype[sid]
    rgb = sid_to_rgb[sid]
    z = sid_to_z[sid]
    m = z_to_regions_mask[z]
    mm = skimage.morphology.binary_dilation(m)
    m = mm ^ m
    
    if dtype == 'xenium':
        fig, ax = plt.subplots(ncols=2)
        x = tiled[dtype_to_channels[dtype].index(gene)]
        x = np.stack([x, x, x], -1)
        x[m] = [x.max(), 0, 0]
        x = x[r1:r2, c1:c2]
        x /= x.max()
#         x = utils.rescale(x, size=rgb.shape[:2], dim_order='h w c', target_dtype=x.dtype)
        print(x.shape, m.shape)
#         x[m] = [1, 0, 0]
        ax[0].imshow(x)
        ax[0].set_title(z)
        ax[1].imshow(rgb)
        plt.show()

In [None]:
region = list(region_id_to_region.values())[0]
region.keys()

In [None]:
rlabeled = skimage.morphology.label(region['mask'])
prop = skimage.measure.regionprops(rlabeled)[0]
prop

In [None]:
plt.imshow(region['mask'])

In [None]:
prop.centroid_local

In [None]:
prop.centroid