In [None]:
import math
import os
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
import tifffile
import torchvision.transforms.functional as TF
import torch
from pathlib import Path
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ColorJitter, Normalize, RandomHorizontalFlip, RandomVerticalFlip, RandomAdjustSharpness, RandomResizedCrop, RandomAffine

from einops import rearrange

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import multiplex_imaging_pipeline.utils as utils

In [None]:
adata_fps = sorted(utils.listfiles(os.path.join(output_dir, 'spatial_features'), regex=r'.h5ad$'))
sample_to_adata = {fp.split('/')[-1].replace('.h5ad', ''):fp for fp in adata_fps}

In [None]:
region_fts_fps = sorted(utils.listfiles(os.path.join(output_dir, 'region_features'), regex=r'.txt$'))
sample_to_region_fts = {fp.split('/')[-1].replace('.txt', ''):fp for fp in region_fts_fps}

In [None]:
mask_fps = sorted(utils.listfiles(os.path.join(output_dir, 'region_features_masks'), regex=r'region.*.tif$'))
sample_to_labeled = {fp.split('/')[-1].replace('.tif', ''):fp for fp in mask_fps}

In [None]:
pseudo_fps = sorted(utils.listfiles(os.path.join(output_dir, 'pseudo_fullres'), regex=r'.tif$'))
sample_to_pseudo = {fp.split('/')[-1].replace('.tif', ''):fp for fp in pseudo_fps}

In [None]:
ome_fps = sorted(utils.listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan',
                         regex='/level_2/[^/]+.ome.tiff$'))
sample_to_ome = {fp.split('/')[-1].replace('.ome.tiff', ''):fp for fp in ome_fps}

In [None]:
name_map = {
    'HT110B1-S1H4': 'HT110B1_S1H4',
    'HT206B1-H1': 'HT206B1_H1',
    'HT206B1_S1H2L4': 'HT206B1_S1H2L4_20221028',
    'HT243B1-S1H4': 'HT243B1_S1H4',
    'HT271B1-S1H3A5': 'HT271B1_S1H3A5_02172023',
    'HT305B1-S1H1': 'HT305B1_S1H1',
    'HT323B1-H1A1': 'HT323B1_H1A1',
    'HT323B1-H1A4': 'HT323B1_H1A4',
    'HT323B1-H1-08042022': 'HT323B1_H1_08042022',
    'HT323B1-H3': 'HT323B1_H3',
    'HT339B1-H1A1': 'HT339B1_H1A1',
    'HT339B1-H2A1': 'HT339B1_H2A1',
    'HT365B1_S1H1': 'HT365B1_S1H1_02132023',
    'HT397B1-H2A2': 'HT397B1_H2A2',
    'HT397B1-H3A1': 'HT397B1_H3A1',
    'HT480B1-S1H2-R001': 'HT480B1_S1H2_R001',
    'HT480B1-S1H2-R002': 'HT480B1_S1H2_R002'
}

sample_to_ome = {name_map.get(k, k):v for k, v in sample_to_ome.items()}

In [None]:
samples = sorted(sample_to_adata.keys())

In [None]:
for sample in samples[-:1]:
    print(sample)
    df = sample_to_region_fts[sample]
    region_to_bbox = {region:(r1, r2, c1, c2)
                      for region, (r1, r2, c1, c2) in zip(
                          df.index, df['expanded_r1', 'expanded_r2', 'expanded_c1', 'expanded_c2'].values)}
    
    pseudo = tifffile.imread(sample_to_pseudo[sample])
    labeled = tifffile.imread(sample_to_labeled[sample])
    
    for region in region_to_bbox:
        fig, axs = plt.subplots(ncols=2)
        
        r1, r2, c1, c2 = region_to_bbox[region]
        tile = pseudo[r1:r2, c1:c2]
        print(region)
        axs[0].imshow(tile)
        axs[0].axis('off')
        
        tile = labeled[r1:r2, c1:c2]
        mask = np.zeros_like(tile, dtype=bool)
        mask[tile==int(region)] = True
        axs[1].imshow(mask)
        axs[1].axis('off')
        
        plt.show()


