In [None]:
import os
import re
import json
from pathlib import Path
import yaml

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import scipy
import skimage
import pandas as pd
import tifffile
from einops import rearrange, repeat
from pydantic_extra_types.color import Color

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import mushroom.utils as utils
import mushroom.visualization.utils as vis_utils

In [None]:
project_dir = Path('/data/estorrs/mushroom/data/projects/submission_v1')
region_dir = project_dir / 'analysis' / 'region_characterization'
fig_dir = region_dir / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)
results_dir = region_dir / 'results'
results_dir.mkdir(parents=True, exist_ok=True)

In [None]:
source_root = '/diskmnt/Projects/Users/estorrs/mushroom/data'
target_root = '/data/estorrs/mushroom/data'

def alter_filesystem(config, source_root, target_root):
    for entry in config['sections']:
        for mapping in entry['data']:
            mapping['filepath'] = mapping['filepath'].replace(source_root, target_root)
    
    if 'trainer_kwargs' in config and config['trainer_kwargs']['data_mask'] is not None:
        config['trainer_kwargs']['data_mask'] = config['trainer_kwargs']['data_mask'].replace(source_root, target_root)
        
    return config

In [None]:
fps = sorted(utils.listfiles(region_dir, regex=r'_regions.txt.gz$'))

sid_to_data = {}
for fp in fps:
    name = fp.split('/')[-1]
    sid = re.sub(r'^(.*)_regions.txt.gz$', r'\1', name)
    print(sid)
    case = re.sub(r'^(.*)-U[0-9]+$', r'\1', sid)

    config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
    config = alter_filesystem(config, source_root, target_root)

    imaris_dir = project_dir / case / 'imaris' / 'rois'
    meta = json.load(open(imaris_dir / 'tiled' / 'metadata.json'))
    
    sid_to_channels = meta['sid_to_channels']
    dtype_ident_to_dtype = meta['dtype_ident_to_dtype']
    sid_to_dtype_ident = meta['sid_to_dtype_ident']
    
    full_rgb_no_dir = imaris_dir / 'full_rgbs_no_overlays'

    if sid in sid_to_channels:
        channels = sid_to_channels[sid]
        dti = sid_to_dtype_ident[sid]
        dtype = dti.split('_')[0]
        dtype2 = 'xenium5k' if dtype == 'xenium' and len(channels) > 2000 else dti.split('_')[0]
    else:
        raise RuntimeError()

    if case in ['HT891Z1', 'HT913Z1', 'S18-5591-C8', 'S18-9906']:
        disease = 'prad'
    else:
        disease = 'brca'

    mapping = None
    position = None
    for entry in config['sections']:
        if entry['sid'] == sid:
            for item in entry['data']:
                if item['dtype'] == dtype:
                    mapping = item
                    position = entry['position']
                    break
    assert mapping is not None

    if dtype in ['he', 'multiplex']:
        data = {
            'filepath': mapping['filepath']
        }

    elif dtype in ['xenium', 'cosmx']:
        data = {
            'adata': mapping['filepath'],
            'transcripts': mapping['filepath'].replace('.h5ad', '_transcripts.parquet'),
            'morphology': mapping['filepath'].replace('.h5ad', '_morphologyfocus.ome.tiff')
        }
    elif dtype in ['vishd']:
        data = {
            'adata': mapping['filepath'],
            'he': mapping['filepath'].replace('.h5ad', '_he.ome.tiff')
        }
    else:
        raise RuntimeError(f'dtype {dtype} not valid')

    for k, v in data.items():
        assert os.path.exists(v)
    
    rgb_fp = full_rgb_no_dir / f'{sid}.tif'
    if not os.path.exists(rgb_fp):
        rgb_fp = imaris_dir / 'full_rgbs' / f'{sid}.tif'
        if not os.path.exists(rgb_fp):
            print('no rgb', sid)
            rgb_fp = None

    sid_to_data[sid] = {
        'case': case,
        'sid': sid,
        'position': position,
        'channels': channels,
        'dti': dti,
        'dtype': dtype,
        'dtype2': dtype2,
        'rgb': rgb_fp,
        'labeled_regions': fp.replace('_regions.txt.gz', '_regions.tif'),
        'labeled_boundaries': fp.replace('_regions.txt.gz', '_boundaries.tif'),
        'labeled_tme': fp.replace('_regions.txt.gz', '_tme.tif'),
        'data': data,
    }

In [None]:
sid_to_data['HT913Z1-U1']['rgb']

## visualization functions

In [None]:
def plot_xenium_pseudo(ax, rgb, transcripts, view_settings, s=.1):
    ax.imshow(rgb)
    
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)
        
    pool = set(transcripts['feature_name'])
    for entry in view_settings:
        color = np.asarray(Color(entry['color']).as_rgb_tuple()) / 255.
        channel = entry['channel']
        marker = entry['marker']
        
        if channel in pool:
            small = transcripts[transcripts['feature_name']==channel]
            X = small[['y_location', 'x_location']].values.astype(int)
            ax.scatter(X[:, 1], X[:, 0], color=color, s=s, marker=marker, edgecolors='none')
        # else:
        #     print(f'{channel} not present in xenium')
    return ax

def add_transcripts_to_rgb(rgb, transcripts, xenium_view_settings, pt_scaler=50, bbox=None):
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)

    if bbox is not None:
        r1, r2, c1, c2 = bbox
        rgb = rgb[r1:r2, c1:c2]

        m = (transcripts['y_location'] > r1) & (transcripts['y_location'] < r2)
        m &= (transcripts['x_location'] > c1) & (transcripts['x_location'] < c2)
        transcripts = transcripts[m]
        transcripts['y_location'] -= r1
        transcripts['x_location'] -= c1

    size = (rgb.shape[1] / 1000, rgb.shape[0] / 1000)
    fig, ax = plt.subplots(figsize=(rgb.shape[1] / 1000, rgb.shape[0] / 1000))
    ax.set_axis_off()
    ax = plot_xenium_pseudo(ax, rgb, transcripts, xenium_view_settings, s=size[0] / pt_scaler)
    plt.savefig('temp.tif', bbox_inches='tight', pad_inches=0, dpi=int(1000 * 1.2987411728584588))
    rgb = tifffile.imread('temp.tif')
    os.remove('temp.tif')
    plt.close()
    
    if rgb.shape[-1] == 4:
        rgb = rgb[..., :-1]
    
    return rgb

def get_st_rgb(rgb_fp, transcripts_fp, view_settings, bbox=None, pt_scaler=50, black=False):
    rgb = tifffile.imread(rgb_fp)

    if black:
        rgb[...] = 0
        
    rgb = add_transcripts_to_rgb(
        rgb, transcripts_fp,
        view_settings,
        bbox=bbox,
        pt_scaler=pt_scaler
    )
    return rgb

def get_st_simple(transcripts, view_settings, ax, bbox=None, s=4):
    if not isinstance(transcripts, pd.DataFrame):
        transcripts = pd.read_parquet(transcripts)

    if bbox is not None:
        r1, r2, c1, c2 = bbox

        m = (transcripts['y_location'] > r1) & (transcripts['y_location'] < r2)
        m &= (transcripts['x_location'] > c1) & (transcripts['x_location'] < c2)
        transcripts = transcripts[m]
        transcripts['y_location'] -= r1
        transcripts['x_location'] -= c1

    pool = set(transcripts['feature_name'])
    for entry in view_settings:
        color = np.asarray(Color(entry['color']).as_rgb_tuple()) / 255.
        channel = entry['channel']
        marker = entry['marker']
        
        if channel in pool:
            small = transcripts[transcripts['feature_name']==channel]
            X = small[['y_location', 'x_location']].values.astype(int)
            ax.scatter(X[:, 1], X[:, 0], color=color, s=s, marker=marker, edgecolors='none')
            ax.axis('equal')
            ax.set_xticks([])
            ax.set_yticks([])

