In [None]:
import os
import time
import re
import uuid
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
import tifffile
from PIL import Image, ImageOps

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from violet.utils.preprocessing import get_svs_array, is_background, extract_st_tiles, extract_svs_tiles
from violet.utils.dataloaders import listfiles

In [None]:
out_dir = '/data/violet/sandbox/multiresolution_dino_tcia'
raw_dir = os.path.join(out_dir, 'raw_input', 'test_A')
norm_dir = os.path.join(out_dir, 'normalized_input')

Path(raw_dir).mkdir(parents=True, exist_ok=True)
Path(norm_dir).mkdir(parents=True, exist_ok=True)

In [None]:
from torchvision.transforms import RandomResizedCrop

def extract_multiresolution_tiles(fp, out_dir, s=(64, 2048), target=512, n=1000, coverage=.5, delta=15*60):
    if fp.split('.')[-1] in ['tif', 'tiff']:
        img = tifffile.imread(fp)
    elif fp.split('.')[-1] == 'svs':
        img = get_svs_array(fp, scale=1.)
    else:
        raise RuntimeError('invalid extension')
        
    sample = fp.split('/')[-1].split('.')[0].replace(' ', '_')
    
    print(img.shape)
    t = RandomResizedCrop(target, scale=(float(s[0]) / s[1], 1.), ratio=(1., 1.))
    if min(img.shape[0], img.shape[1]) < s[1]:
        raise RuntimeError('crop size must be smaller than image size')
    # scale factors
    valid_r, valid_c = np.arange(img.shape[0] - s[1]), np.arange(img.shape[1] - s[1])
    
    count = 0
    start = time.time()
    while count < n:
        r, c = np.random.choice(valid_r), np.random.choice(valid_c)
        patch = img[r:r + s[1], c:c + s[1]]
        
        if not is_background(patch, coverage=coverage):
            patch = Image.fromarray(patch)
            tile = t(patch)
            tile.save(os.path.join(raw_dir, f'{sample}_{count}.jpeg'))
            count += 1
        if time.time() - start > delta:
            break

In [None]:
# # do st
# filemap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t')
# filemap = filemap[~pd.isnull(filemap['highres_image'])]
# keep = [d for d in set(filemap['disease'])
#         if d=='pdac']
# filemap = filemap[[True if d in keep and t == 'ffpe' else False
#                   for d, t in zip(filemap['disease'], filemap['tissue_type'])]]
# filemap

In [None]:
# data_map = {row['sample_id']: {'spatial': row['spaceranger_output'], 'tif': row['highres_image']}
#             for i, row in filemap.iterrows()}
# regions = [4, 8]
# imgs, img_ids = extract_st_tiles(data_map, normalize=False, regions=regions)

In [None]:
# import matplotlib.pyplot as plt
# for i in imgs[0]:
#     print(i.shape)
#     plt.imshow(i)
#     plt.show()

In [None]:
# for tiles, img_id in zip(imgs, img_ids):
#     w = 8
#     fname = os.path.join(in_dir, f'{img_id}_{w}.jpeg')
#     im = Image.fromarray(tiles[2]).resize((512, 512))
#     im.save(fname)


In [None]:

fps = sorted(listfiles('/data/tcia/PDA/', regex=r'.svs$'))
fps = np.random.choice(fps, size=10, replace=False)
data_map = {fp.split('/')[-1].split('.')[0]:fp for fp in fps}
data_map

In [None]:
for s, fp in data_map.items():
    print(s)
    extract_multiresolution_tiles(fp, out_dir, n=5000, delta=60 * 5)

run normalization

In [None]:
fps = sorted(listfiles(norm_dir))
len(fps)

In [None]:
for fp in fps[:20]:
    im = Image.open(fp)
    im.show()
    

calculate mean and std of dataset

In [None]:
from einops import rearrange
sum_means, sum_stds = np.asarray([0., 0., 0.]), np.asarray([0., 0., 0.])
n = 1000
for fp in np.random.choice(fps, size=n):
    img = np.asarray(Image.open(fp))
    sum_means += np.mean(rearrange(img, 'h w c -> (h w) c'), axis=0)
    sum_stds += np.std(rearrange(img, 'h w c -> (h w) c'), axis=0)
sum_means / n / 255., sum_stds / n / 255.

run dino