In [None]:
sample_to_classifications = {
    'HT110B1_S1H4': {
        'normal': [264, 263, 117, 89, 187, 152, ],
        'dcis': [85, 220],
        'idc': [200, 235, 153, 247, 230, 255, 168, 177, 248, 240, 227, 182, 206, 243, 245, 246, 124, 159, 225],
        'artifact': [271, 270],},
    'HT171B1-S1H9A1-4_left_05122023': {
        'normal': [],
        'dcis': [],
        'idc': [95, 134, 159, 144, 156, 80, 47, 12, 166, 135, 65, 98, 79, 25, 104, 158, 59, 83, 127, 37, 2, 8, 172],
        'artifact': [72, 174, 18, 111, 20, ],},
    'HT171B1-S1H9A1-4_right_05122023': {
        'normal': [],
        'dcis': [],
        'idc': [95, 96, 70, 145, 76, 147, 157, 168, 115, 77, 155, 97, 128, 43, 39],
        'artifact': [19, 8, 29, 3, 50, 1, 170, 87, 41, ],},
    'HT206B1_H1': {
        'normal': [112, 87, 22, 129, 109, 208, 128, 202, 92, 189, 116, 201, 209, 103],
        'dcis': [101, 197, 182, 41, 42, 49, 179, 110, 178, 96, 120, 212],
        'idc': [218, 135, 108, 131, 54, 113, 136, 23, 126, 196, 224, 27, 72, 32, 70, 219, 104, 33, 190, 228, 52, 165],
        'artifact': [],},
    'HT206B1_H1_06252022': {
        'normal': [50, 87, 313, 299, 49, 162, 187, 230, 177, 174, 214, 171, 70, 209, 77, 90, 95, 10, 89],
        'dcis': [196, 91, 109, 125, 237, 57, 88, 139, 145, 75],
        'idc': [218, 255, 296, 35, 249, 186, 246, 260, 200, 97, 217, 259, 236, 216, 284, 252, 289, 124, 140],
        'artifact': [],},
    'HT206B1_S1H2L4_20221028': {
        'normal': [114, 187, 61, 79, 125, 190, 86, 106, 166],
        'dcis': [119, 113, 8, 5, 72, 49, 40, 118, 185, 30, 156, 4, 94, 90, 21, 165, 56, 54, 160, 29, 124, 177, 28, 179, 142, 169, 38, 88],
        'idc': [65, 52, 62, 115, 167, 47, 67, 122, 37, 9],
        'artifact': [],},
    'HT243B1-S1H4A4_04192023': {
        'normal': [293, 149, 152, 68, 145, 157, 271],
        'dcis': [],
        'idc': [47, 41, 32, 28, 31, 21, 8, 86, 80, 52, 156, 168, 141, 237, 202, 124, 130, 283],
        'artifact': [],},
    'HT243B1-S1H4A4_left_05122023': {
        'normal': [131, 45, 165],
        'dcis': [],
        'idc': [249, 251, 173, 220, ],
        'artifact': [35, 17, 34, 29, 256, 21],},
    'HT243B1-S1H4A4_right_05122023': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},
    'HT243B1_S1H4': {
        'normal': [211, 205, 30, 4, 169, 57, 10, 28, 61, 53],
        'dcis': [44, 58, 43, 49, 55, 46, 48, 51, ],
        'idc': [88, 213, 206, 146, 185, 173, 67, 142, 190],
        'artifact': [26],},
    'HT271B1-S1H6A5_04192023': {
        'normal': [255, 243, 40, 33, 238, 152, 108, 115],
        'dcis': [75, 41, 66, 96, 28, 10, 86, 67, 79, 224, 78, 208, 146, 54, 69, 63, 77, 181, 131, 44],
        'idc': [188, 167, 197, 199, 48, 57, 76, 141, 19, 174, 38, 192, 121, 83],
        'artifact': [182, 173, 262, 1],},
    'HT271B1-S1H6A5_left_05122023': {
        'normal': [30, 147, 80, 9, 38, 31, 32, 8, 46, 35, 63],
        'dcis': [25, 7, 24, 42, 14, 37, 27, 36, 47],
        'idc': [141, 73, 72, 106, 16, 99, 58],
        'artifact': [1, 44, 69, 3, 132, 149, 146, 2],},
    'HT271B1-S1H6A5_right_05122023': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},
    'HT297B1_H1_08042022': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': []},# exclude
    'HT305B1_S1H1': { 
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': []},# exclude
    'HT308B1-S1H5A4_04192023': {
        'normal': [75, 119],
        'dcis': [123, 94, 83, 118, 134, 85, 35],
        'idc': [21, 44, 23, 9, 74, 22, 59, 107],
        'artifact': [270, 188, 147, 41],},
    'HT308B1-S1H5A4_left_05122023': {
        'normal': [],
        'dcis': [48, 12, 55],
        'idc': [53, 9, 49, 6, 18, 67, 59, 60],
        'artifact': [1, 65, 2],},
    'HT308B1-S1H5A4_right_05122023': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},
    'HT323B1_H1A1': {
        'normal': [633, 711, 623, 621, 684, 630, 722, 522, 503, 438, 360, 397, 416, 328, 666],
        'dcis': [213, 24, 22],
        'idc': [323, 436, 232, 412, 19, 312, 173, 85, 381, 390, 317, 419],
        'artifact': [653],},
    'HT323B1_H1A4': {
        'normal': [669, 514, 378, 202, 284, 278, 239],
        'dcis': [265, 7, 71, 16, 55, 35, 79, 2, 134, 408, 188, 86, 213, 40, 70, 3, 128],
        'idc': [174, 343, 111, 432, 271, 22],
        'artifact': []},
    'HT323B1_H1_08042022': {
        'normal': [428, 291, 582, 438],
        'dcis': [307, 443, 141, 6],
        'idc': [161, 382, 289, 301, 181, 137, 366, 128, 30],
        'artifact': [],},
    'HT323B1_H3': {
        'normal': [727, 700, 448, 625, 584, 587],
        'dcis': [279, 546, 751, 282, 699, 654, 661, 199, 487, 549, 555, 462, 920, 814, 350, 208, 320, 979],
        'idc': [1119, 183, 45, 254, 653, 690, 148, 412],
        'artifact': [],},
    'HT339B1_H1A1': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},# exclude
    'HT339B1_H2A1': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},# exclude
    'HT339B1_H4A4': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},# exclude
    'HT365B1_S1H1_02132023': {
        'normal': [],
        'dcis': [159, 153],
        'idc': [158, 148, 160, 37, 6, 10],
        'artifact': [],},
    'HT397B1_H2A2': {
        'normal': [],
        'dcis': [193, 216, 176, 111, 50, 253, 9, 5, 69, 129, 48, 37, 54, 108, 55, 58, 41, 51, 6, 197, 33, 40, 142, 23, 25, 35],
        'idc': [228, 132, 227, 102, 152, 30],
        'artifact': [],},
    'HT397B1_H3A1': {
        'normal': [],
        'dcis': [122, 338, 97, 423, 392, 126, 422, 227, 455, 107, 411, 370, 157, 436, 430, 254],
        'idc': [365, 176, 255, 104, 212],
        'artifact': [],},
    'HT397B1_S1H1A3U22_04122023': {
        'normal': [],
        'dcis': [36, 17, 11, 24, 14],
        'idc': [91, 81, 77, 89, 194],
        'artifact': []},
    'HT397B1_S1H1A3U31_04062023': {
        'normal': [],
        'dcis': [53, 16, 33, 8],
        'idc': [22],
        'artifact': [],},
    'HT397B1_U12_03172023': {
        'normal': [],
        'dcis': [43, 117, 41, 79, 54, 22, 76, 35, 13, 67, 44, 17, 40, 26, 30, 11],
        'idc': [121, 33, 213, 14],
        'artifact': [],},
    'HT397B1_U2_03162023': {
        'normal': [],
        'dcis': [],
        'idc': [],
        'artifact': [],},
    'HT480B1_S1H2_R001': {
        'normal': [],
        'dcis': [176, 24, 96, 92],
        'idc': [28, 281, 309, 135, 159],
        'artifact': [],},
    'HT480B1_S1H2_R002': {
        'normal': [],
        'dcis': [92, 97, 104, 182, 27, 47, 36, 106, 88, 26],
        'idc': [237, 204],
        'artifact': [],},
    'HT565B1-H2_04262023': {
        'normal': [90, 30, ],
        'dcis': [79, 78, 91, 60, 36, 77, 25, 81, 74, 65, 71, 45, 35, 40, 50, 86, 13, 58, 41, 54],
        'idc': [85, 2],
        'artifact': [4],}
}
r_sample_to_classifications = {f'{s}_{r}':l for s, cls in sample_to_classifications.items()
                               for l, rs in cls.items()
                               for r in rs}

