In [22]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import glob
from skimage import io
import os
import numpy as np
from typing import List, Tuple, Dict
import time
import itertools

In [3]:
def brenner_gradient(img: np.array) -> int:
        return np.sum((img[2:]-img[:-2])**2)

In [4]:
h, w = 720, 1280
patch_size = [360, 256]
x_steps = int(h / patch_size[0])
y_steps = int(w / patch_size[1])
x_coord = np.linspace(int(patch_size[0]/2), h - int(patch_size[0]/2), x_steps)
y_coord = np.linspace(int(patch_size[1]/2), w - int(patch_size[1]/2), y_steps)
patch_coords = [(x, y) for x in x_coord for y in y_coord]

In [5]:
def get_label_dict(sample_box: str, data_dir: str = "/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset", patch_coords: List = patch_coords, patch_size: List = [360, 256]) -> Dict:
    samples = glob.glob(os.path.join(data_dir, sample_box,'sample*'))
    sample_dict = {sample: {} for sample in samples}
    for sample in samples:
        images = glob.glob(os.path.join(sample,'*distance*.jpg'))
        image_dict = {patch: [0, 0, 0] for patch in range(len(patch_coords))}
        for image in images:
            img = io.imread(image)
            for patch_idx, (x, y) in enumerate(patch_coords):
                patch = img[int(x-int(patch_size[0]/2)):int(x+int(patch_size[0]/2)), int(y-int(patch_size[1]/2)):int(y+int(patch_size[1]/2))]
                brenner_value = brenner_gradient(patch)
                if brenner_value > image_dict[patch_idx][-1]:
                    image_dict[patch_idx] = x, y, brenner_value
        sample_dict[sample] = image_dict
    return sample_dict


In [23]:
def get_label_samples_dict(sample: str, data_dir: str = "/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset", patch_coords: List = patch_coords, patch_size: List = [360, 256]) -> Dict:
    images = glob.glob(os.path.join(sample,'*distance*.jpg'))
    image_dict = {patch: [0, 0, 0] for patch in range(len(patch_coords))}
    for image in images:
        img = io.imread(image)
        for patch_idx, (x, y) in enumerate(patch_coords):
            patch = img[int(x-int(patch_size[0]/2)):int(x+int(patch_size[0]/2)), int(y-int(patch_size[1]/2)):int(y+int(patch_size[1]/2))]
            brenner_value = brenner_gradient(patch)
            if brenner_value > image_dict[patch_idx][-1]:
                image_dict[patch_idx] = x, y, brenner_value
    return image_dict

In [29]:
data_dir = "/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset"
sample_boxes = next(os.walk(data_dir))[1]
samples = list(itertools.chain.from_iterable([glob.glob(os.path.join(data_dir, sample_box,'sample*')) for sample_box in sample_boxes]))
start = time.time()
with ProcessPoolExecutor(max_workers=61) as executor:
        futures = executor.map(
            get_label_samples_dict,
            samples,
        )
end = time.time()
print(end - start)
len(list(futures))

155.92690253257751


186

In [26]:
for f in futures:
    print(f)
    break

{0: (180.0, 128.0, 1760758), 1: (180.0, 384.0, 4257502), 2: (180.0, 640.0, 4614287), 3: (180.0, 896.0, 1020184), 4: (180.0, 1152.0, 1547338), 5: (540.0, 128.0, 1507739), 6: (540.0, 384.0, 7319263), 7: (540.0, 640.0, 11780154), 8: (540.0, 896.0, 3492988), 9: (540.0, 1152.0, 2379728)}


185

In [23]:
import torch

from patch_dataset import FocusDataset

In [24]:
dataset = torch.load("/home/maf4031/focus_model/data/test_patch_dataset.pt")

In [25]:
dataset.label_dict

{'Inflammation_3_4': {'/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset/Inflammation_3_4/sample_13': {0: 1760758,
   1: 4257502,
   2: 4614287,
   3: 1020184,
   4: 1547338,
   5: 1507739,
   6: 7319263,
   7: 11780154,
   8: 3492988,
   9: 2379728},
  '/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset/Inflammation_3_4/sample_62': {0: 15601880,
   1: 20695909,
   2: 22807902,
   3: 20838763,
   4: 15558733,
   5: 17863837,
   6: 21145248,
   7: 21302495,
   8: 19286478,
   9: 15056769},
  '/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset/Inflammation_3_4/sample_3': {0: 1692953,
   1: 3582010,
   2: 3552537,
   3: 6320795,
   4: 4048874,
   5: 1466343,
   6: 1634523,
   7: 5820465,
   8: 5918882,
   9: 4592091},
  '/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset/Inflammation_3_4/sample_7': {0: 1798281,
   1: 1682947,
   2: 4512255,
   3: 869728,
   4: 1339381,
   5: 1576597,
   6: 6761006,
   7: 4208716,
   8: 768905,
   9: 2479194},
  '/n/data2/hms/dbmi/kyu/lab/maf4031/focus_dataset/Inflam