In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import torchvision.transforms.functional as f
import torch
import h5py
import matplotlib.pyplot as plt
import concurrent.futures

### Visualize current DINO augmentations

In [None]:
def inverse_normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [None]:
## Create list of patches
# PATH = "/lustre/groups/shared/histology_data/tcga_patches/patches/2.0"
# patches = list(Path(PATH).glob("**/*.png"))
# np.savetxt("/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt", patches, fmt="%s", delimiter="\n")

In [None]:
# patches = "/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt"
patches = "/lustre/groups/shared/histology_data/patch_lists/all.txt"
patches = np.loadtxt(patches, dtype=str, max_rows=100000).tolist()
patches = np.random.choice(patches, 100).tolist()
# patches = list(Path("/lustre/groups/shared/histology_data/TCGA/ACC/patches").glob("**/*.h5"))
# patches = list(Path("/lustre/groups/shared/tcga/CRC/patches/512px_crc_wonorm_complete_diag_frozen").glob("**/*.jpeg"))
# patches = np.loadtxt("/lustre/groups/shared/histology_data/TCGA/CRC/patches/512px_crc_wonorm_complete_diag_frozen.txt", dtype=str, max_rows=100).tolist()

In [None]:
len(patches)

In [None]:
id = np.random.randint(0, len(patches))
patch = Image.open(patches[id]).convert(mode="RGB")

In [None]:
patch

In [None]:
from dinov2.data import DataAugmentationDINO

In [None]:
data_transform = DataAugmentationDINO(
    (1., 1.), #cfg.crops.global_crops_scale,
    (0.32, 0.32), #cfg.crops.local_crops_scale,
    8, #cfg.crops.local_crops_number,
    224, #global_crops_size=cfg.crops.global_crops_size,
    local_crops_size=98, #cfg.crops.local_crops_size,
)

In [None]:
out = data_transform(patch)

In [None]:
fig = plt.figure(figsize=(20, 2))
for i in range(2):
    rev = inverse_normalize(tensor=out[f'global_crops'][i], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, i+1)
    plt.imshow(img)
    # plt.axis('off')
for k in range(8):
    rev = inverse_normalize(tensor=out[f'local_crops'][k], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, k+3)
    plt.imshow(img)
    # plt.axis('off')
plt.show()

In [None]:
# determine whether you want to plot local or global crops
# global crop 0 and 1 have different settings, local crops are all the same
def process_image(index):
    out = data_transform(patch)
    rev = inverse_normalize(tensor=out[f'global_crops'][1], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)

    plt.subplot(2, 5, index + 1)
    plt.imshow(img)
    plt.axis('off')

In [None]:
# plot the global crops
num_images = 10

# Create a figure and set the size
fig = plt.figure(figsize=(10, 4))

# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor() as executor:
    # Map the process_image function to each index in parallel
    executor.map(process_image, range(num_images))

# Adjust layout and show the plot
plt.tight_layout()
plt.show()

In [None]:
# adapted from  https://github.com/DIAGNijmegen/pathology-he-auto-augment/blob/main/he-randaugment/custom_hed_transform.py
import numpy as np
from scipy import linalg
from skimage.util import dtype, dtype_limits
from skimage.exposure import rescale_intensity
import time

rgb_from_hed = np.array([[0.65, 0.70, 0.29],
                         [0.07, 0.99, 0.11],
                         [0.27, 0.57, 0.78]]).astype('float32')
factor = 0.1
# color_jitter = np.random.uniform(-factor, factor, (3, 3))
# rgb_from_hed = rgb_from_hed + color_jitter
print(rgb_from_hed)
hed_from_rgb = linalg.inv(rgb_from_hed).astype('float32')


def rgb2hed(rgb):
    return separate_stains(rgb, hed_from_rgb)

def hed2rgb(hed):
    return combine_stains(hed, rgb_from_hed)

def separate_stains(rgb, conv_matrix):
    rgb = dtype.img_as_float(rgb, force_copy=True).astype('float32')
    rgb += 2
    stains = np.dot(np.reshape(-np.log(rgb), (-1, 3)), conv_matrix)
    return np.reshape(stains, rgb.shape)

def combine_stains(stains, conv_matrix):
    stains = dtype.img_as_float(stains.astype('float64')).astype('float32')  # stains are out of range [-1, 1] so dtype.img_as_float complains if not float64
    logrgb2 = np.dot(-np.reshape(stains, (-1, 3)), conv_matrix)
    rgb2 = np.exp(logrgb2)
    return rescale_intensity(np.reshape(rgb2 - 2, stains.shape),
                             in_range=(-1, 1))

