In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#!tar -C drive/My\ Drive/VOC2012/ -xvf drive/My\ Drive/VOC2012/VOCtrainval_11-May-2012.tar 

In [0]:
# import packages and load dataset

import torchvision
import numpy as np
from PIL import Image
from IPython.display import display
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import transforms
import random
import math
from collections import namedtuple
from itertools import count
import cv2

VOC2012 = torchvision.datasets.VOCDetection("drive/My Drive/VOC2012")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
# test data input extraction

img = VOC2012[0][0]
h = int(VOC2012[0][1]['annotation']['size']['height'])
w = int(VOC2012[0][1]['annotation']['size']['width'])

print(w, h)

bbox = VOC2012[0][1]['annotation']['object'][0]['bndbox']
left = int(bbox['xmin'])
upper = int(bbox['ymin'])
right = int(bbox['xmax'])
lower = int(bbox['ymax'])

bbox_original = (0, 0, w, h)
print(bbox_original)
bbox = (left, upper, right, lower)
print(bbox)

display(img)
img = img.crop(bbox)
display(img)

In [0]:
# define DQN with resnet preprocessing step

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # pre-trained convolutional network
        conv = torchvision.models.resnet50(pretrained=True)
        modules = list(conv.children())[:-1]
        self.conv = nn.Sequential(*modules)
        for p in conv.parameters():
            p.requires_grad = False
            
        # deep Q-network
        self.dqn = nn.Sequential(
            nn.Linear(2138, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 9),
            nn.Softmax(dim=1)
        )
        
    def forward(self, img_t, action_history):
        x = self.conv(img_t)
        x = x.reshape(x.size(0), 2048)
        x = torch.cat((x, action_history), dim=1)
        x = self.dqn(x)
        return x

In [0]:
# data loading and preprocessing functions

def default_collate(batch):
    images = [item[0] for item in batch]
    bboxes_observed = []
    bboxes_true = []
    action_history = [torch.zeros(90)] * len(images)
    for item in batch:
        h = int(item[1]['annotation']['size']['height'])
        w = int(item[1]['annotation']['size']['width'])
        bboxes_observed.append((0, 0, int(w/2), int(h/2)))
        obj = item[1]['annotation']['object']
        if isinstance(obj, list):
            bbox = obj[0]['bndbox']
        else:
            bbox = obj['bndbox']     
        left = int(bbox['xmin'])
        upper = int(bbox['ymin'])
        right = int(bbox['xmax'])
        lower = int(bbox['ymax'])
        bboxes_true.append((left, upper, right, lower))
    return images, bboxes_observed, bboxes_true, action_history

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
    )])

In [0]:
# define replay memory

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [0]:
# TODO: action selection / training

BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

n_actions = 9

policy_net = Net().to(device)
target_net = Net().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

#optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
memory = ReplayMemory(10000)

steps_done = 0

def select_action(img_t, action_history):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(img_t, action_history).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
    
#for (i, [images, bboxes_observed, bboxes_true, action_history]) in enumerate(train_loader):
#    img_observed = [img.crop(bbox) for (img, bbox) in zip(images, bboxes_observed)]
#    img_t = torch.stack([transform(img) for img in img_observed]).to(device)
#    action_history = action_history.to(device)
#    print(net(img_t, action_history))
#    if i == 0:
#        break

In [0]:
def calculate_iou(state):
    image, bbox_observed, bbox_true, action_history = state

    img_mask = np.zeros((image.height, image.width))
    gt_mask = np.zeros((image.height, image.width))

    x1, y1, x2, y2 = bbox_observed
    img_mask[y1:y2, x1:x2] = 1.0

    x1, y1, x2, y2 = bbox_true
    gt_mask[y1:y2, x1:x2] = 1.0

    img_and = cv2.bitwise_and(img_mask, gt_mask)
    img_or = cv2.bitwise_or(img_mask, gt_mask)
    j = np.count_nonzero(img_and)
    i = np.count_nonzero(img_or)
    iou = float(float(j)/float(i))
    
    return iou

def update_action_history(action_history, action):
    action_tmp = torch.zeros(9)
    action_tmp[action] = 1
    action = action_tmp
      
    last_actions = action_history[:81].clone()
       
    action_history[:9] = action
    action_history[9:] = last_actions
        
    return action_history
 
