In [42]:
!pip install torchray wandb

import os
import wandb
import torch
import random
import time

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
from scipy.ndimage.filters import gaussian_filter

import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset

from torchray.attribution.rise import rise
from torchray.benchmark.datasets import get_dataset, coco_as_mask, voc_as_mask
from torchray.utils import imsc, get_device, xmkdir
import torchray.attribution.extremal_perturbation as elp



  from scipy.ndimage.filters import gaussian_filter


In [55]:
import torch
import numpy as np
from torchray.utils import imsc
from matplotlib import pyplot as plt
import torch.nn.functional as F
from scipy.ndimage.filters import gaussian_filter
import pandas as pd


def gkern(klen, nsig):
	"""Returns a Gaussian kernel array.
	Convolution with it results in image blurring."""
	# create nxn zeros
	inp = np.zeros((klen, klen))
	# set element at the middle to one, a dirac delta
	inp[klen // 2, klen // 2] = 1
	# gaussian-smooth the dirac, resulting in a gaussian filter mask
	k = gaussian_filter(inp, nsig)
	kern = np.zeros((3, 3, klen, klen))
	kern[0, 0] = k
	kern[1, 1] = k
	kern[2, 2] = k
	return torch.from_numpy(kern.astype('float32'))


def blur(x, klen=11, ksig=5):
	kern = gkern(klen, ksig)
	return F.conv2d(x, kern, padding=klen // 2)


def normalise(x):
	return (x - x.min()) / max(x.max() - x.min(), 0.0001)


def hierarchical_perturbation(model,
							  input,
							  target,
							  vis=False,
							  interp_mode='nearest',
							  resize=None,
							  batch_size=32,
							  perturbation_type='mean', threshold_mode='mid-range', return_info=False):
	print('\nBelieve the HiPe!')
	with torch.no_grad():

		dev = input.device
		bn, channels, input_y_dim, input_x_dim = input.shape
		dim = min(input_x_dim, input_y_dim)
		total_masks = 0
		depth = 0
		num_cells = int(max(np.ceil(np.log2(dim)), 1)/2)
		print('Num cells: {}'.format(num_cells))
		max_depth = int(np.log2(dim / num_cells)) - 1
		print('Max depth: {}'.format(max_depth))
		saliency = torch.zeros((1, 1, input_y_dim, input_x_dim), device=dev)
		max_batch = batch_size

		thresholds_d_list = []
		masks_d_list = []

		output = model(input)[:, target]

		if perturbation_type == 'blur':
			pre_b_image = blur(input.clone().cpu()).to(dev)

		while (num_cells*2) <= (dim/4):

			masks_list = []
			b_list = []
			num_cells *= 2
			depth += 1
			if threshold_mode == 'mean':
				threshold = torch.mean(saliency)
			else:
				threshold = torch.min(saliency) + ((torch.max(saliency) - torch.min(saliency)) / 2)

			thresholds_d_list.append(threshold.item())

			y_ixs = range(-1, num_cells)
			x_ixs = range(-1, num_cells)
			x_cell_dim = input_x_dim // num_cells
			y_cell_dim = input_y_dim // num_cells

			print('Depth: {}, {} x {} Cell Dim'.format(depth, y_cell_dim, x_cell_dim))
			print('Threshold: {}'.format(threshold))
			print('Range {:.1f} to {:.1f}'.format(saliency.min(), saliency.max()))

			for x in x_ixs:
				for y in y_ixs:
					x1, y1 = max(0, x), max(0, y)
					x2, y2 = min(x + 2, num_cells), min(y + 2, num_cells)

					mask = torch.zeros((1, 1, num_cells, num_cells), device=dev)
					mask[:, :, y1:y2, x1:x2] = 1.0
					local_saliency = F.interpolate(mask, (input_y_dim, input_x_dim), mode=interp_mode) * saliency

					if depth > 1:
						local_saliency = torch.max(local_saliency)
					else:
						local_saliency = 0

					# If salience of region is greater than the average, generate higher resolution mask
					if local_saliency >= threshold:

						masks_list.append(abs(mask - 1))

						if perturbation_type == 'blur':

							b_image = input.clone()
							b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim] = pre_b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim]
							b_list.add(b_image)

						if perturbation_type == 'mean':
							b_image = input.clone()
							mean = torch.mean(b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim],
											  axis=(-1, -2), keepdims=True)

							b_image[:, :, y1 * y_cell_dim:y2 * y_cell_dim, x1 * x_cell_dim:x2 * x_cell_dim] = mean
							b_list.append(b_image)

			num_masks = len(masks_list)
			print('Selected {} masks at depth {}'.format(num_masks, depth))
			print('Masks: {}'.format(num_masks))
			if num_masks == 0:
				depth -= 1
				break
			total_masks += num_masks
			masks_d_list.append(num_masks)

			while len(masks_list) > 0:
				m_ix = min(len(masks_list), max_batch)
				if perturbation_type != 'fade':
					b_imgs = torch.cat(b_list[:m_ix])
					del b_list[:m_ix]
				masks = torch.cat(masks_list[:m_ix])
				del masks_list[:m_ix]

				# resize low-res masks to input size
				masks = F.interpolate(masks, (input_y_dim, input_x_dim), mode=interp_mode)

				if perturbation_type == 'fade':
					perturbed_outputs = torch.relu(output - model(input * masks)[:, target])
				else:
					perturbed_outputs = torch.relu(output - model(b_imgs)[:, target])

				sal = perturbed_outputs.reshape(-1,1,1,1) * torch.abs(masks - 1)
				saliency += torch.sum(sal, dim=(0, 1))

			if vis:
				plt.figure(figsize=(8, 4))
				plt.subplot(1, 3, 1)
				plt.title('Depth: {}, Threshold: {:.1f}'.format(depth, threshold))
				imsc(torch.sum(input.cpu(), dim=(0)).unsqueeze(0))
				plt.subplot(1, 3, 2)
				if perturbation_type == 'fade':
					imsc(torch.sum(input.cpu() * masks.cpu(), dim=(0)).unsqueeze(0))
				else:
					imsc(torch.sum(b_imgs.cpu(), dim=(0)).unsqueeze(0))
				plt.subplot(1, 3, 3)
				imsc(torch.sum(saliency.cpu(), dim=(0, 1)).unsqueeze(0))
				plt.show()
				plt.figure(figsize=(8, 4))
				pd.Series(normalise(saliency).cpu().reshape(-1)).plot(label='Saliency ({})'.format(threshold_mode))
				pd.Series(normalise(input).cpu().reshape(-1)).plot(label='Actual')
				plt.legend()
				plt.show()

		print('Used {} masks in total.'.format(total_masks))
		if resize is not None:
			saliency = F.interpolate(saliency, (resize[1], resize[0]), mode=interp_mode)
		if return_info:
			return saliency, {'thresholds': thresholds_d_list, 'masks': masks_d_list, 'total_masks': total_masks}
		else:
			return saliency, total_masks


