In [217]:
import torch
from torchvision import datasets, transforms

from patch_sampler import GridSamplerV2

In [218]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

flowers = datasets.Flowers102("../data/flowers", transform=transform, download=True)

In [219]:
img = flowers[4][0]
img

tensor([[[0.1412, 0.1373, 0.1373,  ..., 0.0667, 0.0706, 0.0784],
         [0.1490, 0.1451, 0.1412,  ..., 0.0863, 0.0824, 0.0784],
         [0.1529, 0.1451, 0.1333,  ..., 0.0745, 0.0784, 0.0784],
         ...,
         [0.1647, 0.1608, 0.1451,  ..., 0.2392, 0.2510, 0.2667],
         [0.1765, 0.1647, 0.1373,  ..., 0.2235, 0.2549, 0.2706],
         [0.2039, 0.1922, 0.1608,  ..., 0.1961, 0.2314, 0.2510]],

        [[0.2235, 0.2157, 0.2039,  ..., 0.1098, 0.1176, 0.1255],
         [0.2275, 0.2157, 0.2039,  ..., 0.1255, 0.1255, 0.1255],
         [0.2235, 0.2118, 0.1922,  ..., 0.1216, 0.1255, 0.1255],
         ...,
         [0.1451, 0.1451, 0.1333,  ..., 0.3569, 0.3725, 0.3843],
         [0.1529, 0.1490, 0.1255,  ..., 0.3412, 0.3725, 0.3922],
         [0.1804, 0.1765, 0.1490,  ..., 0.3137, 0.3490, 0.3725]],

        [[0.0980, 0.0941, 0.0941,  ..., 0.0627, 0.0392, 0.0392],
         [0.1020, 0.0941, 0.0902,  ..., 0.0706, 0.0471, 0.0392],
         [0.0863, 0.0824, 0.0745,  ..., 0.0471, 0.0431, 0.

In [225]:
grid_sampler = GridSamplerV2()
patches, coords = grid_sampler(img)

In [226]:
patches.shape

torch.Size([196, 3, 16, 16])

In [222]:
hist_a0 = torch.histogram(patches[0][0], bins=16, range=[0, 1])[0]
hist_a0[hist_a0 == 0] = 1
hist_a0

tensor([  9.,  68., 144.,  12.,  10.,   9.,   4.,   1.,   1.,   1.,   1.,   1.,
          1.,   1.,   1.,   1.])

In [223]:
def batch_histogram(data_tensor, num_classes=-1):
    """
    Computes histograms, even if in batches (as opposed to torch.histc and torch.histogram).
    Arguments:
        data_tensor: a D1 x ... x D_n torch.LongTensor
        num_classes (optional): the number of classes present in data.
                                If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
    Returns:
        A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
        containing histograms of the last dimension D_n of tensor,
        that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
    """
    maxd = data_tensor.max()
    nc = (maxd+1) if num_classes <= 0 else num_classes
    hist = torch.zeros((*data_tensor.shape[:-1], nc), dtype=data_tensor.dtype, device=data_tensor.device)
    ones = torch.tensor(1, dtype=hist.dtype, device=hist.device).expand(data_tensor.shape)
    hist.scatter_add_(-1, ((data_tensor * nc) // (maxd+1)).long(), ones)
    return hist

hist = batch_histogram(patches.flatten(-2) * 255, 16)[0][0]
hist

tensor([  8.,  63., 149.,  12.,  11.,   8.,   5.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.])

In [310]:
def find_top_k_entropy_patches(patches, k):
    patch_entropy = calculate_patches_entropy(patches)
    sorted_entropy, patch_indices = torch.sort(patch_entropy, descending=True, dim=-1)
    return patch_indices[:k]

def calculate_patches_entropy(patches):
    histograms = batch_histogram(patches.flatten(-2) * 255, 16)
    histograms = torch.where(histograms == 0, 1, histograms)
    histograms = torch.nn.functional.normalize(histograms, p=1, dim=-1)
    patches_entropy = -torch.sum((histograms * torch.log2(histograms)), dim=-1)
    patches_entropy = torch.sum(patches_entropy, dim=-1)
    return patches_entropy

In [305]:
def divide_patch_coords_in_four(coords):
    patch_sizes = torch.cat((
        (coords[:, 2] - coords[:, 0]).unsqueeze(1),
        (coords[:, 3] - coords[:, 1]).unsqueeze(1)
    ), dim=-1)

    upper_left = torch.cat((
        coords[:, :2],
        coords[:, 2:] - patch_sizes / 2
    ), dim=-1)
    upper_right = torch.cat((
        coords[:, 0].unsqueeze(1),
        (coords[:, 1] + patch_sizes[:, 1] / 2).unsqueeze(1),
        (coords[:, 2] - patch_sizes[:, 0] / 2).unsqueeze(1),
        coords[:, 3].unsqueeze(1)
    ), dim=-1)
    lower_left = torch.cat((
        (coords[:, 0] + patch_sizes[:, 0] / 2).unsqueeze(1),
        coords[:, 1].unsqueeze(1),
        coords[:, 2].unsqueeze(1),
        (coords[:, 3] - patch_sizes[:, 1] / 2).unsqueeze(1)
    ), dim=-1)
    lower_right = torch.cat((
        coords[:, :2] + patch_sizes / 2,
        coords[:, 2:]
    ), dim=-1)

    coords_after_division = torch.cat((
        upper_left, upper_right,
        lower_left, lower_right
    ), dim=0)
    return coords_after_division


def divide_patches_pixels_in_four(patches):
    patch_size = patches.shape[-1]
    upper_left = patches[:, :, :(patch_size // 2), :(patch_size // 2)]
    upper_right = patches[:, :, :(patch_size // 2), (patch_size // 2):]
    lower_left = patches[:, :, (patch_size // 2):, :(patch_size // 2)]
    lower_right = patches[:, :, (patch_size // 2):, (patch_size // 2):]

    resize = transforms.Resize((patch_size, patch_size), antialias=True)
    upper_left = resize.forward(upper_left)
    upper_right = resize.forward(upper_right)
    lower_left = resize.forward(lower_left)
    lower_right = resize.forward(lower_right)

    patches_after_division = torch.cat((
        upper_left, upper_right,
        lower_left, lower_right
    ), dim=0)
    return patches_after_division

#TODO make it so that it accepts patch indices and integrates the resulting patches with the rest
#TODO remember about batches!
def divide_patches_in_four(patches, coords):
    coords_after_division = divide_patch_coords_in_four(coords)
    patches_after_division = divide_patches_pixels_in_four(patches)
    return patches_after_division, coords_after_division

In [307]:
divide_patches_in_four(patches, coords)

(tensor([[[[0.1412, 0.1402, 0.1382,  ..., 0.1228, 0.1182, 0.1159],
           [0.1432, 0.1422, 0.1402,  ..., 0.1217, 0.1170, 0.1146],
           [0.1472, 0.1462, 0.1440,  ..., 0.1195, 0.1144, 0.1119],
           ...,
           [0.1438, 0.1522, 0.1691,  ..., 0.1658, 0.1612, 0.1589],
           [0.1372, 0.1466, 0.1654,  ..., 0.1698, 0.1657, 0.1636],
           [0.1339, 0.1438, 0.1636,  ..., 0.1718, 0.1679, 0.1660]],
 
          [[0.2235, 0.2214, 0.2171,  ..., 0.1566, 0.1542, 0.1529],
           [0.2244, 0.2221, 0.2173,  ..., 0.1538, 0.1513, 0.1500],
           [0.2263, 0.2234, 0.2176,  ..., 0.1482, 0.1455, 0.1441],
           ...,
           [0.2272, 0.2340, 0.2476,  ..., 0.1774, 0.1660, 0.1604],
           [0.2248, 0.2322, 0.2472,  ..., 0.1844, 0.1723, 0.1662],
           [0.2235, 0.2313, 0.2470,  ..., 0.1879, 0.1754, 0.1692]],
 
          [[0.0980, 0.0971, 0.0951,  ..., 0.0646, 0.0634, 0.0627],
           [0.0988, 0.0976, 0.0951,  ..., 0.0640, 0.0623, 0.0614],
           [0.1003, 0.09