In [None]:
def add_bins_to_rgb(rgb, pts, color='white', pt_scaler=.1, marker='s'):
    size = (rgb.shape[1] / 1000, rgb.shape[0] / 1000)
    fig, ax = plt.subplots(figsize=(rgb.shape[1] / 1000, rgb.shape[0] / 1000))
    ax.set_axis_off()
    ax.imshow(rgb)
    ax.scatter(pts[:, 1], pts[:, 0], color=color, s=pt_scaler / 10, marker=marker, edgecolors='none')
    plt.savefig('temp.tif', bbox_inches='tight', pad_inches=0, dpi=int(1000 * 1.2987411728584588))
    plt.close()
    rgb = tifffile.imread('temp.tif')
    os.remove('temp.tif')
    
    if rgb.shape[-1] == 4:
        rgb = rgb[..., :-1]
    
    return rgb

def get_vishd_rgb(rgb, adata, view_settings, pt_scaler=.5):
    target_size = rgb.shape[:-1]
    for entry in view_settings:
        color = entry['color']
        channel = entry['channel']
        marker = entry['marker']
        
        if channel in adata.var.index:
            mask = adata[:, channel].X.toarray().flatten() > 0
            pts = adata.obs[['y_location', 'x_location']].values
        
            filtered_pts = pts[mask]
            vals = adata[mask, channel].X.toarray().flatten()
    
            rgb = add_bins_to_rgb(rgb, filtered_pts, color=color, marker=marker, pt_scaler=pt_scaler)
            rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)
            
        else:
            print(f'{gene} not present in vishd')
    return rgb

In [None]:
case_to_rois = {
    'HT891Z1': {
        'roi1': (1300, 2100, 2500, 3500), # normal -> gp3
        'roi2': (4650, 5650, 1900, 2900), # normal -> gp3
        'roi2-p2': (4650, 5650, 1900, 2900), # PIN -> gp3
    },
    'HT913Z1': {
        'roi1': (4000, 4750, 5800, 7200), # normal -> gp3
        'roi1-p2': (4000, 4750, 5800, 7200) # normal -> gp3
    },
    'HT704B1': {
        'roi1': (2400, 3000, 5800, 6600), # DCIS -> IDC
        'roi2': (3700, 4300, 5750, 6500), # DCIS -> partial IDC
        'roi3': (0, 7250, 0, 9063), # full region trace of dcis
    },
    'HT206B1': {
        'roi1': (1650, 2600, 4900, 5750), # dcis -> idc
        'roi2': (2700, 4300, 3600, 5500) # dcis -> idc
    },
    'HT397B1': {
        'roi1': (2600, 3700, 3200, 4800) # dcis -> idc
    }
}

In [None]:
for sid, data in sid_to_data.items():
    if 'HT206B1' in sid:
        print(sid, data['position'])

In [None]:
case_to_targets = {
    'xenium': {
        'HT891Z1': {
            'roi1': [
                {'sid': 'HT891Z1-U104', 'center': [1850, 3080], 'radius': 110},
                {'sid': 'HT891Z1-U69', 'center': [1820, 2940], 'radius': 130},
                {'sid': 'HT891Z1-U59', 'center': [1770, 2960], 'radius': 210},
                {'sid': 'HT891Z1-U44', 'center': [1630, 2830], 'radius': 130},
                {'sid': 'HT891Z1-U31', 'center': [1590, 2770], 'radius': 120}
            ],
            'roi2': [
                {'sid': 'HT891Z1-U31', 'center': [4820, 2090], 'radius': 145},
                {'sid': 'HT891Z1-U21', 'center': [4850, 2190], 'radius': 145},
                {'sid': 'HT891Z1-U21', 'center': [5310, 2440], 'radius': 380},
                {'sid': 'HT891Z1-U59', 'center': [5370, 2390], 'radius': 210},
                {'sid': 'HT891Z1-U81', 'center': [5420, 2230], 'radius': 120}
            ],
        },
        'HT913Z1': {
            'roi1': [
                {'sid': 'HT913Z1-U21', 'center': [4410, 6590], 'radius': 270},
                {'sid': 'HT913Z1-U61', 'center': [4440, 6340], 'radius': 120},
                {'sid': 'HT913Z1-U81', 'center': [4350, 6180], 'radius': 150},
                {'sid': 'HT913Z1-U101', 'center': [4510, 6170], 'radius': 200},
                {'sid': 'HT913Z1-U21', 'center': [4490, 6750], 'radius': 150},
            ]
        },
        'HT704B1': {
            'roi1': [
                {'sid': 'HT704B1-U1', 'center': [2750, 6350], 'radius': 170},
                {'sid': 'HT704B1-U33', 'center': [2670, 6000], 'radius': 210},
                {'sid': 'HT704B1-U41', 'center': [2810, 5970], 'radius': 130},
                {'sid': 'HT704B1-U50', 'center': [2600, 6040], 'radius': 250},
            ],
        },
        'HT206B1': {
            'roi1': [
                {'sid': 'HT206B1-U8', 'center': [2000, 5300], 'radius': 500},
                {'sid': 'HT206B1-U24', 'center': [2000, 5300], 'radius': 100},
            ],
            'roi2': [
                {'sid': 'HT206B1-U8', 'center': [3650, 4700], 'radius': 500},
                {'sid': 'HT206B1-U24', 'center': [3900, 5080], 'radius': 100},
            ],
        },
    },
    'vishd': {
        'HT891Z1': {
            'roi1': [
                {'sid': 'HT891Z1-U2', 'center': [4820, 2090], 'radius': 200},
                {'sid': 'HT891Z1-U33', 'center': [4820, 2090], 'radius': 200},
            ],
            'roi2': [
                {'sid': 'HT891Z1-U2', 'center': [5000, 2300], 'radius': 200},
                {'sid': 'HT891Z1-U33', 'center': [5000, 2300], 'radius': 200},
            ]
        },
        'HT704B1': {
            'roi1': [
                {'sid': 'HT704B1-U2', 'center': [2750, 6350], 'radius': 170},
                {'sid': 'HT704B1-U51', 'center': [2600, 6040], 'radius': 250},
            ]
        },
    },
    'cosmx': {
        'HT704B1': {
            'roi1': [
                {'sid': 'HT704B1-U14', 'center': [2750, 6350], 'radius': 300},
                {'sid': 'HT704B1-U22', 'center': [2750, 6350], 'radius': 300},
                {'sid': 'HT704B1-U47', 'center': [2750, 6350], 'radius': 300},
                {'sid': 'HT704B1-U56', 'center': [2750, 6350], 'radius': 300},
            ],
        },
        'HT891Z1': {
            'roi1': [
                {'sid': 'HT891Z1-U57', 'center': [1770, 2960], 'radius': 200},
                {'sid': 'HT891Z1-U60', 'center': [1770, 2960], 'radius': 200},
            ],
            'roi2': [
                {'sid': 'HT891Z1-U57', 'center': [2090, 3990], 'radius': 300},
                {'sid': 'HT891Z1-U60', 'center': [2090, 3990], 'radius': 300},
            ]
        },
        
    }
}

In [None]:
case_to_roidata = {}

for dtype, d in case_to_targets.items():
    case_to_roidata[dtype] = {}
    for case, d1 in d.items():
        case_to_roidata[dtype][case] = {}
        for roi, entry in d1.items():
            case_to_roidata[dtype][case][roi] = {}
            for j, d in enumerate(entry):
                print(case, roi, d['sid'], dtype)
                sid = d['sid']
                r, c = d['center']
                radius = d['radius']
    
                bbox = r - radius, r + radius, c - radius, c + radius
                r1, r2, c1, c2 = bbox
            
                data = sid_to_data[sid]

                if dtype in ['xenium', 'cosmx']:
                    transcripts_fp = data['data']['transcripts']
                    rgb_fp = data['rgb']
    
                    rgb = tifffile.imread(rgb_fp)[r1:r2, c1:c2]
        
                    transcripts = pd.read_parquet(transcripts_fp)
                    m = (transcripts['y_location'] > r1) & (transcripts['y_location'] < r2)
                    m &= (transcripts['x_location'] > c1) & (transcripts['x_location'] < c2)
                    transcripts = transcripts[m]
                    transcripts['y_location'] -= r1
                    transcripts['x_location'] -= c1
        
                    case_to_roidata[dtype][case][roi][sid] = {
                        'dtype': dtype,
                        'rgb': rgb,
                        'transcripts': transcripts
                    }
                elif dtype == 'vishd':
                    adata_fp = data['data']['adata']
                    rgb_fp = data['rgb']

                    rgb = tifffile.imread(rgb_fp)[r1:r2, c1:c2]
                    
                    adata = sc.read_h5ad(adata_fp)
                    m = (adata.obs['y_location'] > r1) & (adata.obs['y_location'] < r2)
                    m &= (adata.obs['x_location'] > c1) & (adata.obs['x_location'] < c2)
                    adata = adata[m]
                    adata.obs['y_location'] -= r1
                    adata.obs['x_location'] -= c1

                    case_to_roidata[dtype][case][roi][sid] = {
                        'dtype': dtype,
                        'rgb': rgb,
                        'adata': adata
                    }