def resize_saliency(tensor, saliency, size, mode):
    """Resize a saliency map.

    Args:
        tensor (:class:`torch.Tensor`): reference tensor.
        saliency (:class:`torch.Tensor`): saliency map.
        size (bool or tuple of int): if a tuple (i.e., (width, height),
            resize :attr:`saliency` to :attr:`size`. If True, resize
            :attr:`saliency: to the shape of :attr:`tensor`; otherwise,
            return :attr:`saliency` unchanged.
        mode (str): mode for :func:`torch.nn.functional.interpolate`.

    Returns:
        :class:`torch.Tensor`: Resized saliency map.
    """
    if size is not False:
        if size is True:
            size = tensor.shape[2:]
        elif isinstance(size, tuple) or isinstance(size, list):
            # width, height -> height, width
            size = size[::-1]
        else:
            assert False, "resize must be True, False or a tuple."
        saliency = F.interpolate(
            saliency, size, mode=mode, align_corners=False)
    return saliency


def _upsample_reflect(x, size, interpolate_mode="bilinear"):
    r"""Upsample 4D :class:`torch.Tensor` with reflection padding.

    Args:
        x (:class:`torch.Tensor`): 4D tensor to interpolate.
        size (int or list or tuple of ints): target size
        interpolate_mode (str): mode to pass to
            :function:`torch.nn.functional.interpolate` function call
            (default: "bilinear").

    Returns:
        :class:`torch.Tensor`: upsampled tensor.
    """
    # Check and get input size.
    assert len(x.shape) == 4
    orig_size = x.shape[2:]

    # Check target size.
    if not isinstance(size, tuple) and not isinstance(size, list):
        assert isinstance(size, int)
        size = (size, size)
    assert len(size) == 2

    # Ensure upsampling.
    for i, o_s in enumerate(orig_size):
        assert o_s <= size[i]

    # Get size of input cell when interpolated.
    cell_size = [int(np.ceil(s / orig_size[i])) for i, s in enumerate(size)]

    # Get size of interpolated input with padding.
    pad_size = [int(cell_size[i] * (orig_size[i] + 2))
                for i in range(len(orig_size))]

    # Pad input with reflection padding.
    x_padded = F.pad(x, (1, 1, 1, 1), mode="reflect")

    # Interpolated padded input.
    x_up = F.interpolate(x_padded,
                         pad_size,
                         mode=interpolate_mode,
                         align_corners=False)

    # Slice interpolated input to size.
    x_new = x_up[:,
                 :,
                 cell_size[0]:cell_size[0] + size[0],
                 cell_size[1]:cell_size[1] + size[1]]

    return x_new