In [None]:
# HED color augmentations adapted from  https://github.com/DIAGNijmegen/pathology-he-auto-augment/blob/main/he-randaugment/custom_hed_transform.py
# Tellez et al.
__cutoff_range = (0.15, 0.85)

def hed_jitter(factor):
    __biases = [np.random.uniform(-factor, factor), np.random.uniform(-factor, factor), np.random.uniform(-factor, factor)]
    __sigmas = [np.random.uniform(-factor, factor), np.random.uniform(-factor, factor), np.random.uniform(-factor, factor)]

    patch_mean = np.mean(a=patch) / 255.0
    if __cutoff_range[0] <= patch_mean <= __cutoff_range[1]:

        patch_hed = rgb2hed(rgb=patch_image)

        patch_hed *= (1.0 + np.array(__sigmas))
        patch_hed += np.array(__biases)
        
        patch_rgb = hed2rgb(hed=patch_hed)
        patch_rgb = np.clip(a=patch_rgb, a_min=0.0, a_max=1.0)
        patch_rgb *= 255.0
        patch_rgb = patch_rgb.astype(dtype=np.uint8)

        # patch_transformed = np.transpose(a=patch_rgb, axes=(2, 0, 1))
        patch_transformed = patch_rgb
    return patch_transformed

# patch_image = np.array(patch)
# patch_transformed = hed_jitter(0.05)



In [None]:
patch_image = np.array(patch)

In [None]:
fig = plt.figure(figsize=(20, 2))
plt.subplot(1, 10, 1)
plt.imshow(patch)
plt.axis('off')
for i in tqdm(range(9)):
    plt.subplot(1, 10, i + 2)
    patch_transformed = hed_jitter(0.07)
    plt.imshow(patch_transformed)
    plt.axis('off')

plt.show()

In [None]:
patch_image.shape

### Stain normalization Reinhard

In [None]:
# stain utils

def standardize_brightness(I):
    """

    :param I:
    :return:
    """
    p = np.percentile(I, 90)
    return np.clip(I * 255.0 / p, 0, 255).astype(np.uint8)

def build_stack(tup):
    """
    Build a stack of images from a tuple of images
    :param tup:
    :return:
    """
    N = len(tup)
    if len(tup[0].shape) == 3:
        h, w, c = tup[0].shape
        stack = np.zeros((N, h, w, c))
    if len(tup[0].shape) == 2:
        h, w = tup[0].shape
        stack = np.zeros((N, h, w))
    for i in range(N):
        stack[i] = tup[i]
    return stack

def patch_grid(ims, width=5, sub_sample=None, rand=False, save_name=None):
    """
    Display a grid of patches
    :param ims:
    :param width:
    :param sub_sample:
    :param rand:
    :return:
    """
    N0 = np.shape(ims)[0]
    if sub_sample == None:
        N = N0
        stack = ims
    elif sub_sample != None and rand == False:
        N = sub_sample
        stack = ims[:N]
    elif sub_sample != None and rand == True:
        N = sub_sample
        idx = np.random.choice(range(N), sub_sample, replace=False)
        stack = ims[idx]
    height = np.ceil(float(N) / width).astype(np.uint16)
    plt.rcParams['figure.figsize'] = (18, (18 / width) * height)
    plt.figure()
    for i in range(N):
        plt.subplot(height, width, i + 1)
        im = stack[i]
        show(im, now=False, fig_size=None)
    if save_name != None:
        plt.savefig(save_name)
    plt.show()

def show(image, now=True, fig_size=(10, 10)):
    """
    Show an image (np.array).
    Caution! Rescales image to be in range [0,1].
    :param image:
    :param now:
    :param fig_size:
    :return:
    """
    image = image.astype(np.float32)
    m, M = image.min(), image.max()
    if fig_size != None:
        plt.rcParams['figure.figsize'] = (fig_size[0], fig_size[1])
    plt.imshow((image - m) / (M - m), cmap='gray')
    plt.axis('off')
    if now == True:
        plt.show()


In [None]:
"""
Normalize a patch stain to the target image using the method of:

E. Reinhard, M. Adhikhmin, B. Gooch, and P. Shirley, ‘Color transfer between images’, IEEE Computer Graphics and Applications, vol. 21, no. 5, pp. 34–41, Sep. 2001.
"""

from __future__ import division

import cv2 as cv
import numpy as np


### Some functions ###