In [None]:
# case_to_idx = {case:i for i, case in enumerate(case_to_roidata['xenium'].items())}
# case_to_pos = {case:0 for case in case_to_idx.keys()}
def show_rois(view_settings, figsize=(8, 12), black=False, dtype_to_pt_scaler={}):
    count = 0
    ls = []
    for dtype, d in case_to_roidata.items():
        for case, d1 in d.items():
            for roi, d2 in d1.items():
                for j, (sid, roidata) in enumerate(d2.items()):
                    if dtype in ['xenium', 'vishd', 'cosmx']:
                        count += 1
                        ls.append(len(d2))
                        break

    nrows = count
    ncols = np.max(ls)
    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    i = -1
    for dtype, d in case_to_roidata.items():
        for case, d1 in d.items():
            for roi, d2 in d1.items():
                if len(d2):
                    i += 1
                    for j, (sid, roidata) in enumerate(d2.items()):
                        ax = axs[i, j]
        
                        if dtype in ['xenium', 'cosmx']:
                            transcripts = roidata['transcripts']
                            rgb = roidata['rgb']
                        
                            x = add_transcripts_to_rgb(rgb, transcripts, view_settings, pt_scaler=dtype_to_pt_scaler.get(dtype, 1), bbox=None)
        
                        if dtype == 'vishd':
                            adata = roidata['adata']
                            rgb = roidata['rgb']
                        
                            x = get_vishd_rgb(rgb, adata, view_settings[-1]['channel'], pt_scaler=dtype_to_pt_scaler.get(dtype, 1))
                        
                        ax.imshow(x)
                        ax.set_title(sid)
                        case_to_pos[case] += 1

    for ax in axs.flatten():
        for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(6)
        ax.axis('off')


In [None]:
# view_settings = [
#     {
#         'channel': 'EPCAM',
#         'color': 'red',
#         'marker': '.',
#     },
#     {
#         'channel': 'KRT5',
#         'color': 'green',
#         'marker': '.',
#     },
#     {
#         'channel': 'AMACR',
#         'color': 'white',
#         'marker': 's',
#     }
# ]

# show_rois(view_settings, dtype_to_pt_scaler={'xenium': 1, 'vishd': 5, 'cosmx': .5})

In [None]:
genes = """SNAI2""".split('\n')

In [None]:
for gene in genes:
    print(gene)
    view_settings = [
        {
            'channel': 'EPCAM',
            'color': 'red',
            'marker': '.',
        },
        {
            'channel': 'KRT5', 
            'color': 'limegreen',
            'marker': '.',
        },
        {
            'channel': gene,
            'color': 'white',
            'marker': '.',
        }
    ]

    show_rois(view_settings, dtype_to_pt_scaler={'xenium': .2, 'vishd': 5, 'cosmx': .2})
    plt.show()

In [None]:
view_settings = [
    {
        'channel': 'MFAP5',
        'color': 'cyan',
        'marker': 's',
    },
    {
        'channel': 'SFRP4',
        'color': 'yellow',
        'marker': 's',
    },
    {
        'channel': 'THY1',
        'color': 'purple',
        'marker': 's',
    },
    {
        'channel': 'ASPN',
        'color': 'magenta',
        'marker': 's',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': '.',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': '.',
    },
    # {
    #     'channel': 'CP',
    #     'color': 'purple',
    #     'marker': 'P',
    # },
    # {
    #     'channel': 'CD3E',
    #     'color': 'limegreen',
    #     'marker': '*',
    # },
    # {
    #     'channel': 'CD68',
    #     'color': 'cyan',
    #     'marker': 'D',
    # },
]

VISHD_GENE = 'MFAP5'


In [None]:
view_settings = [
    # {
    #     'channel': 'TCIM',
    #     'color': 'yellow',
    #     'marker': 's',
    # },
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': '.',
    },
    {
        'channel': 'LAMC2',
        'color': 'limegreen',
        'marker': 's',
    },
    {
        'channel': 'TFF1',
        'color': 'white',
        'marker': 's',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': '.',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': '.',
    },
    {
        'channel': 'CAVIN2',
        'color': 'cyan',
        'marker': 's',
    },
    # {
    #     'channel': 'HLA-DQB2',
    #     'color': 'limegreen',
    #     'marker': '.',
    # },
    # {
    #     'channel': 'CD68',
    #     'color': 'cyan',
    #     'marker': 'D',
    # },
]

# VISHD_GENE = 'MFAP5'


In [None]:
# case = 'HT891Z1'
case = 'HT704B1'
# case = 'HT206B1'

stacked, sids = [], []
config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
config = alter_filesystem(config, source_root, target_root)

for entry in config['sections']:
    if entry['sid'] in sid_to_data:
        rgb = tifffile.imread(sid_to_data[entry['sid']]['rgb'])
        stacked.append(rgb)
        sids.append(entry['sid'])

stacked = np.stack(stacked)

stacked.shape

In [None]:
pseudos_rgb_dir = project_dir / case / 'imaris' / 'rois' / 'pseudos_sandbox' / 'view1'
pseudos_rgb_dir.mkdir(parents=True, exist_ok=True)
pseudos_rgb_dir

In [None]:
rgbs = []
target_size = None
for i, (rgb_full, sid) in enumerate(zip(stacked, sids)):
    print(i, sid)
    data = sid_to_data[sid]

    rgb = tifffile.imread(data['rgb'])

    if i == 0:
        target_size = rgb.shape[:-1]
        
    if data['dtype'] in ['cosmx', 'xenium']:
        transcripts = data['data']['transcripts']
        rgb = add_transcripts_to_rgb(rgb, transcripts, view_settings, pt_scaler=50, bbox=None)
    elif data['dtype'] in ['vishd']:
        adata = sc.read_h5ad(data['data']['adata'])
        rgb = get_vishd_rgb(rgb, adata, view_settings, pt_scaler=10)

    if rgb.shape[:2] != target_size:
        rgb = utils.rescale(rgb, size=target_size, dim_order='h w c', target_dtype=rgb.dtype)

    tifffile.imwrite(pseudos_rgb_dir / f'case_Z{i}.tif', rgb, compression='LZW')
    rgbs.append(rgb)

In [None]:
next(iter(sid_to_data.values())).keys()

In [None]:
plt.imshow(rgbs[2])

In [None]:
plt.imshow(rgbs[0][3000:4000, 3000:4000])

In [None]:
fps = sorted(utils.listfiles(f'/data/estorrs/mushroom/data/projects/submission_v1/HT206B1/imaris/rois/results', regex=r'.*_fc.txt.gz'))
fps

In [None]:
def get_fc_ordered(fp):
    fc = pd.read_csv(fp, sep='\t', index_col=0)
    fc = fc.loc[:, ~pd.isnull(fc.iloc[0])]
    fc = fc.T.sort_values(fc.index[0])
    return fc

In [None]:
fc1 = get_fc_ordered('/data/estorrs/mushroom/data/projects/submission_v1/HT206B1/imaris/rois/results/roi1_xenium_0_fc.txt.gz')
fc1

In [None]:
fc2 = get_fc_ordered('/data/estorrs/mushroom/data/projects/submission_v1/HT206B1/imaris/rois/results/roi2_xenium_0_fc.txt.gz')
fc2

In [None]:
fc2.head(20)

In [None]:
genes = ['HMGCS2', 'MYLK', 'S100P', 'VEGFA', 'CP', 'TCIM']

In [None]:
for g in fc2.index.to_list():
    x = fc1.loc[g].iloc[0] if g in fc1.index else None
    y = fc2.loc[g].iloc[0] if g in fc2.index else None
    print(g, x, y)