def rise(model,
         input,
         target=None,
         seed=0,
         num_masks=8000,
         num_cells=7,
         filter_masks=None,
         batch_size=32,
         p=0.5,
         resize=False,
         resize_mode='bilinear'):
    r"""RISE.

    Args:
        model (:class:`torch.nn.Module`): a model.
        input (:class:`torch.Tensor`): input tensor.
        seed (int, optional): manual seed used to generate random numbers.
            Default: ``0``.
        num_masks (int, optional): number of RISE random masks to use.
            Default: ``8000``.
        num_cells (int, optional): number of cells for one spatial dimension
            in low-res RISE random mask. Default: ``7``.
        filter_masks (:class:`torch.Tensor`, optional): If given, use the
            provided pre-computed filter masks. Default: ``None``.
        batch_size (int, optional): batch size to use. Default: ``128``.
        p (float, optional): with prob p, a low-res cell is set to 0;
            otherwise, it's 1. Default: ``0.5``.
        resize (bool or tuple of ints, optional): If True, resize saliency map
            to size of :attr:`input`. If False, don't resize. If (width,
            height) tuple, resize to (width, height). Default: ``False``.
        resize_mode (str, optional): If resize is not None, use this mode for
            the resize function. Default: ``'bilinear'``.

    Returns:
        :class:`torch.Tensor`: RISE saliency map.
    """
    with torch.no_grad():
        # Get device of input (i.e., GPU).
        dev = input.device

        # Initialize saliency mask and mask normalization term.
        input_shape = input.shape
        saliency_shape = list(input_shape)

        height = input_shape[2]
        width = input_shape[3]

        out = model(input)
        num_classes = out.shape[1]

        saliency_shape[1] = num_classes
        saliency = torch.zeros(saliency_shape, device=dev)

        # Number of spatial dimensions.
        nsd = len(input.shape) - 2
        assert nsd == 2

        # Spatial size of low-res grid cell.
        cell_size = tuple([int(np.ceil(s / num_cells))
                           for s in input_shape[2:]])

        # Spatial size of upsampled mask with buffer (input size + cell size).
        up_size = tuple([input_shape[2 + i] + cell_size[i]
                         for i in range(nsd)])

        # Save current random number generator state.
        state = torch.get_rng_state()

        # Set seed.
        torch.manual_seed(seed)

        if filter_masks is not None:
            assert len(filter_masks) == num_masks

        num_chunks = (num_masks + batch_size - 1) // batch_size
        for chunk in range(num_chunks):
            # Generate RISE random masks on the fly.
            mask_bs = min(num_masks - batch_size * chunk, batch_size)

            if filter_masks is None:
                # Generate low-res, random binary masks.
                grid = (torch.rand(mask_bs, 1, *((num_cells,) * nsd),
                                   device=dev) < p).float()

                # Upsample low-res masks to input shape + buffer.
                masks_up = _upsample_reflect(grid, up_size)

                # Save final RISE masks with random shift.
                masks = torch.empty(mask_bs, 1, *input_shape[2:], device=dev)
                shift_x = torch.randint(0,
                                        cell_size[0],
                                        (mask_bs,),
                                        device='cpu')
                shift_y = torch.randint(0,
                                        cell_size[1],
                                        (mask_bs,),
                                        device='cpu')
                for i in range(mask_bs):
                    masks[i] = masks_up[i,
                                        :,
                                        shift_x[i]:shift_x[i] + height,
                                        shift_y[i]:shift_y[i] + width]
            else:
                masks = filter_masks[
                    chunk * batch_size:chunk * batch_size + mask_bs]

            # Accumulate saliency mask.
            for i, inp in enumerate(input):
                out = model(inp.unsqueeze(0) * masks)
                if len(out.shape) == 4:
                    # TODO: Consider handling FC outputs more flexibly.
                    assert out.shape[2] == 1
                    assert out.shape[3] == 1
                    out = out[:, :, 0, 0]
                sal = torch.matmul(out.data.transpose(0, 1),
                                   masks.view(mask_bs, height * width))
                sal = sal.view((num_classes, height, width))
                saliency[i] = saliency[i] + sal

        # Normalize saliency mask.
        saliency /= num_masks

        # Restore original random number generator state.
        torch.set_rng_state(state)

        # Resize saliency mask if needed.
        saliency = resize_saliency(input,
                                   saliency,
                                   resize,
                                   mode=resize_mode)
        return saliency


  from scipy.ndimage.filters import gaussian_filter


