In [1]:
import torch
import os
import numpy as np
from PIL import Image, ImageDraw
Image.MAX_IMAGE_PIXELS = 10000000000
from torchvision.transforms import functional as F
import math
from torchvision.utils import save_image


base_dir = '/projects/patho2/melanoma_diagnosis/results/2scale_128x4x32x2/7.5_12/posemb_dropout0.25/visual/all/' 
im_dir = '/projects/patho2/melanoma_diagnosis/x10/split/test/'

In [2]:
def _get_image_size(img):
    if F._is_pil_image(img):
        return img.size
    elif isinstance(img, torch.Tensor) and img.dim() > 2:
        return img.shape[-2:][::-1]
    else:
        raise TypeError("Unexpected type {}".format(type(img)))


class DivideToScales(object):
    def __init__(self, scale_levels: list, size=None, interpolation=Image.BICUBIC):
        assert len(scale_levels) > 0, "Atleast 1 scale is required. Got: {}".format(scale_levels)

        self.scale_levels = sorted(scale_levels) #[0.25, 0.5, 0.75, 1.0]
        self.num_scales = len(scale_levels)

        if size is not None:
            self.resize = self.get_sizes(scale_levels=scale_levels, size=size)
        else:
            self.resize = None
        self.interpolation = interpolation

    @staticmethod
    def get_sizes(scale_levels, size):
        resize = dict()
        height_1x, width_1x = size
        for sc in scale_levels:
            scaled_h, scaled_w = int(sc * height_1x), int(sc * width_1x)
            # assert height_1x == int(scaled_h / sc), "Scale is not correct. Got: {} and {}".format(height_1x,
            #                                                                                       int(scaled_h / sc))
            # assert width_1x == int(scaled_w / sc), "Scale is not correct. Got: {} and {}".format(width_1x,
            #                                                                                      int(scaled_w / sc))
            resize[sc] = [scaled_h, scaled_w]
        return resize

    @staticmethod
    def get_params(image_sizes, crop_zoom):
        image_width, image_height = image_sizes
        crop_height = int(round(image_height / crop_zoom))
        crop_width = int(round(image_width / crop_zoom))
        return crop_height, crop_width

    def divide_image_to_scales(self, image, mask, scale):
        image = F.resize(image, self.resize[scale], self.interpolation)
        if mask is not None:
             mask = F.resize(mask, self.resize[scale], Image.NEAREST)
        return image, mask

    def get_scales(self, sample: dict):
        image, mask = sample['image'], sample['mask']
        if self.resize is None:
            width, height = _get_image_size(img=image)
            self.resize = self.get_sizes(scale_levels=self.scale_levels, size=(height, width))

        images, masks = [], [] if mask is not None else None
        for i in self.scale_levels:
            im, m = self.divide_image_to_scales(image, mask, i)
            images.append(im)
            if mask is not None:
                masks.append(m)

        return {'image': images, 'mask': masks}

    def __call__(self, sample: dict):
        return self.get_scales(sample=sample)


def calculate_bbox(img_size, resize2, idx):
    """
    Function that return the bounding box of a word given its index
    Args:
        ind: int, ind < number of words

    Returns:
        Bounding box(int[]): [h_low, h_high, w_low, w_high]
    """
    h, w = img_size
    c_w = w // resize2
    c_h = h // resize2
    crop_length = c_w * c_h
    assert idx < crop_length, "Index Out of Bound"

    # [index]: [pad_top, pad_left, pad_right, pad_bottom]
    # top= max((idx % c_h) * (resize2), 0)
    top = max(math.floor(idx / c_w) * resize2, 0)
    # bottom = min(h, (idx % c_h) * resize2+ resize2)
    bottom = min(h, math.floor(idx / c_w) * resize2 + resize2)
    left = max((idx % c_w) * resize2, 0)
    right = min(w, (idx % c_w) * resize2 + resize2)
    # left = max(math.floor(idx / c_h) * resize2, 0)
    # right = min(w, math.floor(idx / c_h) * resize2 + resize2)

    return [top, bottom, left, right]

def resize_image_to_k_crops_size(image, n_crops):
    w, h = image.size
    crop_size_h = int(math.ceil(h / n_crops))
    crop_size_w = int(math.ceil(w / n_crops))

    # transform crop
    # resize crop to fit the crop size
    new_h = n_crops * crop_size_h
    new_w = n_crops * crop_size_w

    image = F.resize(img=image, size=[new_h, new_w], interpolation=Image.BICUBIC)
    return image, (crop_size_w, crop_size_h)