In [None]:
genes = ['CNN1', 'MYH11', 'MAMDC2', 'VCAN', 'C5orf46', 'APCDD1', 'S100P', 'TOP2A', 'GPRC5A', 'CDK1']

In [None]:
genes = fc1.index.to_list()[:50]
genes

In [None]:
genes = fc1.index.to_list()[-50:][::-1]
genes

In [None]:
region_dir = project_dir / 'analysis' / 'region_characterization'
fig_dir = region_dir / 'figures'
fig_dir.mkdir(parents=True, exist_ok=True)

In [None]:
case_to_view_details = {
    'HT397B1': {
        'roi1': {
            'radius': 549,
            'thickness': 15,
            'expansion': 40
        },
    },
    'HT206B1': {
        'roi1': {
            'radius': 300,
            'thickness': 10,
            'expansion': 25
        },
        'roi2': {
            'radius': 700,
            'thickness': 20,
            'expansion': 50
        },
    },
    'HT704B1': {
        'roi1': {
            'radius': 250,
            'thickness': 10,
            'expansion': 25
        },
        'roi2': {
            'radius': 200,
            'thickness': 10,
            'expansion': 25
        },
        'roi3': {
            'radius': 600,
            'thickness': 20,
            'expansion': 50
        },
    },
    'HT891Z1': {
        'roi1': {
            'radius': 150,
            'thickness': 8,
            'expansion': 20
        },
        'roi2': {
            'radius': 220,
            'thickness': 10,
            'expansion': 25
        },
        'roi2-p2': {
            'radius': 300,
            'thickness': 10,
            'expansion': 25
        }
    },
    'HT913Z1': {
        'roi1': {
            'radius': 300,
            'thickness': 10,
            'expansion': 25
        },
        'roi1-p2': {
            'radius': 300,
            'thickness': 10,
            'expansion': 25
        }
    },
}

In [None]:
case_to_rois = {
    'HT891Z1': {
        'roi1': (1300, 2100, 2500, 3500), # normal -> gp3
        'roi2': (4650, 5650, 1900, 2900), # normal -> gp3
        'roi2-p2': (4650, 5650, 1900, 2900), # PIN -> gp3
    },
    'HT913Z1': {
        'roi1': (4000, 4750, 5800, 7200), # normal -> gp3
        'roi1-p2': (4000, 4750, 5800, 7200) # normal -> gp3
    },
    'HT704B1': {
        'roi1': (2400, 3000, 5800, 6600), # DCIS -> IDC
        'roi2': (3700, 4300, 5750, 6500), # DCIS -> partial IDC
        'roi3': (0, 7250, 0, 9063), # full region trace of dcis
    },
    'HT206B1': {
        'roi1': (1650, 2600, 4900, 5750), # dcis -> idc
        'roi2': (2700, 4300, 3600, 5500) # dcis -> idc
    },
    'HT397B1': {
        'roi1': (2600, 3700, 3200, 4800) # dcis -> idc
    }
}

In [None]:
# case = 'HT704B1'
# roi = 'roi1'

# case = 'HT206B1'
# roi = 'roi2'

case = 'HT891Z1'
roi = 'roi1'

# case = 'HT913Z1'
# roi = 'roi1'

In [None]:
config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
config = alter_filesystem(config, source_root, target_root)

stacked, sids, sids_filtered = [], [], []
config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
config = alter_filesystem(config, source_root, target_root)

for entry in config['sections']:
    sids.append(entry['sid'])
    if entry['sid'] in sid_to_data:
        rgb = tifffile.imread(sid_to_data[entry['sid']]['rgb'])
        stacked.append(rgb)
        sids_filtered.append(entry['sid'])

stacked = np.stack(stacked)

stacked.shape

In [None]:
len(sids)

In [None]:
def load_regions(regions_fp):
    regions = json.load(open(regions_fp))['features']
    for x in regions:
        try:
            np.asarray(x['geometry']['coordinates'])
        except ValueError:
            print(x['id'])
            raise RuntimeError('failed')
            
    regions = [{
        'id': x['id'],
        'z': x['geometry']['plane']['z'] if 'plane' in x['geometry'] else 0,
        'coordinates': np.asarray(x['geometry']['coordinates'])
    } for x in regions]
    regions = [x for x in regions if len(x['coordinates'].shape) == 3]
    for x in regions:
        x['coordinates'] = x['coordinates'][0][:, [1, 0]]
        x['mask'] = skimage.draw.polygon2mask(stacked.shape[1:3], x['coordinates'])

    return regions

In [None]:
imaris_dir = project_dir / case / 'imaris' / 'rois'
regions_fp = imaris_dir / f'{roi}.geojson'
regions = load_regions(regions_fp)
rid_to_region = {r['id']:r for r in regions}

len(regions), regions[0].keys()

In [None]:
to_path = json.load(open('/data/estorrs/mushroom/data/projects/submission_v1/analysis/paths.json'))
keep = to_path[case][roi]
keep

In [None]:
def get_imgs(view_settings):
    radius = case_to_view_details[case][roi]['radius']
    imgs = []
    rids = []
    for rid in keep:
        region = rid_to_region[rid]
        mask = region['mask'].copy()
    
        sid = sids[region['z']]
        data = sid_to_data.get(sid)
        dtype = data['dtype'] if data is not None else None
        rgb = tifffile.imread(data['rgb']) if data is not None else None
        
        roi_r1, roi_r2, roi_c1, roi_c2 = case_to_rois[case][roi]
        
        rlabeled = skimage.morphology.label(mask)
        prop = skimage.measure.regionprops(rlabeled)[0]
        r, c = prop['centroid']
        r, c = int(r), int(c)
        
        r1 = max(0, r - radius)
        c1 = max(0, c - radius)
        r2 = min(mask.shape[0] - 1, r + radius)
        c2 = min(mask.shape[1] - 1, c + radius)
        r1, r2, c1, c2 = r1 + roi_r1, r2 + roi_r1, c1 + roi_c1, c2 + roi_c1
        diam = radius * 2
        if r2 - r1 < diam:
            if r1 == 0:
                r2 += diam - (r2 - r1)
            else:
                r1 -= diam - (r2 - r1)
        if c2 - c1 < diam:
            if c1 == 0:
                c2 += diam - (c2 - c1)
            else:
                c1 -= diam - (c2 - c1)
    
        if data is not None:
            rgb = rgb[r1:r2, c1:c2].copy()
    
            if dtype in ['xenium', 'cosmx']:
                transcripts_fp = data['data']['transcripts']
                transcripts = pd.read_parquet(transcripts_fp)
                m = (transcripts['y_location'] > r1) & (transcripts['y_location'] < r2)
                m &= (transcripts['x_location'] > c1) & (transcripts['x_location'] < c2)
                transcripts = transcripts[m]
                transcripts['y_location'] -= r1
                transcripts['x_location'] -= c1
    
                rgb = add_transcripts_to_rgb(rgb, transcripts, view_settings, pt_scaler=4, bbox=None)
                imgs.append(rgb)
                rids.append(rid)
            if dtype == 'vishd':
                adata_fp = data['data']['adata']
                adata = sc.read_h5ad(adata_fp)
                m = (adata.obs['y_location'] > r1) & (adata.obs['y_location'] < r2)
                m &= (adata.obs['x_location'] > c1) & (adata.obs['x_location'] < c2)
                adata = adata[m]
                adata.obs['y_location'] -= r1
                adata.obs['x_location'] -= c1
    
                rgb = get_vishd_rgb(rgb, adata, view_settings, pt_scaler=1)
                imgs.append(rgb)
                rids.append(rid)
    
        # else:
        #     print(sid)
    return imgs, rids

In [None]:
meta.keys()

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
directory = Path('/data/estorrs/mushroom/data/projects/submission_v1/HT704B1/imaris/rois/tiled')
tiled = tifffile.imread(directory / 'xenium_0.tif')
meta = json.load(open(directory / 'metadata.json'))
channels = meta['dtype_ident_to_channels']['xenium_0']
tiled.shape, len(channels)

