In [1]:
import openslide
from openslide import lowlevel as openslide_ll
from ctypes import c_uint32, POINTER, cast, addressof, c_void_p, byref, c_uint8, sizeof, c_uint16
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
from loguru import logger
import re
import cv2
from multiprocessing import RawArray
import concurrent
from concurrent.futures import ProcessPoolExecutor
from concurrent import futures
from tqdm import tqdm
from pathlib import Path
import dask.array as da
import dask

import shiprec
from shiprec.read import load_slide

SLIDE_FILE = "/data/data/TCGA-BRCA-DX-IMGS_1133/TCGA-AO-A12B-01Z-00-DX1.B215230B-5FF7-4B0A-9C1E-5F1658534B11.svs"
TARGET_PATCH_FILE = "/app/normalization_template.jpg"

In [2]:
from dask.distributed import Client

client = Client(n_workers=8, threads_per_worker=1, memory_limit="16GB")
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 119.21 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:40855,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 119.21 GiB

0,1
Comm: tcp://127.0.0.1:38463,Total threads: 1
Dashboard: http://127.0.0.1:37083/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:34649,
Local directory: /tmp/dask-scratch-space/worker-uy5_jrbd,Local directory: /tmp/dask-scratch-space/worker-uy5_jrbd

0,1
Comm: tcp://127.0.0.1:40299,Total threads: 1
Dashboard: http://127.0.0.1:44615/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:40689,
Local directory: /tmp/dask-scratch-space/worker-a_597aki,Local directory: /tmp/dask-scratch-space/worker-a_597aki

0,1
Comm: tcp://127.0.0.1:32931,Total threads: 1
Dashboard: http://127.0.0.1:36459/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:34273,
Local directory: /tmp/dask-scratch-space/worker-1rjfbl50,Local directory: /tmp/dask-scratch-space/worker-1rjfbl50

0,1
Comm: tcp://127.0.0.1:33949,Total threads: 1
Dashboard: http://127.0.0.1:32837/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:38605,
Local directory: /tmp/dask-scratch-space/worker-gqwr6m3e,Local directory: /tmp/dask-scratch-space/worker-gqwr6m3e

0,1
Comm: tcp://127.0.0.1:39161,Total threads: 1
Dashboard: http://127.0.0.1:38817/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:34339,
Local directory: /tmp/dask-scratch-space/worker-deumq2g2,Local directory: /tmp/dask-scratch-space/worker-deumq2g2

0,1
Comm: tcp://127.0.0.1:35105,Total threads: 1
Dashboard: http://127.0.0.1:34379/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:45019,
Local directory: /tmp/dask-scratch-space/worker-vn0ny09f,Local directory: /tmp/dask-scratch-space/worker-vn0ny09f

0,1
Comm: tcp://127.0.0.1:34921,Total threads: 1
Dashboard: http://127.0.0.1:44687/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:35261,
Local directory: /tmp/dask-scratch-space/worker-85jeyvjd,Local directory: /tmp/dask-scratch-space/worker-85jeyvjd

0,1
Comm: tcp://127.0.0.1:43541,Total threads: 1
Dashboard: http://127.0.0.1:45187/status,Memory: 14.90 GiB
Nanny: tcp://127.0.0.1:43399,
Local directory: /tmp/dask-scratch-space/worker-ayxi8e7u,Local directory: /tmp/dask-scratch-space/worker-ayxi8e7u


In [3]:
slide = load_slide(SLIDE_FILE, target_mpp=256./224.)
slide