In [53]:
class SyntheticDataset(Dataset):
	def __init__(self, num_samples=1000, dim=(100, 100), min_coverage=0.1, max_coverage=0.8, max_salient_regions=10, random_sal_val=False):
		self.samples, self.salient_region_sizes, self.num_salient_regions = self.generate_synthetic(num_samples, dim, min_coverage, max_coverage, max_salient_regions, random_sal_val)

	def __len__(self):
		return len(self.samples)

	def __getitem__(self, idx):
		return self.samples[idx].unsqueeze(0), self.salient_region_sizes[idx], self.num_salient_regions[idx]

	def generate_synthetic(self, num_samples, dim, min_coverage, max_coverage, max_salient_regions, random_sal_val):
		samples = []
		salient_region_sizes = []
		num_salient_regions = []

		num_covg = num_samples//max_salient_regions

		# Generate uniform distribution for coverage and number of salient regions
		coverage_values = np.linspace(min_coverage, max_coverage, num_covg)
		salient_region_counts = range(1,max_salient_regions+1)
		print('coverage', coverage_values)
		print('region counts', salient_region_counts)

		for coverage in coverage_values:
			for count in salient_region_counts:
				base = torch.zeros(dim)
				region_size_x = int(dim[0] * (coverage/count))
				region_size_y = int(dim[1] * (coverage/count))

				for region in range(count):

					start_x, start_y = -1, -1
					attempts = 0
					max_attempts = 10000  # Limit the number of attempts to find a non-overlapping region
					# Check for overlaps
					while start_x < 0 or torch.sum(base[start_x:start_x + region_size_x, start_y:start_y + region_size_y]) > 0:
						start_x = random.randint(0, dim[0] - region_size_x)
						start_y = random.randint(0, dim[1] - region_size_y)
						attempts += 1
						if attempts >= max_attempts:
							break
					base[start_x:start_x + region_size_x, start_y:start_y + region_size_y] = 1 if not random_sal_val else min(random.random(), 0.99) + 0.01
				samples.append(base)
				salient_region_sizes.append(base.mean())
				num_salient_regions.append(count)

		return samples, salient_region_sizes, num_salient_regions


class ProxyModel(torch.nn.Module):

	def __init__(self, numel):
		super().__init__()
		self.linear = torch.nn.Linear(numel, 1)
		self.linear.weight=torch.nn.Parameter(torch.ones_like(self.linear.weight), requires_grad=False)

	def forward(self, x):
		sum = self.linear(x.reshape(x.shape[0], -1))
		return sum


def calculate_precision_recall_f1(ground_truth, predicted_saliency):
	# Epsilon may be added for numerical stability to avoid division by zero
	epsilon = 1e-10

	# True Positives (TP): Sum of product of ground truth and predicted saliency
	tp = torch.sum(ground_truth * predicted_saliency)

	# False Positives (FP): Sum of predicted saliency where ground truth is not present
	fp = torch.sum(predicted_saliency * (1 - ground_truth))

	# True Negatives (TN): Sum of the inverse of ground truth and predicted saliency
	tn = torch.sum((1 - ground_truth) * (1 - predicted_saliency))

	# False Negatives (FN): Sum of ground truth where predicted saliency is not present
	fn = torch.sum(ground_truth * (1 - predicted_saliency))

	# Precision
	precision = (tp + epsilon) / (tp + fp + epsilon)

	# Recall
	recall = (tp + epsilon) / (tp + fn + epsilon)

	# F1 Score
	f1 = (2 * precision * recall + epsilon) / (precision + recall + epsilon)

	return precision, recall, f1