In [None]:
def quick_show_tiled(cid, dti, idx=0):
    directory = Path(f'/data/estorrs/mushroom/data/projects/submission_v1/{cid}/imaris/rois/tiled')
    meta = json.load(open(directory / 'metadata.json'))
    channels = meta['dtype_ident_to_channels'][dti]
    # dti = meta['sid_to_dtype_ident'][sid]
    fp = meta['dtype_ident_to_tiled_fps'][dti]

    img = tifffile.imread(fp)[idx]
    assert img.shape[0] == len(channels)

    return img, channels

def quick_show_transcripts(sid, bbox, view_settings):
    data = sid_to_data.get(sid)
    dtype = data['dtype'] if data is not None else None
    rgb = tifffile.imread(data['rgb']) if data is not None else None
    r1, r2, c1, c2 = bbox
    rgb = rgb[r1:r2, c1:c2].copy()
    
    if dtype in ['xenium', 'cosmx']:
        transcripts_fp = data['data']['transcripts']
        transcripts = pd.read_parquet(transcripts_fp)
        m = (transcripts['y_location'] > r1) & (transcripts['y_location'] < r2)
        m &= (transcripts['x_location'] > c1) & (transcripts['x_location'] < c2)
        transcripts = transcripts[m]
        transcripts['y_location'] -= r1
        transcripts['x_location'] -= c1

        rgb = add_transcripts_to_rgb(rgb, transcripts, view_settings, pt_scaler=8, bbox=None)

    if dtype == 'vishd':
        adata_fp = data['data']['adata']
        adata = sc.read_h5ad(adata_fp)
        m = (adata.obs['y_location'] > r1) & (adata.obs['y_location'] < r2)
        m &= (adata.obs['x_location'] > c1) & (adata.obs['x_location'] < c2)
        adata = adata[m]
        adata.obs['y_location'] -= r1
        adata.obs['x_location'] -= c1

        rgb = get_vishd_rgb(rgb, adata, view_settings, pt_scaler=2)

    return rgb

In [None]:
# xen_img, xen_channels = quick_show_tiled('HT704B1', 'xenium_0')
# cos_img, cos_channels = quick_show_tiled('HT704B1', 'cosmx_0')

xen_img, xen_channels = quick_show_tiled('HT891Z1', 'xenium_0')
cos_img, cos_channels = quick_show_tiled('HT891Z1', 'cosmx_0')

In [None]:
# bbox = (4000, 5500, 5500, 7500)
bbox = (4500, 6000, 2000, 3500)

In [None]:
def display_gene(gene):
    if gene in xen_channels:
        r1, r2, c1, c2 = (np.asarray(bbox).astype(float) / 10).astype(int)
        plt.imshow(xen_img[xen_channels.index(gene), r1:r2, c1:c2])
        plt.axis('off')
        plt.title('xenium')
        plt.show()
    
    if gene in cos_channels:
        r1, r2, c1, c2 = (np.asarray(bbox).astype(float) / 10).astype(int)
        plt.imshow(cos_img[cos_channels.index(gene), r1:r2, c1:c2])
        plt.axis('off')
        plt.title('cosmx')
        plt.show()
    
    view_settings = [
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 'o',
        },
        {
            'channel': gene,
            'color': 'yellow',
            'marker': 's',
        },
    ]
    
    # rgb = quick_show_transcripts('HT704B1-U2', bbox, view_settings)
    rgb = quick_show_transcripts('HT891Z1-U2', bbox, view_settings)
    plt.imshow(rgb)
    plt.axis('off')
    plt.title('vishd')
    plt.show()

In [None]:
genes = """CEACAM1""".split('\n')


In [None]:
for gene in genes:
    print(gene)
    display_gene(gene)
    plt.show()

In [None]:
# genes = """MGP VIM EGFR SOX9 LTF CLU APP PTN CX3CL1 PPP1R1B CDH1 LAMC2 PLAT MMP7 ANPEP ITGA2 MS4A6A EHF C1QB TACSTD2""".split(' ')
# genes += """KRT8 SPP1 PGR AR FASN GPRC5A KLF2 LDLR SCD DUSP4 AREG KRT23 FZD1 COX6C KRT19 TOP2A PCNA MLPH TCIM KRT7""".split(' ')
# genes += """TPM2 TAGLN CALD1 MYL9 TCF7 MRC1 SNAI2 CD163 CLEC14A CSF1R AIF1 PDGFRB IL3RA PECAM1 CD34 ADGRL4 IGFBP7 SERPINE1 VWF HAS3""".split(' ')
# genes += """MFAP5 THBS2 SFRP4 ASPN VCAM1 MMP2 MSR1 BGN S100A4 FAP IL32 RYR3 FBN1 CCDC80 SFRP2 IFI8 IGHG1 BASP1 THY1 HAVCR2""".split(' ')
# genes = ['TFF1', 'PIP', 'PLAT', 'MGP', 'EHF', 'LAMC2', 'CLU', 'ANPEP', 'MS4A6A', 'TCIM', 'FASN', 'KRT8', 'CD44', 'GLUL', 'MLPH', 'SPP1']

# genes = ['ANXA2', 'RUNX1', 'CLDN1', 'CD74', 'SDC4', 'SMPD3', 'SCPEP1', 'CYP27A1', 'KRT15', 'TCIM', 'PMP22', 'SOX9', 'SMAD3', 'EGFR', 'PTPRU', 'GATA3', 'FHL2', 'ID3', 'KRT19', 'NTN4']
# genes += ['LAMC2', 'KLF6', 'ERBB2', 'SLC2A1', 'GPC1', 'FAS', 'PLAT', 'AQP3', 'CDKN1A', 'SOX9', 'CD44', 'MDM2', 'KRT19', 'TENT5C', 'EGFR', 'GATA3', 'EHF', 'KRT15', 'FHL2', 'NTN4']
# genes += ['TFF1', 'PLAT', 'MGP', 'KIT', 'PIP', 'HMGCS2']

# genes = ['TMEFF2', 'PPP1R1B', 'FBP1', 'GRK2', 'MYC', 'PRMT1', 'SCIN', 'ANPEP', 'GDF15', 'FOXA1', 'MLPH', 'ALDH1A3', 'TFF3', 'STEAP1', 'EPCAM', 'SLC39A6', 'KLK2', 'AMACR', 'TMPRSS2', 'NKX3-1']
# genes += ['ONECUT2', 'KLK3', 'IGF1', 'GATA2', 'EPCAM', 'MLPH', 'EPHA6', 'ANPEP', 'PEBP4', 'SLC39A6', 'PLA2G7', 'NPDC1', 'MYB', 'PC1', 'GLYATL1', 'GPR160', 'AR', 'ARF', 'GEF3', 'ALDH1A3', 'SCD', 'SPDEF']
# genes += ['TCIM', 'MKI67', 'VEGFA', 'CP', 'KRT8', 'GLUL', 'FASN']


# genes = ['FAS', 'FHL2', 'KRT19', 'LAMC2', 'NTN4', 'SOX9', 'TCIM', 'EDN1', 'MET', 'AMY2A']
# genes = ['NOS1', 'UPK1A', 'GABRP', 'MACC1', 'SCNN1A', 'SGK1', 'BCL11A', 'NAV2', 'FLRT3', 'DUOXA1', 'SORL1', 'FAT2', 'TRV4', 'DMKN']

# genes = ['PLA2G7', 'SERPINA3', 'ALDH1A3', 'AMACR', 'ANPEP', 'AR', 'CP', 'EPHA6', 'FOXA1', 'GDF15', 'GLYATL1', 'KLK3', 'MYC', 'NKX3-1', 'NPDC1', 'PPP1R1B', 'SCD', 'SCIN', 'SPDEF', 'STEAP1', 'VEGFA', 'ARFGEF3', 'STEAP4']

# genes = ['DUOXA1', 'UPK1A']