In [3]:
class KCrops(object):
    def __init__(self, scale_levels: list, n_crops=[7]):
        self.n_crops_h = n_crops
        self.n_crops_w = n_crops

        self.scales = scale_levels

    def divide_image_to_crops(self, image, scale_ind):
        # resize image into small crops
        channel, h, w = image.shape
        crop_size_h = int(math.ceil(h / self.n_crops_h[scale_ind]))
        crop_size_w = int(math.ceil(w / self.n_crops_w[scale_ind]))

        # transform crop
        # resize crop to fit the crop size
        new_h = self.n_crops_h[scale_ind] * crop_size_h
        new_w = self.n_crops_w[scale_ind] * crop_size_w

        # transform to crops
        ## Image to BAGS
        image = F.resize(img=image, size=[new_h, new_w], interpolation=Image.BICUBIC)
        # [C x N_B_H x B_H x W]] --> [C x N_B_H x B_H x N_B_W x B_W]
        crops = torch.reshape(image, (channel, self.n_crops_h[scale_ind], crop_size_h,
                                      self.n_crops_w[scale_ind], crop_size_w))
        # [C x N_B_H x B_H x N_B_W x B_W] --> [C x N_B_H x N_B_W x B_H x B_W]
        crops = crops.permute(0, 1, 3, 2, 4)

        # '''
        # Preserve dimensionality, move to forward loop
        # '''
        # #[C x N_B_H x N_B_W x B_H x B_W]--> [C x N_B_w * N_B_h x B_H x B_W]
        crops = torch.reshape(crops, (channel, self.n_crops_h[scale_ind] * self.n_crops_w[scale_ind],
                                      crop_size_h, crop_size_w))
        # #[C x N_B_w * N_B_h x B_H x B_W] --> [N_B_w * N_B_h x C x B_H x B_W]
        crops = crops.permute(1, 0, 2, 3)
        return crops

    def __call__(self, sample: dict):
        image, mask = sample['image'], sample['mask']
        images = []
        for i, im in enumerate(image):
            images.append(self.divide_image_to_crops(im, i))
        masks = None
        if mask is not None:
            masks = []
            for i, m in enumerate(mask):
                masks.append(self.divide_image_to_crops(m, i))
        return {'image': images, 'mask': masks}
    

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8
    In the other cases, tensors are returned without scaling.
    """

    def __call__(self, sample: dict):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        image, mask = sample['image'], sample['mask']
        image_tensor = [F.to_tensor(im) for im in image] if isinstance(image, list) else F.to_tensor(image)
        mask_tensor = None
        if mask is not None:
            mask_tensor = [torch.ByteTensor(np.array(m)).unsqueeze(0) for m in mask] if isinstance(mask, list) else F.to_tensor(mask)

        return {'image': image_tensor, 'mask': mask_tensor}

    def __repr__(self):
        return self.__class__.__name__ + '()'

In [None]:
def draw_border_as_color(crop, rgb=[1, 0, 0], border=10):
    # C x w x h
    crop[:,:, :border] = torch.Tensor([[[rgb[0]]], [[rgb[1]]], [[rgb[2]]]]).expand((3, crop.shape[1], border))
    crop[:, :, -border:] = torch.Tensor([[[rgb[0]]], [[rgb[1]]], [[rgb[2]]]]).expand((3, crop.shape[1], border))
    crop[:, :border, :] = torch.Tensor([[[rgb[0]]], [[rgb[1]]], [[rgb[2]]]]).expand((3, border, crop.shape[2]))
    crop[:, -border:, :] = torch.Tensor([[[rgb[0]]], [[rgb[1]]], [[rgb[2]]]]).expand((3, border, crop.shape[2]))
    return crop

In [4]:
case_dir = os.path.join(base_dir, '3', 'MP_0020_x10_z0_0', 'scale_1')
im_p = os.path.join(im_dir, '5', 'MP_0020_x10_z0_0.tif')

In [6]:
# extract k crops
patch_grad = torch.load(os.path.join(case_dir, 'patch_grads.pth'))
ind = np.argsort(patch_grad)
original = Image.open(im_p).convert('RGB')
t = DivideToScales([1.25])
sample = t({'image': original, 'mask': None})
image_scales = sample['image']
image, crop_size = resize_image_to_k_crops_size(image_scales[0], len(patch_grad))
t2 = ToTensor()
t3= KCrops([1.25], [int(math.sqrt(len(patch_grad)))])
sample = t2({'image': image, 'mask': None})
image = sample['image']
sample = t3({'image': [image], 'mask': None})
crops = sample['image'][0]
#os.mkdir(os.path.join(case_dir, 'bags'))
for i in range(10):
    #save_image(crops[ind[-1-i]], os.path.join(case_dir, 'top_bags', 'crop_{}.jpg'.format(i)))
    # draw rectangles on top 5
    crop = draw_border_as_color(crops[ind[-1-i]])
# re-arrange crops
orig_from_crops = crops.permute(1, 0, 2, 3)
orig_from_crops = orig_from_crops.reshape(-1, int(sqrt(crops)), 
                                          int(sqrt(crops)), 
                                          crops.shape[2], crops.shape[3])
orig_from_crops = orig_from_crops.permute(0, 1, 3, 2, 4)
orig_from_crops = orig_from_crops.reshape(-1,  int(sqrt(crops))*crops.shape[2],
                                          int(sqrt(crops))*crops.shape[3])

In [45]:
original.save('whole.jpg')

In [52]:
crops.shape

torch.Size([81, 3, 180, 1278])

In [53]:
orig_from_crops = crops.permute(1, 0, 2, 3)

In [54]:
orig_from_crops = orig_from_crops.reshape(-1, 9, 9, 180, 1278)

In [56]:
orig_from_crops = orig_from_crops.permute(0, 1, 3, 2, 4)

In [58]:
orig_from_crops = orig_from_crops.reshape(-1, 9 *180, 9 * 1278)

In [59]:
orig_from_crops.shape

torch.Size([3, 1620, 11502])

In [60]:
save_image(orig_from_crops, 'test.jpg')