[32m2023-08-07 07:29:42.041[0m | [1mINFO    [0m | [36mshiprec.read.mpp[0m:[36mget_slide_mpp[0m:[36m42[0m - [1mMPP successfully extracted using extract_mpp_from_properties: 0.499[0m
[32m2023-08-07 07:29:42.042[0m | [34m[1mDEBUG   [0m | [36mshiprec.read[0m:[36mload_slide[0m:[36m21[0m - [34m[1mSlide has 3 levels with following downsamples: {0: 1.0, 1: 4.000140674347211, 2: 16.004303859372538}[0m
[32m2023-08-07 07:29:42.042[0m | [1mINFO    [0m | [36mshiprec.read[0m:[36mload_slide[0m:[36m30[0m - [1mUsing level 0 with level_mpp=0.499 for slide_mpp=0.499 and target_mpp=1.143[0m


Unnamed: 0,Array,Chunk
Bytes,1.96 GiB,57.42 MiB
Shape,"(22400, 31360, 3)","(4480, 4480, 3)"
Dask graph,35 chunks in 3 graph layers,35 chunks in 3 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 1.96 GiB 57.42 MiB Shape (22400, 31360, 3) (4480, 4480, 3) Dask graph 35 chunks in 3 graph layers Data type uint8 numpy.ndarray",3  31360  22400,

Unnamed: 0,Array,Chunk
Bytes,1.96 GiB,57.42 MiB
Shape,"(22400, 31360, 3)","(4480, 4480, 3)"
Dask graph,35 chunks in 3 graph layers,35 chunks in 3 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [4]:
target = cv2.cvtColor(cv2.imread(str(TARGET_PATCH_FILE)), cv2.COLOR_BGR2RGB)
target_np = target
target = da.from_array(target)
target

Unnamed: 0,Array,Chunk
Bytes,768.00 kiB,768.00 kiB
Shape,"(512, 512, 3)","(512, 512, 3)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 768.00 kiB 768.00 kiB Shape (512, 512, 3) (512, 512, 3) Dask graph 1 chunks in 1 graph layer Data type uint8 numpy.ndarray",3  512  512,

Unnamed: 0,Array,Chunk
Bytes,768.00 kiB,768.00 kiB
Shape,"(512, 512, 3)","(512, 512, 3)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [5]:
from shiprec.macenko import DaskMacenkoNormalizer

norm = DaskMacenkoNormalizer(exact=False)
norm.fit(target)
result = norm.normalize(slide)
result

Unnamed: 0,Array,Chunk
Bytes,1.96 GiB,57.42 MiB
Shape,"(22400, 31360, 3)","(4480, 4480, 3)"
Dask graph,35 chunks in 194 graph layers,35 chunks in 194 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 1.96 GiB 57.42 MiB Shape (22400, 31360, 3) (4480, 4480, 3) Dask graph 35 chunks in 194 graph layers Data type uint8 numpy.ndarray",3  31360  22400,

Unnamed: 0,Array,Chunk
Bytes,1.96 GiB,57.42 MiB
Shape,"(22400, 31360, 3)","(4480, 4480, 3)"
Dask graph,35 chunks in 194 graph layers,35 chunks in 194 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [6]:
result.compute()

read 2 0 (20514, 0) 0 (10257, 10257)
read 2 2 (20514, 20514) 0 (10257, 10257)
read 3 2 (30771, 20514) 0 (10257, 10257)
read 2 1 (20514, 10257) 0 (10257, 10257)
read 1 1 (10257, 10257) 0 (10257, 10257)
read 5 1 (51285, 10257) 0 (10257, 10257)
read 4 1 (41028, 10257) 0 (10257, 10257)
read 6 1 (61542, 10257) 0 (10257, 10257)
read 5 0 (51285, 0) 0 (10257, 10257)
read 3 1 (30771, 10257) 0 (10257, 10257)
read 0 1 (0, 10257) 0 (10257, 10257)
read 6 0 (61542, 0) 0 (10257, 10257)
read 4 0 (41028, 0) 0 (10257, 10257)
read 3 0 (30771, 0) 0 (10257, 10257)
read 1 0 (10257, 0) 0 (10257, 10257)
read 0 0 (0, 0) 0 (10257, 10257)
read 4 2 (41028, 20514) 0 (10257, 10257)
read 6 2 (61542, 20514) 0 (10257, 10257)
read 0 2 (0, 20514) 0 (10257, 10257)
read 5 2 (51285, 20514) 0 (10257, 10257)
read 0 3 (0, 30771) 0 (10257, 10257)
read 1 3 (10257, 30771) 0 (10257, 10257)
read 6 3 (61542, 30771) 0 (10257, 10257)
read 3 3 (30771, 30771) 0 (10257, 10257)
read 0 4 (0, 41028) 0 (10257, 10257)
read 1 2 (10257, 20514)

array([[[240, 240, 240],
        [240, 240, 240],
        [240, 240, 240],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]],

       [[240, 240, 240],
        [240, 240, 240],
        [240, 240, 240],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]],

       [[240, 240, 240],
        [240, 240, 240],
        [240, 240, 240],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]],

       ...,

       [[196,  67, 108],
        [196,  67, 108],
        [196,  67, 108],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]],

       [[196,  67, 108],
        [196,  67, 108],
        [196,  67, 108],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]],

       [[196,  67, 108],
        [196,  67, 108],
        [196,  67, 108],
        ...,
        [196,  67, 108],
        [196,  67, 108],
        [196,  67, 108]]

In [5]:
HENormalizer = object

def namedelayprint(name):
    def f(x):
        print(f"=== {name} ===")
        print(x)
        return x
    return f

def printd(name, x):
    return da.from_delayed(dask.delayed(namedelayprint(name))(x), shape=x.shape, dtype=x.dtype)

@dask.delayed
def delayed_eigh(X):
    """Delayed version of `np.linalg.eigh`."""
    _, eigh = np.linalg.eigh(X)
    namedelayprint("eigh")(eigh)
    return eigh

def _cov(X: da.Array, N: da.Array) -> da.Array:
    """Compute covariance matrix of X.

    Args:
        X: array of shape (N, D) where N is the number of samples and D is the number of features.
        N: number of samples as an array of shape ().

    Unlike `da.cov(X)`, this function doesn't break when X.shape[0] is unknown at graph construction time.
    """

    # Compute mean of each column
    mean = X.mean(axis=-1, keepdims=True)

    # Subtract mean from each column
    X_centered = X - mean

    # Compute covariance matrix
    cov = np.dot(X_centered, X_centered.T) / (N - 1)
    return cov

class DaskMacenkoNormalizer(HENormalizer):
    def __init__(self):
        super().__init__()

        self.HERef = da.array([[0.5626, 0.2159], [0.7201, 0.8012], [0.4062, 0.5581]])
        self.maxCRef = da.array([1.9705, 1.0308])

    def _convert_rgb2od(self, I, Io=240, beta=0.15):
        # Calculate optical density
        OD = -da.log((I.astype(float) + 1) / Io)

        # Remove transparent pixels
        mask = ~da.any(OD < beta, axis=1)
        ODhat = OD[mask]

        ODhatN = mask.sum()

        return OD, ODhat, ODhatN

    def _find_HE(self, ODhat, eigvecs, alpha):
        # Project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues
        That = ODhat.dot(eigvecs[:, 1:3])

        phi = da.arctan2(That[:, 1], That[:, 0])

        That, phi = printd("That", That), printd("phi", phi)

        minPhi = da.percentile(phi, alpha).squeeze(-1)
        maxPhi = da.percentile(phi, 100 - alpha).squeeze(-1)

        minPhi, maxPhi = printd("minPhi", minPhi), printd("maxPhi", maxPhi)

        vMin = eigvecs[:, 1:3].dot(da.expand_dims(da.stack([da.cos(minPhi), da.sin(minPhi)], axis=0), axis=0).T)
        vMax = eigvecs[:, 1:3].dot(da.expand_dims(da.stack([da.cos(maxPhi), da.sin(maxPhi)], axis=0), axis=0).T)

        vMin = vMin[:, 0]
        vMax = vMax[:, 0]
        
        vMin, vMax = printd("vMin", vMin), printd("vMax", vMax)

        # The next few lines are a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second.
        # It is equivalent to the following code:
        # HE = da.array((vMin, vMax)).T if vMin[0] > vMax[0] else da.array((vMax, vMin)).T

        stacked = da.stack([vMin, vMax], axis=0)
        is_bigger = vMin[0] > vMax[0]
        HE = da.where(is_bigger, stacked, stacked[::-1]).T
        HE = HE.rechunk(-1)

        return HE

    def _find_concentration(self, OD, HE):
        # Rows correspond to channels (RGB), columns to OD values
        Y = da.reshape(OD, (-1, 3)).T

        # Determine concentrations of the individual stains
        C = da.linalg.lstsq(HE, Y)[0]

        return C

    def _compute_matrices(self, I, Io, alpha, beta):
        I = I.reshape((-1, 3))

        OD, ODhat, ODhatN = self._convert_rgb2od(I, Io=Io, beta=beta)

        # Compute eigenvectors
        cov = _cov(ODhat.T, ODhatN)

        # Now cov has shape (3, 3), so we can compute eigenvectors locally
        cov = cov.rechunk(-1)
        cov = printd("cov", cov)
        eigvecs = delayed_eigh(cov)
        eigvecs = da.from_delayed(eigvecs, (3, 3), dtype="float")

        HE = self._find_HE(ODhat, eigvecs, alpha)
        C = self._find_concentration(OD, HE)

        # Normalize stain concentrations
        maxC = da.concatenate([da.percentile(C[0], 99, internal_method="tdigest"), da.percentile(C[1], 99, internal_method="tdigest")])

        HE, C, maxC = printd("HE", HE), printd("C", C), printd("maxC", maxC)

        return HE, C, maxC

    def fit(self, I, Io=240, alpha=1, beta=0.15):
        HE, _, maxC = self._compute_matrices(I, Io, alpha, beta)

        self.HERef = HE
        self.maxCRef = maxC

    def normalize(self, I, Io=240, alpha=1, beta=0.15):
        I_chunks = I.chunks
        h, w, c = I.shape
        I = I.reshape((-1, 3))

        HE, C, maxC = self._compute_matrices(I, Io, alpha, beta)

        maxC = da.divide(maxC, self.maxCRef)
        C2 = da.divide(C, da.expand_dims(maxC, axis=-1))

        # Recreate the image using reference mixing matrix
        Inorm = da.multiply(Io, da.exp(-self.HERef.dot(C2)))
        Inorm[Inorm > 255] = 255
        Inorm = Inorm.astype(I.dtype)
        Inorm = da.reshape(Inorm.T, (h, w, c))

        Inorm = Inorm.rechunk(I_chunks)
        return Inorm
    

target_np = cv2.cvtColor(cv2.imread(str(TARGET_PATCH_FILE)), cv2.COLOR_BGR2RGB)
img_np = cv2.resize(target_np[200:300, 200:500], (2048, 2560))

target = da.from_array(target_np, chunks=-1)
img = da.from_array(img_np, chunks=(2048//2, 2560//2, 3))
# img = da.from_array(img_np, chunks=-1)

norm = DaskMacenkoNormalizer()
norm.fit(target)

result = norm.normalize(img)

In [6]:
img = result.compute()
# plt.imshow(img)

=== cov ===
[[0.11751138 0.15543969 0.07614929]
 [0.15543969 0.23981599 0.1209225 ]
 [0.07614929 0.1209225  0.06492574]]
=== eigh ===
[[ 0.192072   -0.8353876  -0.51501058]
 [-0.54262705  0.34686327 -0.76501095]
 [ 0.81771892  0.42639586 -0.38668118]]
=== phi ===
[-1.50896731 -1.52083376 -1.52849165 ... -1.39107331 -1.37304593
 -1.36061369]
=== maxPhi ===
-1.1531217303977388
=== minPhi ===
-1.5633898923781344
=== vMax ===
[0.1318742  0.83994679 0.52640154]
=== vMin ===
[0.50880927 0.76755897 0.38982862]
=== HE ===
[[0.50880927 0.1318742 ]
 [0.76755897 0.83994679]
 [0.38982862 0.52640154]]
=== maxC ===
[2.92766359 1.50944988]
=== cov ===
[[0.13174686 0.18141829 0.0921291 ]
 [0.18141829 0.28680508 0.14713076]
 [0.0921291  0.14713076 0.07973138]]
=== eigh ===
[[ 0.13116037 -0.85438694 -0.502812  ]
 [-0.52123265  0.3719983  -0.76807148]
 [ 0.84327545  0.36282257 -0.39654293]]
=== phi ===
[-1.46852253 -1.46852253 -1.46852253 ... -1.44003349 -1.43655907
 -1.40685865]
=== maxPhi ===
-1.220631

In [25]:
# from shiprec.macenko import NumpyMacenkoNormalizer

# n2 = NumpyMacenkoNormalizer()
# n2.fit(target_np)
# out2 = n2.normalize(img_np)
# plt.imshow(out2)