In [None]:
region_to_cls_imgs = {}
for sample, classifications in sample_to_classifications.items():
    print(sample)
    df = sample_to_region_fts[sample]
    region_to_bbox = {region:(r1, r2, c1, c2)
                      for region, (r1, r2, c1, c2) in zip(
                          df.index, df['expanded_r1', 'expanded_r2', 'expanded_c1', 'expanded_c2'].values)}
    
    pseudo = tifffile.imread(sample_to_pseudo[sample]).to(torch.float32)
    pseudo /= pseudo.max()
    labeled = tifffile.imread(sample_to_labeled[sample])
    
    for region in region_to_bbox:
        identifier = f'{sample}_{region}'
        r1, r2, c1, c2 = region_to_bbox[region]
        tile_pseudo = pseudo[r1:r2, c1:c2]
        tile = labeled[r1:r2, c1:c2]
        tile_mask = torch.zeros_like(tile, dtype=bool)
        tile_mask[tile==int(region)] = True

        tile_pseudo = rearrange(tile_pseudo, 'h w c -> c h w')
        tile_mask = tile_mask.unsqueeze(dim=0)
        h, w = r2 - r1, c2 - c1
        if h > w:
            ratio = w / h
            padding = int((256 - (256 * ratio)) // 2)
            tile_pseudo = TF.resize(tile_pseudo, (256, int(256 * ratio)))
            tile_mask = TF.resize(tile_mask, (256, int(256 * ratio)))

            tile_pseudo = TF.pad(tile_pseudo, (padding, 0, padding, 0),
                                 padding_mode='reflect' if ratio > 1.0 else 'constant') # just make constant default for now
            tile_mask = TF.pad(tile_mask, (padding, 0, padding, 0),
                               padding_mode='reflect' if ratio > 1.0 else 'constant')
                
        else:
            ratio = h / w
            padding = int((256 - (256 * ratio)) // 2)
            tile_pseudo = TF.resize(tile_pseudo, (int(256 * ratio), 256))
            tile_mask = TF.resize(tile_mask, (int(256 * ratio), 256))
            tile_pseudo = TF.pad(tile_pseudo, (0, padding, 0, padding),
                                padding_mode='reflect' if ratio > 1.0 else 'constant')
            tile_mask = TF.pad(tile_mask, (0, padding, 0, padding),
                              padding_mode='reflect' if ratio > 1.0 else 'constant')


        tile_pseudo = TF.resize(tile_pseudo, (256, 256))
        tile_mask = TF.resize(tile_mask, (256, 256))

        region_to_cls_imgs[identifier] = {
            'rgb': tile_pseudo,
            'mask': tile_mask,
            'label': r_sample_to_classifications.get(identifier, 'unknown')
        }

In [None]:
tups = sorted([(d['label'], k) for k, d in region_to_cls_imgs.items()], key=lambda x: x[0])
for _, k in tups:
    d = region_to_cls_imgs[k]
    if d['label'] != 'unknown':
        print(k, d['label'])
        fig, axs = plt.subplots(ncols=2)
        axs[0].imshow(rearrange(d['rgb'], 'c h w -> h w c'))
        axs[1].imshow(d['mask'][0])
        plt.show()

In [None]:
to_remove = [
    'HT110B1_S1H4_270',
    'HT110B1_S1H4_271',
    'HT171B1-S1H9A1-4_left_05122023_18',
    'HT171B1-S1H9A1-4_left_05122023_20',
    'HT171B1-S1H9A1-4_left_05122023_111',
    'HT171B1-S1H9A1-4_right_05122023_41',
    'HT171B1-S1H9A1-4_right_05122023_87',
    'HT243B1_S1H4_26',
    'HT271B1-S1H6A5_04192023_1',
    'HT271B1-S1H6A5_04192023_173',
    'HT271B1-S1H6A5_04192023_182',
    'HT323B1_H1A1_653',
    'HT271B1-S1H6A5_04192023_63',
    'HT271B1-S1H6A5_04192023_54',
    'HT271B1-S1H6A5_04192023_79',
    'HT271B1-S1H6A5_04192023_86',
    'HT271B1-S1H6A5_04192023_146',
    'HT271B1-S1H6A5_04192023_181',
    'HT397B1_U12_03172023_43',
    'HT397B1_U12_03172023_79',
    'HT397B1_U12_03172023_117'
]

In [None]:
img = torch.concat([d['rgb'].unsqueeze(0) for d in region_to_cls_imgs.values()])
img.mean(dim=(0, 2, 3)), img.std(dim=(0, 2, 3))