In [4]:
from enum import Enum
from typing import SupportsFloat, Any
import torch
import torch.nn as nn
import torch.utils.data as D
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import RenderFrame, ActType, ObsType
from pytorch_grad_cam import GradCAM
from torchvision.models import resnet101
from utils import image_net_postprocessing
from torchvision import transforms
import os
import glob
from PIL import Image
import PIL
import json
from pathlib import Path

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Dataset

In [12]:
class COCODataset(D.Dataset):
    def __init__(self, root, train=True):
        root = Path(root)

        with open(root/'annotations/captions_train2017.json', 'r') as f:
            images_info = json.load(f)
        self.file_name_to_id = dict()
        for image_info in images_info['images']:
            self.file_name_to_id[image_info['file_name']] = image_info['id']

        with open(root/'cap_dict.json', 'r') as f:
            self.captions_dict = json.load(f)

        if train:
            self.image_files = glob.glob(os.path.join(root/'train2017', "*.jpg"))
        else:
            self.image_files = glob.glob(os.path.join(root/'val2017', "*.jpg"))

    def __getitem__(self, index):
        image_file = self.image_files[index]
        image = PIL.Image.open(image_file)
        image = image.convert('RGB')
        file_name = image_file.split('/')[-1]
        image_tensor = transforms.ToTensor()(image)

        return image_tensor, self.file_name_to_id[file_name]

    def __len__(self):
        return len(self.image_files)

## GradCam

In [None]:
class Resnet101GradCam(torch.nn.Module):
    def __init__(self, patch_size = (64, 64)):
        super(Resnet101GradCam, self).__init__()
        feature_extractor = resnet101(pretrained=True).to(device).eval()
        self.vis = GradCAM(feature_extractor, device)
        self.patch_size = patch_size
        self.conv_mean = nn.Conv2d(1, 1, kernel_size=self.patch_size, stride=self.patch_size, padding=0,  bias=False)
        self.conv_mean.weight = torch.nn.Parameter(torch.ones_like(self.conv_mean.weight)  / (patch_size[0] * patch_size[1]))

    def forward(self, x, h, w):
        self.vis(x, None, postprocessing=image_net_postprocessing)[0]
        self.vis.cam -= torch.min(self.vis.cam)
        self.vis.cam /= torch.max(self.vis.cam)
        resized_cam = transforms.Resize(size=(h, w))(self.vis.cam.unsqueeze(0).unsqueeze(0))

        number_of_rows = (h - 1) // self.patch_size[0] + 1
        number_of_cols = (w - 1) // self.patch_size[1] + 1

        scores = self.conv_mean(resized_cam)

        softmax = nn.Softmax()

        choice = torch.multinomial(scores.flatten(), 1).item()

        row_choice, col_choice = choice // number_of_cols, choice % number_of_cols

        return torch.tensor([[row_choice, col_choice]])

## Environment

In [None]:
class Actions(Enum):
    STAY = 0
    UP = 1
    RIGHT = 2
    DOWN = 3
    LEFT = 4

In [None]:
class Environment():
    def __init__(self, dataloader, device, patch_size=(64, 64), input_size=224):
        self.dataloader = dataloader
        self.iterator = iter(dataloader)
        self.grad = Resnet101GradCam(patch_size).to(device).eval()
        self.transform = transforms.Resize(input_size)
        self.patch_size = patch_size
        self.device = device
        self.reset()

    def reset(self):
        try:
            # Samples the batch
            self.current_image, self.image_id = next(self.iterator)
        except StopIteration:
            # restart the iterator if the previous iterator is exhausted.
            self.iterator = iter(self.dataloader)
            self.current_image, self.image_id = next(self.iterator)

        self.current_image = self.current_image.to(self.device)
        self.image_id = str(self.image_id.item())
        _, _, self.height, self.width = self.current_image.shape
        self.captions = captions_dict[self.image_id]
        self.row, self.col = [int(x) for x in self.grad(self.transform(self.current_image), self.height, self.width)[0]]
        self.calculate_patch_histograms()


    def calculate_patch_histograms(self):
        max_row, max_col = self.height // self.patch_size[0], self.width // self.patch_size[1]

        self.histograms = []

        for i in range(max_row):
            self.histograms.append([])
            for j in range(max_col):
                start_row, end_row = i * self.patch_size[0], (i + 1) * self.patch_size[0]
                start_col, end_col = j * self.patch_size[1], (j + 1) * self.patch_size[1]

                selected_patch = self.current_image[:, :, start_row: end_row, start_col: end_col].detach().clone()

                hue = self.rgb2h(selected_patch).cpu()
                self.histograms[-1].append(np.histogram(hue, bins=10, range=(0., 1.), density=True)[0])


    def measure_similarity(self):
        current_patch_histogram = self.histograms[self.row][self.col]
        max_row, max_col = self.height // self.patch_size[0], self.width // self.patch_size[1]
        similarity_matrix = []

        for i in range(max_row):
            similarity_matrix.append([])
            for j in range(max_col):
                p_value = ks_2samp(current_patch_histogram,  self.histograms[i][j]).pvalue
                similarity_matrix[-1].append(p_value)

        return similarity_matrix

    def get_patch(self):
        start_row, end_row = self.row * self.patch_size[0], (self.row + 1) * self.patch_size[0]
        start_col, end_col = self.col * self.patch_size[1], (self.col + 1) * self.patch_size[1]
        return self.current_image[:, :, start_row: end_row, start_col: end_col]

    def step(self, action):
        if action == Actions.UP:
            self.row = self.row - 1
        elif action == Actions.RIGHT:
            self.col = self.col + 1
        elif action == Actions.DOWN:
            self.row = self.row + 1
        elif action == Actions.LEFT:
            self.col = self.col - 1
        else:
            pass #TODO STAY

        return self.get_patch()

    def visualize(self, alpha=0.4):
        start_row, end_row = self.row * self.patch_size[0], (self.row + 1) * self.patch_size[0]
        start_col, end_col = self.col * self.patch_size[1], (self.col + 1) * self.patch_size[1]

        mask = torch.zeros_like(self.current_image)
        mask[:, :, start_row: end_row, start_col: end_col] = 1

        return alpha * mask + (1 - alpha) * self.current_image

    def rgb2h(self, rgb: torch.Tensor) -> torch.Tensor:
        cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
        cmin = torch.min(rgb, dim=1, keepdim=True)[0]
        delta = cmax - cmin
        hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
        cmax_idx[delta == 0] = 3
        hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
        hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
        hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
        hsv_h[cmax_idx == 3] = 0.
        hsv_h /= 6.
        hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
        hsv_v = cmax
        return hsv_h

In [None]:
'''
defining the environment as an implmentation of the gym.Env class
'''
class TraversalEnv(gym.Env):
    metadata = {'render.modes': ['human'], "render_fps": 2}

    def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        pass

    def render(self) -> RenderFrame | list[RenderFrame] | None:
        pass