def lab_split(I):
    """
    Convert from RGB uint8 to LAB and split into channels
    :param I: uint8
    :return:
    """
    I = cv.cvtColor(I, cv.COLOR_RGB2LAB)
    I = I.astype(np.float32)
    I1, I2, I3 = cv.split(I)
    I1 /= 2.55
    I2 -= 128.0
    I3 -= 128.0
    return I1, I2, I3


def merge_back(I1, I2, I3):
    """
    Take seperate LAB channels and merge back to give RGB uint8
    :param I1:
    :param I2:
    :param I3:
    :return:
    """
    I1 *= 2.55
    I2 += 128.0
    I3 += 128.0
    I = np.clip(cv.merge((I1, I2, I3)), 0, 255).astype(np.uint8)
    return cv.cvtColor(I, cv.COLOR_LAB2RGB)


def get_mean_std(I):
    """
    Get mean and standard deviation of each channel
    :param I: uint8
    :return:
    """
    I1, I2, I3 = lab_split(I)
    m1, sd1 = cv.meanStdDev(I1)
    m2, sd2 = cv.meanStdDev(I2)
    m3, sd3 = cv.meanStdDev(I3)
    means = m1, m2, m3
    stds = sd1, sd2, sd3
    return means, stds


### Main class ###

class Normalizer(object):
    """
    A stain normalization object
    """

    def __init__(self):
        self.target_means = None
        self.target_stds = None

    def fit(self, target):
        target = standardize_brightness(target)
        means, stds = get_mean_std(target)
        self.target_means = means
        self.target_stds = stds

    def transform(self, I):
        target_means = np.copy(self.target_means)
        target_stds = np.copy(self.target_stds)
        
        # jitter in HED space
        factor = 0.1
        std_scale = np.random.uniform(-factor, factor, (3,))
        target_stds[0][0][0] = target_stds[0][0][0] * (1.0 + std_scale[0])
        target_stds[1][0][0] = target_stds[1][0][0] * (1.0 + std_scale[1])
        target_stds[2][0][0] = target_stds[2][0][0] * (1.0 + std_scale[2])
        target_means[0][0][0] = np.random.randn(1, ) * target_stds[0][0][0] + target_means[0][0][0]
        target_means[1][0][0] = np.random.randn(1, ) * target_stds[1][0][0] + target_means[1][0][0]
        target_means[2][0][0] = np.random.randn(1, ) * target_stds[2][0][0] + target_means[2][0][0]
        
        I = standardize_brightness(I)
        I1, I2, I3 = lab_split(I)
        means, stds = get_mean_std(I)
        norm1 = ((I1 - means[0]) * (target_stds[0] / stds[0])) + target_means[0]
        norm2 = ((I2 - means[1]) * (target_stds[1] / stds[1])) + target_means[1]
        norm3 = ((I3 - means[2]) * (target_stds[2] / stds[2])) + target_means[2]
        return merge_back(norm1, norm2, norm3)

class Augmentor(object):
    """
    A stain normalization object
    """
    def __init__(self, factor):
        self.factor = factor

    def transform(self, I):
        I = standardize_brightness(I)
        means, stds = get_mean_std(I)
        target_means = np.copy(means)
        target_stds = np.copy(stds)

        # jitter in HED space
        # factor = 0.1
        # std_scale = np.random.uniform(-factor, factor, (3,))
        # target_stds[0][0][0] = target_stds[0][0][0] * (1.0 + std_scale[0])
        # target_stds[1][0][0] = target_stds[1][0][0] * (1.0 + std_scale[1])
        # target_stds[2][0][0] = target_stds[2][0][0] * (1.0 + std_scale[2])
        target_means[0][0][0] = np.random.randn(1, ) * 0.5 * target_stds[0][0][0] + target_means[0][0][0]
        target_means[1][0][0] = np.random.randn(1, ) * 0.5 * target_stds[1][0][0] + target_means[1][0][0]
        target_means[2][0][0] = np.random.randn(1, ) * 0.5 * target_stds[2][0][0] + target_means[2][0][0]
        
        I1, I2, I3 = lab_split(I)
        norm1 = ((I1 - means[0]) * (target_stds[0] / stds[0])) + target_means[0]
        norm2 = ((I2 - means[1]) * (target_stds[1] / stds[1])) + target_means[1]
        norm3 = ((I3 - means[2]) * (target_stds[2] / stds[2])) + target_means[2]
        return merge_back(norm1, norm2, norm3)


In [None]:
template = patch # Image.open(patches[np.random.randint(len(patches))])
# sources = [Image.open(patches[np.random.randint(len(patches))]) for _ in range(10)]