# genes = ['SERPINA3', 'OPN1LW', 'IL1F10', 'COL10A1', 'SERPINC1', 'EEF1A2', 'COMP', 'MKRN3', 'FUT3', 'ELMOD1', 'GHSR', 'NRXN2', 'ADAM', 'TS2', 'GRIN2D', 'HGF', 'GGT1', 'COL4A1', 'SFRP4', 'MYO15A', 'PLA2G7', 'SLC7A11', 'PGR', 'CCNA2', 'IHH', 'PTGER2', 'POSTN', 'EBF2', 'ITGAX', 'OPRK1', 'PROS1']
# genes += ['ALDH1A3', 'AMACR', 'ANPEP', 'AR', 'EPHA6', 'FOXA1', 'GDF15', 'GYLATL1', 'KLK3', 'MYC', 'NXX3-1', 'NPDC1', 'PPP1R1B', 'SCD', 'SCIN', 'STEAP1', 'SPDEF', 'ARFGEF3', 'STEAP4']
# genes += ['NKX3-1', 'DNAH5', 'FSTL1', 'TMTC4', 'FOXA1', 'ST14', 'GLUD1', 'NEDD4L', 'FASN', 'ABAT', 'ACLY', 'FEV', 'TMPRSS2', 'SORD', 'VSTM2L', 'EPCAM', 'SPON2', 'TRPM4', 'ABCC4', 'TMEFF2']


# genes = ['MAFB', 'MVP', 'FHL2', 'EGFR', 'SCNN1A', 'HOXD10', 'MEIS1', 'CASP1', 'P2RY2', 'ANXA2', 'SCPEP1', 'NOTCH1', 'TMEM173', 'SLC40A1', 'MYOF', 'CYP27A1', 'MEIS2', 'MMP14', 'GATA3', 'NTN4']

# genes = ['FLRT3', 'DUOXA1', 'UPK1A']
# genes = ['FLRT3', 'EPHA6']
# genes = ['SLC40A1', 'NPCD1']
genes = ['TFF1', 'TCIM', 'MGP', 'PLAT', 'HMGCS2', 'HK2', 'CAMK2N1', 'GFRA1', 'LGALS1', 'COX6B2', 'VEGFA', 'FASN', 'CA9', 'CDKN1A']
genes = ['KRT8', 'CD44']



genes = sorted(set(genes))

genes

In [None]:
# TFF1, TCIM, MGP, PLAT, HMGCS2, HK2
# CAMK2N1
# GFRA1
# LGALS1
# COX6B2


In [None]:
view_settings = [
    # {
    #     'channel': 'EPCAM',
    #     'color': 'red',
    #     'marker': '.',
    # },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    # {
    #     'channel': 'TP63',
    #     'color': 'white',
    #     'marker': 'o',
    # },
]


for gene in genes:
    view = {
        'channel': gene,
        'color': 'cyan',
        'marker': 's',
    }
    imgs, _ = get_imgs(view_settings + [view])
    fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
    for ax, img in zip(axs, imgs):
        ax.imshow(img)
        ax.axis('off')
    plt.title(gene)
    plt.savefig(fig_dir / f'geneplots_{case}_{roi}_{gene}.svg')
    plt.show()


In [None]:
view_settings = [
    # {
    #     'channel': 'EPCAM',
    #     'color': 'red',
    #     'marker': '.',
    # },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    # {
    #     'channel': 'TP63',
    #     'color': 'white',
    #     'marker': 'o',
    # },
]


for gene in genes:
    view = {
        'channel': gene,
        'color': 'cyan',
        'marker': 's',
    }
    imgs, _ = get_imgs(view_settings + [view])
    fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
    for ax, img in zip(axs, imgs):
        ax.imshow(img)
        ax.axis('off')
    plt.title(gene)
    plt.savefig(fig_dir / f'geneplots_{case}_{roi}_{gene}.svg')
    plt.show()


In [None]:
view_settings = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': '.',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': '.',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': '.',
    },
]


for gene in genes:
    view = {
        'channel': gene,
        'color': 'cyan',
        'marker': 's',
    }
    imgs, _ = get_imgs(view_settings + [view])
    fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
    for ax, img in zip(axs, imgs):
        ax.imshow(img)
        ax.axis('off')
    plt.title(gene)
    plt.savefig(fig_dir / f'geneplots_{case}_{roi}_{gene}.svg')
    plt.show()


In [None]:
view_settings = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': '.',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': '.',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': '.',
    },
]


for gene in genes:
    view = {
        'channel': gene,
        'color': 'cyan',
        'marker': 's',
    }
    imgs, _ = get_imgs(view_settings + [view])
    fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
    for ax, img in zip(axs, imgs):
        ax.imshow(img)
        ax.axis('off')
    plt.title(gene)
    plt.savefig(fig_dir / f'geneplots_{case}_{roi}_{gene}.svg')
    plt.show()


In [None]:
1

In [None]:
view_settings = [
    # {
    #     'channel': 'EPCAM',
    #     'color': 'red',
    #     'marker': '.',
    # },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    # {
    #     'channel': 'CD68',
    #     'color': 'limegreen',
    #     'marker': 'o',
    # },
    # {
    #     'channel': 'CD14',
    #     'color': 'purple',
    #     'marker': 'o',
    # },
    # {
    #     'channel': 'CD163',
    #     'color': 'yellow',
    #     'marker': 'o',
    # },
    {
        'channel': 'SPP1',
        'color': 'orange',
        'marker': 'o',
    },
    {
        'channel': 'MS4A6A',
        'color': 'cyan',
        'marker': 'o',
    },
]

imgs, _ = get_imgs(view_settings)
fig, axs = plt.subplots(ncols=len(imgs), figsize=(40, 10))
for ax, img in zip(axs, imgs):
    ax.imshow(img)
    ax.axis('off')
plt.savefig(fig_dir / f'geneplots_{case}_{roi}_macrophage.svg', dpi=300)

In [None]:
# cases = ['HT206B1', 'HT704B1', 'HT397B1']
# dfs = []
# for case in cases:
#     fps = sorted(utils.listfiles(f'/data/estorrs/mushroom/data/projects/submission_v1/{case}/imaris/rois/results', regex=r'tme.*txt.gz'))
#     for fp in fps:
#         df = pd.read_csv(fp, sep='\t', index_col=0)
#         df['ident'] = fp.split('/')[-1]
#         dfs.append(df)
# df = pd.concat(dfs)
# df


In [None]:
cases = ['HT206B1', 'HT704B1', 'HT397B1']
case_to_df = {}

for c in cases:
    case_to_df[c] = {}
    fps_pool = fps = sorted(utils.listfiles(f'/data/estorrs/mushroom/data/projects/submission_v1/{c}/imaris/rois/results', regex=r'.*_expression.txt.gz'))
        
    for fp in fps:
        name = fp.split('/')[-1]
        if 'tme' in name:
            r = name.split('_')[1]
            dti = '_'.join(name.split('_')[2:4])
            key = 'tme'
        else:    
            r = name.split('_')[0]
            dti = '_'.join(name.split('_')[1:3])
            key = 'tumor'

        df = pd.read_csv(fp, sep='\t', index_col=0)
        if r not in case_to_df[c]:
            case_to_df[c][r] = {}
        if key not in case_to_df[c][r]:
            case_to_df[c][r][key] = {}

        print(c, r, key, dti)
        case_to_df[c][r][key][dti] = df



In [None]:
meta = json.load(open('/data/estorrs/mushroom/data/projects/submission_v1/' + case + '/imaris/rois/tiled/metadata.json'))
sid_to_channels = meta['sid_to_channels']
dtype_ident_to_dtype = meta['dtype_ident_to_dtype']
dtype_ident_to_channels = meta['dtype_ident_to_channels']
sid_to_dtype_ident = meta['sid_to_dtype_ident']
# tiling_size, tiling_size = meta['tiling_size'], meta['tiling_size']
# size = meta['size']
# fullres_size = meta['fullres_size']
dtype_ident_to_tiled_fps = meta['dtype_ident_to_tiled_fps']
dtype_ident_to_tiled = {dti:tifffile.imread(fp) for dti, fp in dtype_ident_to_tiled_fps.items()}
for dti, tiled in dtype_ident_to_tiled.items():
    print(dti, tiled.shape)

In [None]:
len(dtype_ident_to_channels['cosmx_0'])

In [None]:
dti_to_std = {}
for dti, channels in dtype_ident_to_channels.items():
    if dti in dtype_ident_to_tiled:
        dti_to_std[dti] = dtype_ident_to_tiled[dti].std((0, 2, 3))