def take_action(state, action):
    image, bbox_observed, bbox_true, action_history = state
        
    x1, y1, x2, y2 = bbox_observed
    
    alph_w = int(0.2 * (x2 - x1))
    alph_h = int(0.2 * (y2 - y1))
    
    done = False
    
    if action == 0: #horizontal move to the right
        #if x2 + alph_w > image.width:
        #    alph_w = image.width - x2
        x1 += alph_w
        x2 = min(x2 + alph_w, image.width)
    elif action == 1: #horizontal move to the left
        #if alph_w > x1:
        #    alph_w = x1
        x1 = max(x1 - alph_w, 0)
        x2 -= alph_w
    elif action == 2: #vertical move up
        #if alph_h > y1:
        #    alph_h = y1
        y1 = max(y1 - alph_h, 0)
        y2 -= alph_h
    elif action == 3: #vertical move down
        #if y2 + alph_h > image.height:
        #    alph_h = image.height - y2
        y1 += alph_h
        y2 = min(y2 + alph_h, image.height)
    elif action == 4: #scale up
        #max_x_oob = max(alph_w - x1, x2 + alph_w - image.width)
        #if max_x_oob > 0:
        #    alph_w -= max_x_oob
        x1 = max(x1 - math.floor(alph_w/2), 0)
        x2 = min(x2 + math.floor(alph_w/2), image.width)
        #max_y_oob = max(alph_h - y1, y2 + alph_h - image.height)
        #if max_y_oob > 0:
        #    alph_h -= max_y_oob
        y1 = max(y1 - math.floor(alph_h/2), 0)
        y2 = min(y2 + math.floor(alph_h/2), image.height)
    elif action == 5: #scale down
        x1 += math.floor(alph_w/2)
        x2 -= math.floor(alph_w/2)
        y1 += math.floor(alph_h/2)
        y2 -= math.floor(alph_h/2)
    elif action == 6: #decrease height (aspect ratio)
        y1 += math.floor(alph_h/2)
        y2 -= math.floor(alph_h/2)
    elif action == 7: #decrease width (aspect ratio)
        x1 += math.floor(alph_w/2)
        x2 -= math.floor(alph_w/2)
    elif action == 8: #trigger
        done = True
        
    bbox_observed_new = (x1, y1, x2, y2)
    action_history_new = update_action_history(action_history, action)
    state_new = (image, bbox_observed_new, bbox_true, action_history_new)
        
    #print("Action taken:", action)
    #print("Old BBOX:", bbox_observed)
    #display(image.crop(bbox_observed))
    #print("New BBOX:", bbox_observed_new)
    #display(image.crop(bbox_observed_new))
    
    iou_old = calculate_iou(state)
    iou_new = calculate_iou(state_new)
    
    #print("Old IOU:", iou_old)
    #print("New IOU:", iou_new)
    
    if done:
        if iou_new >= 0.6:
            reward = 3.0
        else:
            reward = -3.0
    else:
        reward = np.sign(iou_new - iou_old)
        
    print("Reward Received:", reward)
    return reward, state_new, done

def find_positive_actions(state):
    image, bbox_observed, bbox_true, action_history = state
    positive_actions = []
    for i in range(n_actions):
        print(i)
        reward, state_new, done = take_action(state, i)
        if reward > 0:
            positive_actions.append(i)
    print("TEST:", positive_actions)
    return positive_actions

In [0]:
train_loader = torch.utils.data.DataLoader(VOC2012, batch_size=1, collate_fn=default_collate, shuffle=True, num_workers=4)

(i, states) = enumerate(train_loader).__next__()

#images, bboxes_observed, bboxes_true, action_history = states

#img_observed = [img.crop(bbox) for (img, bbox) in zip(images, bboxes_observed)]
#img_t = torch.stack([transform(img) for img in img_observed]).to(device)
#action_history = torch.stack(action_history).to(device)

#action = select_action(img_t, action_history).item()

state = zip(*states).__next__()
state_orig = state

for i in range(20):
    image, bbox_observed, bbox_true, action_history = state
    if i == 0:
        display(image)
        display(image.crop(bbox_true))
    display(image.crop(bbox_observed))
    positive_actions = find_positive_actions(state)
    if len(positive_actions) > 0:
        action = random.choice(positive_actions)
    else:
        action = random.randrange(n_actions)
    print("action selected:", action)
    
    reward, state, done = take_action(state, action)
    if reward < 0:
        state_list.append(state_old)
    if done:
        break