In [1]:
import time
from collections import deque
from enum import Enum
import random
from typing import SupportsFloat, Any
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.utils.data as D
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import RenderFrame, ActType, ObsType
from pytorch_grad_cam import GradCAM
from scipy.stats import ks_2samp

from GradCam import GradCam
from torchvision.models import resnet101, resnet50
from utils import image_net_postprocessing
import numpy as np
from torchvision import transforms
import os
import glob
from PIL import Image
import PIL
import json
from pathlib import Path
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config
import torch.optim as optim
import tianshou as ts
from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

## Dataset

In [3]:
class COCODataset(D.Dataset):
    def __init__(self, root, train=True):
        root = Path(root)

        with open(root/'annotations/captions_train2017.json', 'r') as f:
            images_info = json.load(f)
        self.file_name_to_id = dict()
        for image_info in images_info['images']:
            self.file_name_to_id[image_info['file_name']] = image_info['id']

        with open(root/'cap_dict.json', 'r') as f:
            self.captions_dict = json.load(f)

        if train:
            self.image_files = glob.glob(os.path.join(root/'train2017', "*.jpg"))
        else:
            self.image_files = glob.glob(os.path.join(root/'val2017', "*.jpg"))

    def __getitem__(self, index):
        image_file = self.image_files[index]
        image = PIL.Image.open(image_file)
        image = image.convert('RGB')
        file_name = image_file.split('/')[-1]
        image_tensor = transforms.ToTensor()(image)

        return image_tensor, self.file_name_to_id[file_name]

    def __len__(self):
        return len(self.image_files)

## Environment

In [4]:
class ResnetGradCam(torch.nn.Module):
    def __init__(self, patch_size):
        super(ResnetGradCam, self).__init__()
        feature_extractor = resnet50(pretrained=True).to(device).eval()
        self.vis = GradCam(feature_extractor, device)
        self.patch_size = patch_size
        self.conv_mean = nn.Conv2d(1, 1, kernel_size=self.patch_size, stride=self.patch_size, padding=0,  bias=False)
        self.conv_mean.weight = torch.nn.Parameter(torch.ones_like(self.conv_mean.weight)  / (patch_size[0] * patch_size[1]))

    def forward(self, x, h, w):
        '''
        return the row/col number of a single patch chosen randomly based on the gradcam score
        '''
        self.vis(x, None, postprocessing=image_net_postprocessing)
        self.vis.cam -= torch.min(self.vis.cam)
        self.vis.cam /= torch.max(self.vis.cam)
        resized_cam = transforms.Resize(size=(h, w))(self.vis.cam.unsqueeze(0).unsqueeze(0))

        number_of_rows = (h - self.patch_size[0]) // self.patch_size[0] + 1
        number_of_cols = (w - self.patch_size[1]) // self.patch_size[1] + 1

        scores = self.conv_mean(resized_cam)

        choice = torch.multinomial(scores.flatten(), 1).item()

        row_choice, col_choice = choice // number_of_cols, choice % number_of_cols

        return torch.tensor([[row_choice, col_choice]])

In [5]:
class Actions(Enum):
    UP = 0
    RIGHT = 1
    DOWN = 2
    LEFT = 3
    # STAY = 4