In [None]:
plt.subplot(1, 11, 1)
plt.imshow(template)
plt.axis('off')
for i in range(10):
    plt.subplot(1, 11, i+2)
    plt.imshow(sources[i])
    plt.axis('off')

In [None]:
# reinhard = Normalizer()
# reinhard.fit(np.array(template))

aug = Augmentor(0.1)

# normalized=build_stack((
#     aug.transform(np.array(patch)),    
#     *[aug.transform(np.array(sources[i])) for i in range(10)]
# ))
normalized=build_stack((
    aug.transform(np.array(patch)),    
    *[aug.transform(np.array(patch)) for i in range(10)]
))
patch_grid(normalized, width=11)

In [None]:
# fit target
target = np.array(template)
target = standardize_brightness(target)
means, stds = get_mean_std(target)
target_means = means
target_stds = stds

In [None]:
target_means, target_stds

In [None]:
factor = 0.1
std_scale = np.random.uniform(-factor, factor, (3,))
target_stds[0][0][0] *= (1.0 + std_scale[0])
target_stds[1][0][0] *= (1.0 + std_scale[1])
target_stds[2][0][0] *= (1.0 + std_scale[2])
target_means[0][0][0] = np.random.randn(1, ) * target_stds[0][0][0] + target_means[0][0][0]
target_means[1][0][0] = np.random.randn(1, ) * target_stds[1][0][0] + target_means[1][0][0]
target_means[2][0][0] = np.random.randn(1, ) * target_stds[2][0][0] + target_means[2][0][0]

In [None]:
target_means, target_stds

In [None]:
# transform
def lab_split(I):
    """
    Convert from RGB uint8 to LAB and split into channels
    :param I: uint8
    :return:
    """
    I = cv.cvtColor(I, cv.COLOR_RGB2LAB)
    I = I.astype(np.float32)
    I1, I2, I3 = cv.split(I)
    I1 /= 2.55
    I2 -= 128.0
    I3 -= 128.0
    return I1, I2, I3


def merge_back(I1, I2, I3):
    """
    Take seperate LAB channels and merge back to give RGB uint8
    :param I1:
    :param I2:
    :param I3:
    :return:
    """
    I1 *= 2.55
    I2 += 128.0 
    I3 += 128.0
    I = np.clip(cv.merge((I1, I2, I3)), 0, 255).astype(np.uint8)
    return cv.cvtColor(I, cv.COLOR_LAB2RGB)

factor = 0.5
I = np.array(sources[0])
I = standardize_brightness(I)
I1, I2, I3 = lab_split(I)
means, stds = get_mean_std(I)
print(means)
print(stds)
norm1 = ((I1 - means[0]) * (target_stds[0] / stds[0])) + target_means[0]
norm2 = ((I2 - means[1]) * (target_stds[1] / stds[1])) + target_means[1]
norm3 = ((I3 - means[2]) * (target_stds[2] / stds[2])) + target_means[2]
normalized = merge_back(norm1, norm2, norm3)

In [None]:
plt.imshow(normalized)
plt.show()

In [None]:
plt.imshow(normalized)
plt.show()

In [None]:
def lab_color_augmentation(image, alpha_range=0.1, beta_range=0.1):
    """
    Apply color augmentation in LAB color space.

    Parameters:
    - image: Input image in BGR format (OpenCV standard).
    - alpha_range: Range for random scaling of the L channel.
    - beta_range: Range for random shifting of the A and B channels.

    Returns:
    - Augmented image in BGR format.
    """

    # Convert BGR image to LAB color space
    lab_image = cv.cvtColor(image, cv.COLOR_RGB2LAB)

    # Random scaling of L channel
    alpha = 1.0 + np.random.uniform(-alpha_range, alpha_range)
    lab_image[:, :, 0] = np.clip(alpha * lab_image[:, :, 0], 0, 255)

    # Random shifting of A and B channels
    beta_a = np.random.uniform(-beta_range, beta_range)
    beta_b = np.random.uniform(-beta_range, beta_range)
    lab_image[:, :, 1] = np.clip(lab_image[:, :, 1] + beta_a, -128, 127)
    lab_image[:, :, 2] = np.clip(lab_image[:, :, 2] + beta_b, -128, 127)

    # Convert back to BGR color space
    augmented_image = cv.cvtColor(lab_image, cv.COLOR_LAB2RGB)

    return augmented_image

In [None]:
normalized=build_stack((
    patch,    
    *[lab_color_augmentation(patch) for i in range(10)]
))
patch_grid(normalized, width=11)