In [1]:
import os
import sys
from pathlib import Path

import cv2 as cv
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage
from joblib import Parallel, delayed
from skimage import exposure, io, util
from skimage.util import img_as_ubyte
from tqdm.notebook import tqdm, trange
import cv2

In [2]:
p_dir = (Path().cwd().parents[0]).absolute()
data_dir = p_dir / "data"
match_info_dir = data_dir / "match"

In [3]:
%load_ext autoreload
%autoreload 2

module_path = str(p_dir / "src")

if module_path not in sys.path:
    sys.path.append(module_path)

import utils as my_utils

In [4]:
# experiment = "tonsil"
# core = "DonorA"
# datasets = [
#     core + "_1",
#     core + "_2",
#     core + "_3",
#     core + "_4",
#     core + "_5",
#     core + "_6",
# ]

In [5]:
# experiment = "Lung"
# core = "B5"
# datasets = [
#     core + "_1",
#     core + "_2",
#     core + "_3",
#     core + "_4",
# ]

In [6]:
experiment = "Endometrium"
core = "L128"
datasets = [
    # core + "_1",
    # core + "_2",
    core + "_3",
    # core + "_4",
]

In [7]:
# imgs, markers = get_imgs(f"{experiment} {core}", "IMC")


In [8]:
# import napari 

# viewer = napari.view_image(np.stack(imgs), channel_axis=0, name=markers, visible=False)

# Import IMC data 


In [9]:
from functools import partial

import matplotlib.patches as mpatches
from skimage.segmentation import mark_boundaries
from skimage.transform import rotate
import matplotlib.patches as mpatches
from collections import Counter
import matplotlib.offsetbox
from matplotlib.lines import Line2D

h5_data_dir = p_dir / "data" / "h5"


def get_imgs(experiment, name):
    with h5py.File(h5_data_dir / f"{experiment}.hdf5", "r") as f:
        imgs = f[name][:]
        labels = list(f[name].attrs["markers"])
    return imgs, labels


def contrast_stretching(img):
    # Contrast stretching
    p2, p98 = np.percentile(img, (1, 99))
    img_rescale = exposure.rescale_intensity(
        img, in_range=(p2, p98), out_range=(0, 255)
    ).astype(np.uint8)
    return img_rescale


# Read mask image
def get_masks(mask_folder, dataset):
    """
    Function to get all mask from mask forlder
    """
    # Read masks
    masks = {}

    for (dirpath, dirnames, filenames) in os.walk(mask_folder):
        for name in sorted(filenames):
            if "tif" in name and dataset in name:
                filename = os.path.join(dirpath, name)
                img = skimage.io.imread(filename)
                condition = name.split(".")[0]
                masks[condition] = img
            else:
                continue
    return masks


def crop_img(img, info):
    # Rotate img
    y, h_region, x, w_region = info["bbox"]
    rotation = info["rotation_init"]
    rotation_small = info["rotation_adjust"]
    try:
        border = info['border']
        img = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, None, value = 0)
    except: pass
    img = rotate(img, rotation, resize=True)
    # Get region for img
    if "img_region" in info.keys():
        row_min, row_max, col_min, col_max = info["img_region"]
        img = img[row_min:row_max, col_min:col_max]

    img = img[y : y + h_region, x : x + w_region]
    img = contrast_stretching(img)
    img = img_as_ubyte(img)
    return img


def plot_mutliplex(data, channels, RGB_MAP, markers, mask=None):
    img = my_utils.convert_to_rgb(data, channels=channels, vmax=255, rgb_map=RGB_MAP)
    if mask is not None:
        img = mark_boundaries(img, mask, color=(0.8, 0.8, 0.8), mode="subpixel")
    fig, ax = plt.subplots(figsize=(12, 10))
    ax.imshow(img)
    ax.axis("off")

    # Color legend
    patches = [
        mpatches.Patch(color=RGB_MAP[i]["rgb"] / 255.0, label=markers[i - 1])
        for i in channels
    ]
    ax.legend(
        handles=patches,
        bbox_to_anchor=(0, 0),
        loc="lower right",
        borderaxespad=0.0,
        fontsize=10,
    )

    plt.show()


def get_img_subset(imgs, markers, labels):
    imgs_subset = []
    for marker in markers:
        idx = labels.index(marker)
        imgs_subset.append(imgs[idx])
    return np.stack(imgs_subset, axis=2)


def joblib_loop(task, pics):
    return Parallel(n_jobs=20)(delayed(task)(i) for i in pics)