to_stack = {}
for entry in config['sections']:
    for m in entry['data']:
        dtype = m['dtype']
        dti = f'{dtype}_0'
        if dtype in ['vishd', 'visium']:
            if dti not in to_stack:
                to_stack[dti] = []
            adata = sc.read_h5ad(m['filepath'])
            adata.var_names_make_unique()
            adata = adata[:, dtype_ident_to_channels[dti]]
            means = adata.X.mean(0)
            std = np.asarray(adata.X.power(2).mean(0)) - np.asarray(adata.X.mean(0))**2
            to_stack[dti].append(std)
for dti, stack in to_stack.items():
    dti_to_std[dti] = np.stack(stack).mean(0).flatten()


In [None]:
dti_to_std.keys()

In [None]:
tumor_genes = """MGP VIM EGFR SOX9 LTF CLU APP PTN CX3CL1 PPP1R1B CDH1 LAMC2 PLAT MMP7 ANPEP ITGA2 MS4A6A EHF C1QB TACSTD2 KRT8 SPP1 PGR AR FASN GPRC5A KLF2 LDLR SCD DUSP4 AREG KRT23 FZD1 COX6C KRT19 TOP2A PCNA MLPH TCIM KRT7""".split(' ')
tme_genes = """TPM2 TAGLN CALD1 MYL9 TCF7 MRC1 SNAI2 CD163 CLEC14A CSF1R AIF1 PDGFRB IL3RA PECAM1 CD34 ADGRL4 IGFBP7 SERPINE1 VWF HAS3 MFAP5 THBS2 SFRP4 ASPN VCAM1 MMP2 MSR1 BGN S100A4 FAP IL32 RYR3 FBN1 CCDC80 SFRP2 IFI8 IGHG1 BASP1 THY1 HAVCR2""".split(' ')


cmap_mapping = {
    'xenium_0': 'Blues',
    'xenium_1': 'Blues',
    'cosmx_0': 'Oranges',
    'vishd_0': 'Greens',
    'visium_0': 'Reds'
}


In [None]:
def plot_roi_heatmaps(genes, key, cmap=None, d=None):

    if d is None:
        d = case_to_df[case][roi][key]

    fig, axs = plt.subplots(nrows=len(d), figsize=(20, 15))
    if len(d) == 1:
        axs = [axs]
    for (k, v), ax in zip(d.items(), axs):
        print(k)
        stds = dti_to_std[k]
        f = v.loc[[k for k in keep if k in v.index]]
        gs = [g for g in genes if g in f.columns]
        f = f[gs]
        f /= stds[[dtype_ident_to_channels[k].index(c) for c in gs]]
    
        for g in genes:
            if g not in f.columns:
                f[g] = np.nan
        f = f[genes]
    
        
        if f.shape[0]:
            if cmap is None:
                cm = cmap_mapping[k]
            else:
                cm = cmap
            sns.heatmap(f, cmap=cm, ax=ax)
    return axs


In [None]:
pool =  [x for v in case_to_df[case][roi]['tumor'].values() for x in v.index]
for k in keep:
    if k in pool:
        print(k, rid_to_region[k]['z'])

In [None]:
plot_roi_heatmaps(tumor_genes, 'tumor')
plt.savefig(fig_dir / f'{case}_{roi}_tumor_heatmaps.svg')

In [None]:
plot_roi_heatmaps(tme_genes, 'tme')
plt.savefig(fig_dir / f'{case}_{roi}_tme_heatmaps.svg')

In [None]:
# gene = 'TFF2'
gene = 'CAVIN2'
view = {
    'channel': gene,
    'color': 'cyan',
    'marker': 's',
}
VISHD_GENE = gene
imgs, rids = get_imgs(view_settings + [view])
fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
for ax, img, rid in zip(axs, imgs, rids):
    ax.imshow(img)
    ax.set_title(rid, rotation=90)
    ax.axis('off')


In [None]:
pdf = een.get_pathway_enrichment(
        ['PIP', 'MS4A6A', 'ANPEP', 'PLAT', 'LAMC2', 'CLU', 'SOX9', 'MGP', 'TCIM', 'KRT23', 'AREG', 'DUSP4', 'FASN', 'PGR', 'SPP1', 'KRT8', 'CD44'],
        gene_set_library='GO_Biological_Process_2021')
for i, row in pdf.iloc[:100].iterrows():
    print(row['Term name'], row['Overlapping genes'])

In [None]:
selected_genes = ['TFF1', 'PIP', 'PLAT', 'MGP', 'EHF', 'LAMC2', 'CLU', 'ANPEP', 'MS4A6A', 'TCIM', 'FASN', 'KRT8', 'CD44', 'GLUL', 'MLPH', 'SPP1']
plot_roi_heatmaps(selected_genes, 'tumor')
plt.savefig(fig_dir / f'{case}_{roi}_selected_tumor_heatmaps.svg')

In [None]:
case_to_df.keys()

In [None]:

selected_view = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': 'o',
    },
    {
        'channel': 'PIP',
        'color': 'teal',
        'marker': 'P',
    },
    {
        'channel': 'PLAT',
        'color': 'blue',
        'marker': '*',
    },
    {
        'channel': 'EHF',
        'color': 'magenta',
        'marker': '.',
    },
    {
        'channel': 'LAMC2',
        'color': 'yellow',
        'marker': 'D',
    },
    {
        'channel': 'CLU',
        'color': 'limegreen',
        'marker': 'v',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'TFF1',
        'color': 'cyan',
        'marker': 's',
    },
]
imgs, rids = get_imgs(selected_view)
fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
for ax, img, rid in zip(axs, imgs, rids):
    ax.imshow(img)
    ax.axis('off')
plt.savefig(fig_dir / f'{case}_{roi}_set1_rgbs.svg', dpi=300)

In [None]:

selected_view = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': 'o',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'SPP1',
        'color': 'cyan',
        'marker': 'D',
    },
    {
        'channel': 'MS4A6A',
        'color': 'yellow',
        'marker': 's',
    },
    {
        'channel': 'CD68',
        'color': 'limegreen',
        'marker': '*',
    },
    {
        'channel': 'CD163',
        'color': 'magenta',
        'marker': '*',
    },
]
imgs, rids = get_imgs(selected_view)
fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
for ax, img, rid in zip(axs, imgs, rids):
    ax.imshow(img)
    ax.axis('off')
plt.savefig(fig_dir / f'{case}_{roi}_set3_rgbs.svg', dpi=300)

In [None]:

selected_view = [
    {
        'channel': 'EPCAM',
        'color': 'red',
        'marker': 'o',
    },
    {
        'channel': 'KRT5',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'TP63',
        'color': 'white',
        'marker': 'o',
    },
    {
        'channel': 'FASN',
        'color': 'yellow',
        'marker': '.',
    },
    {
        'channel': 'CD44',
        'color': 'magenta',
        'marker': '.',
    },
    {
        'channel': 'MLPH',
        'color': 'teal',
        'marker': '.',
    },
    {
        'channel': 'GLUL',
        'color': 'orange',
        'marker': '.',
    },
    {
        'channel': 'KRT8',
        'color': 'cyan',
        'marker': '.',
    },
    {
        'channel': 'TCIM',
        'color': 'limegreen',
        'marker': '.',
    },
]
imgs, rids = get_imgs(selected_view)
fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
for ax, img, rid in zip(axs, imgs, rids):
    ax.imshow(img)
    ax.axis('off')
plt.savefig(fig_dir / f'{case}_{roi}_set2_rgbs.svg', dpi=300)

