In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import torchvision.transforms.functional as f
import torch
import h5py
import matplotlib.pyplot as plt
import concurrent.futures

### Visualize current DINO augmentations

In [None]:
def inverse_normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [None]:
# patches = "/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt"
patches = "/lustre/groups/shared/histology_data/patch_lists/all.txt"
patches = np.loadtxt(patches, dtype=str).tolist()
# patches = list(Path("/lustre/groups/shared/histology_data/TCGA/ACC/patches").glob("**/*.h5"))
# patches = list(Path("/lustre/groups/shared/tcga/CRC/patches/512px_crc_wonorm_complete_diag_frozen").glob("**/*.jpeg"))
# patches = np.loadtxt("/lustre/groups/shared/histology_data/TCGA/CRC/patches/512px_crc_wonorm_complete_diag_frozen.txt", dtype=str, max_rows=100).tolist()

In [None]:
len(patches)

In [None]:
from dinov2.data import DataAugmentationDINO

In [None]:
data_transform = DataAugmentationDINO(
    (1., 1.), #cfg.crops.global_crops_scale,
    (0.32, 0.32), #cfg.crops.local_crops_scale,
    8, #cfg.crops.local_crops_number,
    224, #global_crops_size=cfg.crops.global_crops_size,
    local_crops_size=98, #cfg.crops.local_crops_size,
)

In [None]:
id = np.random.randint(0, len(patches))
patch = Image.open(patches[id]).convert(mode="RGB")

In [None]:
patch

In [None]:
out = data_transform(patch)

In [None]:
fig = plt.figure(figsize=(20, 2))
for i in range(2):
    rev = inverse_normalize(tensor=out[f'global_crops'][i], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, i+1)
    plt.imshow(img)
    # plt.axis('off')
for k in range(8):
    rev = inverse_normalize(tensor=out[f'local_crops'][k], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, k+3)
    plt.imshow(img)
    # plt.axis('off')
plt.show()

In [None]:
# determine whether you want to plot local or global crops
# global crop 0 and 1 have different settings, local crops are all the same
def process_image(index):
    out = data_transform(patch)
    rev = inverse_normalize(tensor=out[f'global_crops'][1], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)

    plt.subplot(2, 5, index + 1)
    plt.imshow(img)
    plt.axis('off')

In [None]:
# plot the global crops
num_images = 10

# Create a figure and set the size
fig = plt.figure(figsize=(10, 4))

# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor() as executor:
    # Map the process_image function to each index in parallel
    executor.map(process_image, range(num_images))

# Adjust layout and show the plot
plt.tight_layout()
plt.show()

In [None]:
# adapted from  https://github.com/DIAGNijmegen/pathology-he-auto-augment/blob/main/he-randaugment/custom_hed_transform.py
import numpy as np
from scipy import linalg
from skimage.util import dtype, dtype_limits
from skimage.exposure import rescale_intensity
import time

rgb_from_hed = np.array([[0.65, 0.70, 0.29],
                         [0.07, 0.99, 0.11],
                         [0.27, 0.57, 0.78]]).astype('float32')
hed_from_rgb = linalg.inv(rgb_from_hed).astype('float32')


def rgb2hed(rgb):

    return separate_stains(rgb, hed_from_rgb)

def hed2rgb(hed):

    return combine_stains(hed, rgb_from_hed)

def separate_stains(rgb, conv_matrix):

    rgb = dtype.img_as_float(rgb, force_copy=True).astype('float32')
    rgb += 2
    stains = np.dot(np.reshape(-np.log(rgb), (-1, 3)), conv_matrix)
    return np.reshape(stains, rgb.shape)


def combine_stains(stains, conv_matrix):


    stains = dtype.img_as_float(stains.astype('float64')).astype('float32')  # stains are out of range [-1, 1] so dtype.img_as_float complains if not float64
    logrgb2 = np.dot(-np.reshape(stains, (-1, 3)), conv_matrix)
    rgb2 = np.exp(logrgb2)
    return rescale_intensity(np.reshape(rgb2 - 2, stains.shape),
                             in_range=(-1, 1))

In [None]:
patch_image = np.array(patch)

In [None]:
# HED color augmentations adapted from  https://github.com/DIAGNijmegen/pathology-he-auto-augment/blob/main/he-randaugment/custom_hed_transform.py
# Tellez et al.
__cutoff_range = (0.15, 0.85)

def hed_jitter(factor):
    __biases = [np.random.uniform(-factor, factor), np.random.uniform(-factor, factor), np.random.uniform(-factor, factor)]
    __sigmas = [np.random.uniform(-factor, factor), np.random.uniform(-factor, factor), np.random.uniform(-factor, factor)]

    patch_mean = np.mean(a=patch) / 255.0
    if __cutoff_range[0] <= patch_mean <= __cutoff_range[1]:
        # Reorder the patch to channel last format and convert the image patch to HED color coding.
        #
        # patch_image = np.transpose(a=patch_image, axes=(1, 2, 0))
        patch_hed = rgb2hed(rgb=patch_image)

        # Augment the Haematoxylin channel.
        #
        if __sigmas[0] != 0.0:
            patch_hed[:, :, 0] *= (1.0 + __sigmas[0])

        if __biases[0] != 0.0:
            patch_hed[:, :, 0] += __biases[0]

        # Augment the Eosin channel.
        #
        if __sigmas[1] != 0.0:
            patch_hed[:, :, 1] *= (1.0 + __sigmas[1])

        if __biases[1] != 0.0:
            patch_hed[:, :, 1] += __biases[1]

        # Augment the DAB channel.
        #
        if __sigmas[2] != 0.0:
            patch_hed[:, :, 2] *= (1.0 + __sigmas[2])

        if __biases[2] != 0.0:
            patch_hed[:, :, 2] += __biases[2]
        # Convert back to RGB color coding and order back to channels first order.
        #
        patch_rgb = hed2rgb(hed=patch_hed)
        patch_rgb = np.clip(a=patch_rgb, a_min=0.0, a_max=1.0)
        patch_rgb *= 255.0
        patch_rgb = patch_rgb.astype(dtype=np.uint8)

        # patch_transformed = np.transpose(a=patch_rgb, axes=(2, 0, 1))
        patch_transformed = patch_rgb
    return patch_transformed


In [None]:
fig = plt.figure(figsize=(20, 2))
plt.subplot(1, 10, 1)
plt.imshow(patch)
plt.axis('off')
for i in tqdm(range(9)):
    plt.subplot(1, 10, i + 2)
    patch_transformed = hed_jitter(0.05)
    plt.imshow(patch_transformed)
    plt.axis('off')

plt.show()

In [None]:
patch_image.shape

### Create list of patches

In [None]:
## Create list of patches
PATH = "/lustre/groups/shared/histology_data/tcga_patches/patches/2.0"
patches = list(Path(PATH).glob("**/*.png"))

In [None]:
# np.savetxt("/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt", patches, fmt="%s", delimiter="\n")

In [None]:
len(patches)