In [None]:
from tqdm import tqdm
log_imgs = True
synthetic = True
dim_x, dim_y = 224,224
i_dim_x, i_dim_y = 224,224

methods = ['rise', 'hipe', 'random', 'extremal_perturbation']

num_samples=1000
min_coverage=0.1
max_coverage=0.8
max_salient_regions=5
random_sal_val=False

config = {"log_imgs":log_imgs,"synthetic":synthetic, "synthetic_dim":(dim_x, dim_y),"input_dim":(i_dim_x, i_dim_y), "num_samples":num_samples, "min_coverage":min_coverage, "max_coverage":max_coverage, "max_salient_regions":max_salient_regions, "random_sal_val":random_sal_val}

run = wandb.init(project='proxy_benchmark_new', entity="jessicamarycooper", reinit=True, config=config)
columns=['method', 'time', 'salient region size', 'number of salient regions', 'mae', 'mse', 'cos', 'precision', 'recall', 'f1']
if log_imgs:
	columns.extend(['target', 'output'])

table = wandb.Table(columns=columns)


if synthetic:
	data = SyntheticDataset(num_samples=num_samples, dim=(dim_x, dim_y), min_coverage=min_coverage, max_coverage=max_coverage, max_salient_regions=max_salient_regions, random_sal_val=random_sal_val)
else:
	data = get_dataset(name='coco', subset='val2014', download=True, limiter=num_samples, transform=transforms.ToTensor())

loader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)

data_iterator = iter(loader)

device = get_device()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(0)

model = ProxyModel(numel=dim_x*dim_y).to(device)
model.eval()

class_id = 0
all_results = []

for i, xy in tqdm(enumerate(data_iterator)):
	x, sal_size, sal_num = xy
	x = F.interpolate(x, (i_dim_x, i_dim_y), mode='nearest')
	xd, yd = x.shape[-1], x.shape[-2]
	xn = normalise(x.to(device))
	if synthetic:
		target = xn.clone()[0]
	else:
		target = xn.sum(dim=1)

	for method in methods:

		tic = time.process_time()

		if method == 'hipe':
			saliency, num_ops = hierarchical_perturbation(model, xn, class_id, resize=(i_dim_x, i_dim_y), perturbation_type='fade')

		elif method == 'random':
			saliency = torch.rand(xn.shape).squeeze(0).to(device)


		elif method == "rise":
			rise_saliency = rise(model, xn.clone().detach(), resize=(i_dim_x, i_dim_y))
			saliency = rise_saliency[:, class_id, :, :]

		else:
			areas = [0.018, 0.025, 0.05, 0.1]

			mask, energy = elp.extremal_perturbation(model,
					xn,
					class_id,
					areas=areas,
					num_levels=8,
					step=7,
					sigma=7 * 3,
					max_iter=800,
					debug=False,
					jitter=True,
					smooth=0.09,
					resize=(i_dim_x, i_dim_y),
					perturbation='fade',
					reward_func=elp.simple_reward,
					variant=elp.PRESERVE_VARIANT, )
			saliency = mask.sum(dim=0)

		toc = time.process_time()
		output = normalise(saliency.clone())

		mae = torch.abs(output-target).mean()
		mse = ((output-target)**2).mean()
		cos = F.cosine_similarity(output.view(-1).unsqueeze(0)-0.5, target.view(-1).unsqueeze(0)-0.5).mean()

		prec, rec, f1 = calculate_precision_recall_f1(target, output)

		data=[method, toc - tic, sal_size.item(), sal_num.item(), mae, mse, cos, prec, rec, f1]

		if log_imgs:
			print(data)
			data.extend([wandb.Image(target, caption=f'{output.sum()}'), wandb.Image(output, caption=f'{output.sum()}')])
		table.add_data(*data)
wandb.log({'Table':table})
run.finish()


