In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import torchvision.transforms.functional as f
import torch
import h5py
import matplotlib.pyplot as plt
import concurrent.futures

### Visualize current DINO augmentations

In [None]:
def inverse_normalize(tensor, mean, std):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

In [None]:
patches = "/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt"
patches = np.loadtxt(patches, dtype=str).tolist()
# patches = list(Path("/lustre/groups/shared/histology_data/TCGA/ACC/patches").glob("**/*.h5"))
# patches = list(Path("/lustre/groups/shared/tcga/CRC/patches/512px_crc_wonorm_complete_diag_frozen").glob("**/*.jpeg"))
# patches = np.loadtxt("/lustre/groups/shared/histology_data/TCGA/CRC/patches/512px_crc_wonorm_complete_diag_frozen.txt", dtype=str, max_rows=100).tolist()

In [None]:
len(patches)

In [None]:
from dinov2.data import DataAugmentationDINO

In [None]:
data_transform = DataAugmentationDINO(
    (1., 1.), #cfg.crops.global_crops_scale,
    (0.32, 0.32), #cfg.crops.local_crops_scale,
    8, #cfg.crops.local_crops_number,
    224, #global_crops_size=cfg.crops.global_crops_size,
    local_crops_size=98, #cfg.crops.local_crops_size,
)

In [None]:
id = np.random.randint(0, len(patches))
patch = Image.open(patches[id]).convert(mode="RGB")

In [None]:
patch

In [None]:
out = data_transform(patch) 

In [None]:
fig = plt.figure(figsize=(20, 2))
for i in range(2):
    rev = inverse_normalize(tensor=out[f'global_crops'][i], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, i+1)
    plt.imshow(img)
    # plt.axis('off')
for k in range(8):
    rev = inverse_normalize(tensor=out[f'local_crops'][k], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)
    plt.subplot(1, 10, k+3)
    plt.imshow(img)
    # plt.axis('off')
plt.show()

In [None]:
# determine whether you want to plot local or global crops
# global crop 0 and 1 have different settings, local crops are all the same
def process_image(index):
    out = data_transform(patch)
    rev = inverse_normalize(tensor=out[f'global_crops'][1], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = f.to_pil_image(rev)

    plt.subplot(2, 5, index + 1)
    plt.imshow(img)
    plt.axis('off')

In [None]:
# plot the global crops
num_images = 10

# Create a figure and set the size
fig = plt.figure(figsize=(10, 4))

# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor() as executor:
    # Map the process_image function to each index in parallel
    executor.map(process_image, range(num_images))

# Adjust layout and show the plot
plt.tight_layout()
plt.show()

### Create list of patches

In [None]:
## Create list of patches
PATH = "/lustre/groups/shared/histology_data/tcga_patches/patches/2.0"
patches = list(Path(PATH).glob("**/*.png"))

In [None]:
np.savetxt("/home/haicu/sophia.wagner/datasets/TCGA_all_20X_1024px.txt", patches, fmt="%s", delimiter="\n")

In [None]:
len(patches)