In [6]:
class Environment(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 4}

    def __init__(self, dataloader, patch_size=(64, 64), input_size=224):
        self.dataloader = dataloader
        self.iterator = iter(dataloader)
        self.grad = ResnetGradCam(patch_size).to(device).eval()
        self.transform = transforms.Resize(input_size)
        self.patch_size = patch_size
        self.reset()

        self.observation_space = spaces.Box(low=0, high=255, shape=(3, self.patch_size[0], self.patch_size[1]), dtype=np.uint8)
        self.action_space = spaces.Discrete(len(Actions))

        self.row, self.col, self.max_row, self.max_col = None, None, None, None
        self.height, self.width = None, None
        self.current_image, self.image_id, self.captions = None, None, None

    def reset(self, **kwargs):
        try:
            # Samples the batch
            self.current_image, self.image_id = next(self.iterator)
        except StopIteration:
            # restart the iterator if the previous iterator is exhausted.
            self.iterator = iter(self.dataloader)
            self.current_image, self.image_id = next(self.iterator)

        self.current_image = self.current_image
        self.image_id = str(self.image_id.item())
        _, _, self.height, self.width = self.current_image.shape
        self.captions = self.dataloader.dataset.captions_dict[self.image_id]
        self.row, self.col = [int(x) for x in self.grad(self.transform(self.current_image.to(device)), self.height, self.width)[0]] # just a single row/col
        self.max_row, self.max_col = (self.height - self.patch_size[0]) // self.patch_size[0], (self.width - self.patch_size[1]) // self.patch_size[1]
        # self.calculate_patch_histograms()
        return self._get_patch(), {}

    def _get_patch(self):
        start_row, end_row = self.row * self.patch_size[0], (self.row + 1) * self.patch_size[0]
        start_col, end_col = self.col * self.patch_size[1], (self.col + 1) * self.patch_size[1]
        return self.current_image[0, :, start_row: end_row, start_col: end_col]


    def step(self, action):
        if Actions(action) == Actions.UP:
            self.row = self.row - 1 if self.row > 0 else self.row
        elif Actions(action) == Actions.RIGHT:
            self.col = self.col + 1 if self.col < self.max_col else self.col
        elif Actions(action) == Actions.DOWN:
            self.row = self.row + 1 if self.row < self.max_row else self.row
        elif Actions(action) == Actions.LEFT:
            self.col = self.col - 1 if self.col > 0 else self.col
        else:
            raise ValueError("Invalid action")

        patch = self._get_patch()
        return patch, 0, False, False, {}

    # def calculate_patch_histograms(self):
    #     max_row, max_col = self.height // self.patch_size[0], self.width // self.patch_size[1]
    #
    #     self.histograms = []
    #
    #     for i in range(max_row):
    #         self.histograms.append([])
    #         for j in range(max_col):
    #             start_row, end_row = i * self.patch_size[0], (i + 1) * self.patch_size[0]
    #             start_col, end_col = j * self.patch_size[1], (j + 1) * self.patch_size[1]
    #
    #             selected_patch = self.current_image[:, :, start_row: end_row, start_col: end_col].detach().clone()
    #
    #             hue = self.rgb2h(selected_patch).cpu()
    #             self.histograms[-1].append(np.histogram(hue, bins=10, range=(0., 1.), density=True)[0])
    #
    # def visualize(self, alpha=0.4):
    #     start_row, end_row = self.row * self.patch_size[0], (self.row + 1) * self.patch_size[0]
    #     start_col, end_col = self.col * self.patch_size[1], (self.col + 1) * self.patch_size[1]
    #
    #     mask = torch.zeros_like(self.current_image)
    #     mask[:, :, start_row: end_row, start_col: end_col] = 1
    #
    #     return alpha * mask + (1 - alpha) * self.current_image
    #
    # def rgb2h(self, rgb: torch.Tensor) -> torch.Tensor:
    #     cmax, cmax_idx = torch.max(rgb, dim=1, keepdim=True)
    #     cmin = torch.min(rgb, dim=1, keepdim=True)[0]
    #     delta = cmax - cmin
    #     hsv_h = torch.empty_like(rgb[:, 0:1, :, :])
    #     cmax_idx[delta == 0] = 3
    #     hsv_h[cmax_idx == 0] = (((rgb[:, 1:2] - rgb[:, 2:3]) / delta) % 6)[cmax_idx == 0]
    #     hsv_h[cmax_idx == 1] = (((rgb[:, 2:3] - rgb[:, 0:1]) / delta) + 2)[cmax_idx == 1]
    #     hsv_h[cmax_idx == 2] = (((rgb[:, 0:1] - rgb[:, 1:2]) / delta) + 4)[cmax_idx == 2]
    #     hsv_h[cmax_idx == 3] = 0.
    #     hsv_h /= 6.
    #     hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(rgb), delta / cmax)
    #     hsv_v = cmax
    #     return hsv_h
    #
    # def measure_similarity(self):
    #     current_patch_histogram = self.histograms[self.row][self.col]
    #     max_row, max_col = self.height // self.patch_size[0], self.width // self.patch_size[1]
    #     similarity_matrix = []
    #
    #     for i in range(max_row):
    #         similarity_matrix.append([])
    #         for j in range(max_col):
    #             p_value = ks_2samp(current_patch_histogram,  self.histograms[i][j]).pvalue
    #             similarity_matrix[-1].append(p_value)
    #
    #     return similarity_matrix

In [7]:
dataset = COCODataset(root="../Data/COCO17")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
train_envs = ts.env.DummyVectorEnv([lambda: Environment(dataloader) for _ in range(1)])
test_envs = ts.env.DummyVectorEnv([lambda: Environment(dataloader) for _ in range(1)])



In [9]:
class Q_network(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.feature_extractor = resnet50(pretrained=True)
        input_dim = self.feature_extractor.fc.out_features
        action_shape = train_envs.get_env_attr('action_space',0)[0].n
        self.dueling_head = ts.utils.net.common.Net(input_dim, action_shape, dueling_param=(
            {
                "input_dim": 1024,
                "output_dim": action_shape,
                "hidden_sizes": [512, 512],
            },
            {
                "input_dim": 1024,
                "output_dim": 1,
                "hidden_sizes": [512, 512],
            }
        ), hidden_sizes=[1024],)

    def forward(self, obs, **kwargs):
        obs = torch.tensor(obs)
        duel_out = self.dueling_head(self.feature_extractor(obs))
        return duel_out

In [10]:
net = Q_network()
optim = torch.optim.Adam(net.parameters(), lr=1e-4)
policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.99, estimation_step=3, target_update_freq=32000)
replay_buffer = ts.data.VectorReplayBuffer(100000, 1)
train_collector = ts.data.Collector(policy, train_envs, replay_buffer)
test_collector = ts.data.Collector(policy, test_envs)



In [None]:
lr, epoch, batch_size = 1e-3, 10, 8
train_num, test_num = 10, 10
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10, 10
logger = WandbLogger(
    train_interval=1,
    test_interval=1,
    update_interval=1,
    project="AttentionRL",
    name="testname",
)
writer = SummaryWriter("./logs")
logger.load(writer)

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect,
    test_num, batch_size, update_per_step=1 / step_per_collect,
    train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
    test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
    stop_fn=lambda mean_rewards: mean_rewards >= 0,
    logger=logger
)

  return LooseVersion(v) >= LooseVersion(check)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mattentionrl[0m. Use [1m`wandb login --relogin`[0m to force relogin
fatal: unable to create temp-file: Permission denied
  from IPython.core.display import HTML, display  # type: ignore