coverage [0.1        0.10351759 0.10703518 0.11055276 0.11407035 0.11758794
 0.12110553 0.12462312 0.1281407  0.13165829 0.13517588 0.13869347
 0.14221106 0.14572864 0.14924623 0.15276382 0.15628141 0.15979899
 0.16331658 0.16683417 0.17035176 0.17386935 0.17738693 0.18090452
 0.18442211 0.1879397  0.19145729 0.19497487 0.19849246 0.20201005
 0.20552764 0.20904523 0.21256281 0.2160804  0.21959799 0.22311558
 0.22663317 0.23015075 0.23366834 0.23718593 0.24070352 0.24422111
 0.24773869 0.25125628 0.25477387 0.25829146 0.26180905 0.26532663
 0.26884422 0.27236181 0.2758794  0.27939698 0.28291457 0.28643216
 0.28994975 0.29346734 0.29698492 0.30050251 0.3040201  0.30753769
 0.31105528 0.31457286 0.31809045 0.32160804 0.32512563 0.32864322
 0.3321608  0.33567839 0.33919598 0.34271357 0.34623116 0.34974874
 0.35326633 0.35678392 0.36030151 0.3638191  0.36733668 0.37085427
 0.37437186 0.37788945 0.38140704 0.38492462 0.38844221 0.3919598
 0.39547739 0.39899497 0.40251256 0.40603015 0.4095477

0it [00:00, ?it/s]

['rise', 0.4961156219999907, 0.009646045975387096, 1, tensor(0.0806, device='cuda:0'), tensor(0.0221, device='cuda:0'), tensor(0.9583, device='cuda:0'), tensor(0.1008, device='cuda:0'), tensor(0.9278, device='cuda:0'), tensor(0.1818, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 814.0
Range 0.0 to 1628.0
Selected 21 masks at depth 2
Masks: 21
Depth: 3, 7 x 7 Cell Dim
Threshold: 1462.0
Range 0.0 to 2924.0
Selected 45 masks at depth 3
Masks: 45
Used 147 masks in total.
['hipe', 0.19320235899999716, 0.009646045975387096, 1, tensor(0.0265, device='cuda:0'), tensor(0.0096, device='cuda:0'), tensor(0.9807, device='cuda:0'), tensor(0.2464, device='cuda:0'), tensor(0.8483, device='cuda:0'), tensor(0.3819, device='cuda:0')]
['random', 0.0004948110000100314, 0.009646045975387096, 1, tensor(0.5000, device='cuda:0'), tensor(0.3332, device='cuda:0'

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
1it [00:04,  4.79s/it]

['extremal_perturbation', 3.768780879000019, 0.009646045975387096, 1, tensor(0.0789, device='cuda:0'), tensor(0.0418, device='cuda:0'), tensor(0.9126, device='cuda:0'), tensor(0.0959, device='cuda:0'), tensor(0.8524, device='cuda:0'), tensor(0.1724, device='cuda:0')]
['rise', 0.4683553059999781, 0.004823022987693548, 2, tensor(0.1582, device='cuda:0'), tensor(0.0610, device='cuda:0'), tensor(0.8743, device='cuda:0'), tensor(0.0272, device='cuda:0'), tensor(0.9158, device='cuda:0'), tensor(0.0529, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 242.0
Range 0.0 to 484.0
Selected 42 masks at depth 2
Masks: 42
Depth: 3, 7 x 7 Cell Dim
Threshold: 484.0
Range 0.0 to 968.0
Selected 62 masks at depth 3
Masks: 62
Used 185 masks in total.
['hipe', 0.18731156100000135, 0.004823022987693548, 2, tensor(0.0518, device='cuda:0'), tensor(0.0175, device=

2it [00:08,  4.25s/it]

['extremal_perturbation', 3.447236607000036, 0.004823022987693548, 2, tensor(0.1427, device='cuda:0'), tensor(0.0731, device='cuda:0'), tensor(0.8412, device='cuda:0'), tensor(0.0313, device='cuda:0'), tensor(0.9543, device='cuda:0'), tensor(0.0606, device='cuda:0')]
['rise', 0.46985200900002155, 0.0029296875, 3, tensor(0.2667, device='cuda:0'), tensor(0.1186, device='cuda:0'), tensor(0.7309, device='cuda:0'), tensor(0.0106, device='cuda:0'), tensor(0.9703, device='cuda:0'), tensor(0.0209, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 98.0
Range 0.0 to 196.0
Selected 92 masks at depth 2
Masks: 92
Depth: 3, 7 x 7 Cell Dim
Threshold: 196.0
Range 0.0 to 392.0
Selected 87 masks at depth 3
Masks: 87
Used 260 masks in total.
['hipe', 0.1878767370000105, 0.0029296875, 3, tensor(0.0800, device='cuda:0'), tensor(0.0282, device='cuda:0'), tensor

3it [00:12,  4.06s/it]

['extremal_perturbation', 3.416011835000006, 0.0029296875, 3, tensor(0.1827, device='cuda:0'), tensor(0.0837, device='cuda:0'), tensor(0.8165, device='cuda:0'), tensor(0.0130, device='cuda:0'), tensor(0.8165, device='cuda:0'), tensor(0.0255, device='cuda:0')]
['rise', 0.47692245500002173, 0.001992984674870968, 4, tensor(0.2612, device='cuda:0'), tensor(0.1203, device='cuda:0'), tensor(0.7228, device='cuda:0'), tensor(0.0066, device='cuda:0'), tensor(0.8671, device='cuda:0'), tensor(0.0131, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 62.5
Range 0.0 to 125.0
Selected 55 masks at depth 2
Masks: 55
Depth: 3, 7 x 7 Cell Dim
Threshold: 112.5
Range 0.0 to 225.0
Selected 118 masks at depth 3
Masks: 118
Used 254 masks in total.
['hipe', 0.18925471800002924, 0.001992984674870968, 4, tensor(0.0962, device='cuda:0'), tensor(0.0377, device='cuda:

4it [00:16,  3.99s/it]

['extremal_perturbation', 3.4564528760000144, 0.001992984674870968, 4, tensor(0.1727, device='cuda:0'), tensor(0.0960, device='cuda:0'), tensor(0.7862, device='cuda:0'), tensor(0.0090, device='cuda:0'), tensor(0.7831, device='cuda:0'), tensor(0.0178, device='cuda:0')]
['rise', 0.4826404860000366, 0.0015943878097459674, 5, tensor(0.3278, device='cuda:0'), tensor(0.1781, device='cuda:0'), tensor(0.5437, device='cuda:0'), tensor(0.0047, device='cuda:0'), tensor(0.9692, device='cuda:0'), tensor(0.0093, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 36.0
Range 0.0 to 72.0
Selected 74 masks at depth 2
Masks: 74
Depth: 3, 7 x 7 Cell Dim
Threshold: 62.0
Range 0.0 to 124.0
Selected 173 masks at depth 3
Masks: 173
Used 328 masks in total.
['hipe', 0.1987070870000025, 0.0015943878097459674, 5, tensor(0.1243, device='cuda:0'), tensor(0.0504, device

5it [00:20,  3.95s/it]

['extremal_perturbation', 3.4454264780000017, 0.0015943878097459674, 5, tensor(0.1919, device='cuda:0'), tensor(0.0997, device='cuda:0'), tensor(0.7757, device='cuda:0'), tensor(0.0056, device='cuda:0'), tensor(0.6752, device='cuda:0'), tensor(0.0111, device='cuda:0')]
['rise', 0.46680587299999843, 0.01054288912564516, 1, tensor(0.0940, device='cuda:0'), tensor(0.0291, device='cuda:0'), tensor(0.9438, device='cuda:0'), tensor(0.0941, device='cuda:0'), tensor(0.9176, device='cuda:0'), tensor(0.1707, device='cuda:0')]

Believe the HiPe!
Num cells: 4
Max depth: 4
Depth: 1, 28 x 28 Cell Dim
Threshold: 0.0
Range 0.0 to 0.0
Selected 81 masks at depth 1
Masks: 81
Depth: 2, 14 x 14 Cell Dim
Threshold: 877.5
Range 0.0 to 1755.0
Selected 21 masks at depth 2
Masks: 21
Depth: 3, 7 x 7 Cell Dim
Threshold: 1562.0
Range 0.0 to 3124.0
Selected 41 masks at depth 3
Masks: 41
Used 143 masks in total.
['hipe', 0.18201320200000737, 0.01054288912564516, 1, tensor(0.0375, device='cuda:0'), tensor(0.0124, dev

6it [00:24,  3.91s/it]