class AnchoredHScaleBar(matplotlib.offsetbox.AnchoredOffsetbox):
    """ size: length of bar in data units
        extent : height of bar ends in axes units """
    def __init__(self, size=1, extent = 0.03, label="", loc=2, ax=None,
                 pad=0.4, borderpad=0.5, ppad = 0, sep=2, prop=None, 
                 frameon=True, linekw={}, textprops={}, **kwargs):
        if not ax:
            ax = plt.gca()
        trans = ax.get_xaxis_transform()
        size_bar = matplotlib.offsetbox.AuxTransformBox(trans)
        line = Line2D([0,size],[0,0], **linekw)
        vline1 = Line2D([0,0],[-extent/2.,extent/2.], **linekw)
        vline2 = Line2D([size,size],[-extent/2.,extent/2.], **linekw)
        size_bar.add_artist(line)
        size_bar.add_artist(vline1)
        size_bar.add_artist(vline2)
        txt = matplotlib.offsetbox.TextArea(label, minimumdescent=False, textprops=textprops)
        self.vpac = matplotlib.offsetbox.VPacker(children=[size_bar,txt],  
                                 align="center", pad=ppad, sep=sep) 
        matplotlib.offsetbox.AnchoredOffsetbox.__init__(self, loc, pad=pad, 
                 borderpad=borderpad, child=self.vpac, prop=prop, frameon=frameon,
                 **kwargs)


KeyboardInterrupt: 

In [None]:
RGB_MAP = {
    1: {"rgb": np.array([0, 0, 255]), "range": [0, 150]},
    2: {"rgb": np.array([255, 0, 255]), "range": [0, 255]},
    3: {"rgb": np.array([0,255,255]), "range": [0, 255]},
    4: {"rgb": np.array([255,255,0]), "range": [0, 255]},
}

def one_channel(t, channel, vmax=255, rgb_map=RGB_MAP):
    """
    Converts and returns the image data as RGB image
    Parameters
    ----------
    t : np.ndarray
        original image data
    channels : list of int
        channels to include
    vmax : int
        the max value used for scaling
    rgb_map : dict
        the color mapping for each channel
        See rxrx.io.RGB_MAP to see what the defaults are.
    Returns
    -------
    np.ndarray the image data of the site as RGB channels
    """
    dim1, dim2, _ = t.shape
    colored_channels = []
    x = (t[:, :, 0] / vmax) / (
        (rgb_map[channel]["range"][1] - rgb_map[channel]["range"][0]) / 255
    ) + rgb_map[channel]["range"][0] / 255
    x = np.where(x > 1.0, 1.0, x)
    x_rgb = np.array(
        np.outer(x, rgb_map[channel]["rgb"]).reshape(dim1, dim2, 3), dtype=int
    )
    colored_channels.append(x_rgb)
    im = np.array(np.array(colored_channels).sum(axis=0), dtype=int)
    im = np.where(im > 255, 255, im)
    im = im.astype(np.uint8)
    return im


In [None]:
# imgs, markers = get_imgs(f"{experiment} {core}", "IMC")
# subset = ["HistoneH3", 'DNA2']
# subset = ["DNA1", "DNA2", 'Ki67', 'CD21', 'COL1', 'CD38', 'Vimentin', 'CD20', 'H3K27me3', 'PD1']
img_dapi = np.max(contrast_stretching(imgs[[markers.index(marker) for marker in subset]]), axis=0)
img_dapi = contrast_stretching(img_dapi)

for i, dataset in enumerate(datasets):
    # Only for endometrium L72 and L128
    imgs, markers = get_imgs(f"{dataset}", "IMC")
    subset = ["DNA1", 'DNA2']
    img_dapi = np.max(contrast_stretching(imgs[[markers.index(marker) for marker in subset]]), axis=0)
    img_dapi = contrast_stretching(img_dapi)
    
    try:
        info = my_utils.load_pkl(match_info_dir / f"{experiment}_{dataset}.pickle")
    except:
        continue
        
    if i == 0:
        rotation = info["rotation_init"]
        try:
            border = info['border']
            img_dapi = cv2.copyMakeBorder(img_dapi, border, border, border, border, cv2.BORDER_CONSTANT, None, value = 0)
        except: pass
        # Rotate image 
        img_dapi = rotate(img_dapi, rotation, resize=True)
        
        # Create image
        fig, ax = plt.subplots(figsize=(15, 15), facecolor="k")
        ax.imshow(img_dapi, cmap=plt.cm.gray)
        ax.set_axis_off()
        
        # Add Scale Bar
        ob = AnchoredHScaleBar(size=150, label="150\u03BCm", loc=4, frameon=False, extent=0.,
                       pad=0.1,sep=4, linekw=dict(color="w", linewidth=5), textprops=dict(color='w',fontsize=20))
        ax.add_artist(ob)
        
    # highlight matched region
    y, h_region, x, w_region = info["bbox"]
    if "img_region" in info.keys():
        row_min, _, col_min, _ = info["img_region"]
        y += row_min
        x += col_min
    rect = plt.Rectangle(
        (x, y), w_region, h_region, edgecolor="r", facecolor="none", lw=3
    )
    ax.add_patch(rect)
    r = dataset.split('_')[-1]
    ax.annotate(r, (x+w_region/2, y+h_region/2), color='red', weight='bold', 
                fontsize=50, ha='center', va='center', rotation=0)

## Save image
file_path = Path.cwd().parent / 'figures' / 'ROIs' / f'{dataset}.png'
fig.savefig(file_path, dpi=500, bbox_inches='tight', pad_inches=0)