In [None]:
view_dict = {
    # 'view1': [ # HT704B1 roi1 
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'TCIM',
    #         'color': 'orange',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'TFF1',
    #         'color': 'cyan',
    #         'marker': 's',
    #     },
    # ],
    # 'view2': [ # HT704B1 roi1 
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'MKI67',
    #         'color': 'yellow',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'HMGCS2',
    #         'color': 'blue',
    #         'marker': 's',
    #     },
    # ],
    # 'view3': [ # HT704B1 roi2 
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'TCIM',
    #         'color': 'orange',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'HMGCS2',
    #         'color': 'blue',
    #         'marker': 's',
    #     },
    # ],
    # 'view4': [ # HT704B1 roi2
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'VEGFA',
    #         'color': 'yellow',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'PLAT',
    #         'color': 'limegreen',
    #         'marker': 's',
    #     },
    # ],
    # 'view5': [ # HT206B1 roi1
    #     {
    #         'channel': 'ACTA2',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'MYC',
    #         'color': 'orange',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'CA9',
    #         'color': 'cyan',
    #         'marker': 's',
    #     },
    # ],
    # 'view6': [ # HT206B1 roi1
    #     {
    #         'channel': 'ACTA2',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'MKI67',
    #         'color': 'yellow',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'CDKN1A',
    #         'color': 'limegreen',
    #         'marker': 's',
    #     },
    # ],
    'view7': [ # HT206B1 roi2
        {
            'channel': 'ACTA2',
            'color': 'white',
            'marker': 'o',
        },
        {
            'channel': 'MYC',
            'color': 'orange',
            'marker': 'D',
        },
        {
            'channel': 'CA9',
            'color': 'blue',
            'marker': 's',
        },
    ]
}

In [None]:
view_dict = {
    # 'view1': [ # HT891Z1 roi1 and 2
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'MYC',
    #         'color': 'orange',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'FAS',
    #         'color': 'cyan',
    #         'marker': 's',
    #     },
    # ],
    # 'view2': [ # HT891Z1 roi1 
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'GDF15',
    #         'color': 'yellow',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'SOX9',
    #         'color': 'limegreen',
    #         'marker': 's',
    #     },
    # ],
    # 'view3': [ # HT913Z1 roi1 and roi1-p2
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'SERPINA3',
    #         'color': 'yellow',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'SORL1',
    #         'color': 'cyan',
    #         'marker': 's',
    #     },
    # ],
    # 'view4': [ 
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'DNAH5',
    #         'color': 'orange',
    #         'marker': 'D',
    #     },
    #     {
    #         'channel': 'NTN4',
    #         'color': 'limegreen',
    #         'marker': 's',
    #     },
    # ],
    'view5': [ # HT891Z1 roi1 
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 'o',
        },
        # {
        #     'channel': 'DNAH5',
        #     'color': 'orange',
        #     'marker': 'D',
        # },
        {
            'channel': 'VEGFA',
            'color': 'cyan',
            'marker': 's',
        },
    ],
}

In [None]:
view_dict = {
    'brca_comb3': [
        {
            'channel': 'KRT5',
            'color': 'white',
            'marker': 'o',
        },
        {
            'channel': 'CD44',
            'color': '#FF0000',
            'marker': 's',
        },
        {
            'channel': 'HMGCS2',
            'color': 'blue',
            'marker': 's',
        },
    ],
    # 'brca_comb4': [
    #     {
    #         'channel': 'KRT5',
    #         'color': 'white',
    #         'marker': 'o',
    #     },
    #     {
    #         'channel': 'LGALS1',
    #         'color': 'yellow',
    #         'marker': 's',
    #     },
    #     {
    #         'channel': 'GFRA1',
    #         'color': 'limegreen',
    #         'marker': 's',
    #     },
    # ],
}

In [None]:
# # genes = ['ANXA2', 'RUNX1', 'CLDN1', 'CD74', 'SDC4', 'SMPD3', 'SCPEP1', 'CYP27A1', 'KRT15', 'TCIM', 'PMP22', 'SOX9', 'SMAD3', 'EGFR', 'PTPRU', 'GATA3', 'FHL2', 'ID3', 'KRT19', 'NTN4']
# # genes += ['TFF1', 'PLAT', 'MGP', 'KIT', 'PIP', 'HMGCS2']

# view = [
#     {
#         'channel': 'KRT5',
#         'color': 'white',
#         'marker': 'o',
#     },
# ]
    
# view_dict = {g:view + [{'channel': g, 'color': 'cyan', 'marker': 's'}] for g in genes}

In [None]:
# for case, d1 in case_to_rois.items():
# for case in ['HT206B1', 'HT704B1', 'HT397B1']:
# for case in ['HT704B1', 'HT206B1']:
# for case in ['HT206B1']:
# for case in ['HT704B1']:
for view_name, selected_view in view_dict.items():
    # for case in ['HT704B1', 'HT206B1']:
    for case in ['HT704B1']:
    # for case in ['HT891Z1', 'HT913Z1']:
    # for case in ['HT891Z1']:
    # for case in ['HT913Z1']:
    # for case in ['HT704B1', 'HT891Z1']:
        d1 = case_to_rois[case]
    
        config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
        config = alter_filesystem(config, source_root, target_root)
        
        stacked, sids, sids_filtered = [], [], []
        config = yaml.safe_load(open(project_dir / case / 'registered' / 'metadata.yaml'))
        config = alter_filesystem(config, source_root, target_root)
        
        for entry in config['sections']:
            sids.append(entry['sid'])
            if entry['sid'] in sid_to_data:
                rgb = tifffile.imread(sid_to_data[entry['sid']]['rgb'])
                stacked.append(rgb)
                sids_filtered.append(entry['sid'])
        
        stacked = np.stack(stacked)
    
        meta = json.load(open('/data/estorrs/mushroom/data/projects/submission_v1/' + case + '/imaris/rois/tiled/metadata.json'))
        sid_to_channels = meta['sid_to_channels']
        dtype_ident_to_dtype = meta['dtype_ident_to_dtype']
        dtype_ident_to_channels = meta['dtype_ident_to_channels']
        sid_to_dtype_ident = meta['sid_to_dtype_ident']
        # tiling_size, tiling_size = meta['tiling_size'], meta['tiling_size']
        # size = meta['size']
        # fullres_size = meta['fullres_size']
        dtype_ident_to_tiled_fps = meta['dtype_ident_to_tiled_fps']
        dtype_ident_to_tiled = {dti:tifffile.imread(fp) for dti, fp in dtype_ident_to_tiled_fps.items()}
        for dti, tiled in dtype_ident_to_tiled.items():
            print(dti, tiled.shape)
    
        dti_to_std = {}
        for dti, channels in dtype_ident_to_channels.items():
            if dti in dtype_ident_to_tiled:
                dti_to_std[dti] = dtype_ident_to_tiled[dti].std((0, 2, 3))
        
        to_stack = {}
        for entry in config['sections']:
            for m in entry['data']:
                dtype = m['dtype']
                dti = f'{dtype}_0'
                if dtype in ['vishd', 'visium']:
                    if dti not in to_stack:
                        to_stack[dti] = []
                    adata = sc.read_h5ad(m['filepath'])
                    adata.var_names_make_unique()
                    adata = adata[:, dtype_ident_to_channels[dti]]
                    if 'ArrayView' not in str(type(adata.X)):
                        std = np.asarray(adata.X.power(2).mean(0)) - np.asarray(adata.X.mean(0))**2
                    else:
                        std = np.asarray((adata.X**2).mean(0)) - np.asarray(adata.X.mean(0))**2
                    to_stack[dti].append(std)
        for dti, stack in to_stack.items():
            dti_to_std[dti] = np.stack(stack).mean(0).flatten()
    
        
        for roi, d2 in d1.items():
            if roi != 'roi3' and case != 'HT397B1':
                print(case, roi)
                keep = to_path[case][roi]
                imaris_dir = project_dir / case / 'imaris' / 'rois'
                regions_fp = imaris_dir / f'{roi}.geojson'
                regions = load_regions(regions_fp)
                rid_to_region = {r['id']:r for r in regions}
                
                # plot_roi_heatmaps(selected_genes, 'tumor')
                # plt.savefig(fig_dir / f'{case}_{roi}_selected_tumor_heatmaps.svg')
                # plt.show()
        
                imgs, rids = get_imgs(selected_view)
                fig, axs = plt.subplots(ncols=len(imgs), figsize=(20, 10))
                for ax, img, rid in zip(axs, imgs, rids):
                    ax.imshow(img)
                    ax.axis('off')
                plt.savefig(fig_dir / f'{case}_{roi}_{view_name}_selected_tumor_imgs.svg', dpi=300)
                plt.show()
                    

In [None]:
for k, v in sid_to_channels.items():
    print(k, len(v))

In [None]:
case_to_df[case][roi].keys()

In [None]:
adata.X.shape

In [None]:
adata.X.toarray()

In [None]:
type(adata.X)