In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
#!tar -C drive/My\ Drive/VOC2012/ -xvf drive/My\ Drive/VOC2012/VOCtrainval_11-May-2012.tar 

In [0]:
# import packages and load dataset

import torchvision
import numpy as np
from PIL import Image
from IPython.display import display
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import transforms
import random
import math
from collections import namedtuple
from itertools import count
import cv2

VOC2012 = torchvision.datasets.VOCDetection("drive/My Drive/VOC2012")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
# test data input extraction

print(VOC2012.__len__())

img = VOC2012[0][0]
h = int(VOC2012[0][1]['annotation']['size']['height'])
w = int(VOC2012[0][1]['annotation']['size']['width'])

print(w, h)

bbox = VOC2012[0][1]['annotation']['object'][0]['bndbox']
left = int(bbox['xmin'])
upper = int(bbox['ymin'])
right = int(bbox['xmax'])
lower = int(bbox['ymax'])

bbox_original = (0, 0, w, h)
print(bbox_original)
bbox = (left, upper, right, lower)
print(bbox)

display(img)
img = img.crop(bbox)
display(img)

In [0]:
# define DQN with resnet preprocessing step

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # pre-trained convolutional network
        conv = torchvision.models.resnet50(pretrained=True)
        modules = list(conv.children())[:-1]
        self.conv = nn.Sequential(*modules)
        for p in conv.parameters():
            p.requires_grad = False
            
        # deep Q-network
        self.dqn = nn.Sequential(
            nn.Linear(2138, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 9),
            nn.Softmax(dim=1)
        )
        
    def forward(self, img_t, action_history):
        out = self.conv(img_t)
        out = out.reshape(out.size(0), 2048)
        out = torch.cat((out, action_history), dim=1)
        out = self.dqn(out)
        return out

In [0]:
# data loading and preprocessing functions

State = namedtuple('State',
                        ('image', 'bbox_observed', 'bbox_true', 'action_history'))

def default_collate(batch):
    states = []
    for item in batch:
        image = item[0]
        action_history = torch.zeros(90)
        h = int(item[1]['annotation']['size']['height'])
        w = int(item[1]['annotation']['size']['width'])
        bbox_observed = (0, 0, int(w/2), int(h/2))
        obj = item[1]['annotation']['object']
        if isinstance(obj, list):
            bbox = obj[0]['bndbox']
        else:
            bbox = obj['bndbox']     
        left = int(bbox['xmin'])
        upper = int(bbox['ymin'])
        right = int(bbox['xmax'])
        lower = int(bbox['ymax'])
        bbox_true = (left, upper, right, lower)
        states.append(State(image, bbox_observed, bbox_true, action_history))
    return states

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
    )])

def state_transform(states):
    # return the transformed images and action_history for each state
    img_observed = [state.image.crop(state.bbox_observed) for state in states]
    img_t = torch.stack([transform(img) for img in img_observed]).to(device)
    action_history = torch.stack([state.action_history for state in states]).to(device)
    return img_t, action_history

In [0]:
# define replay memory

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [0]:
# Reinforcement learning actions/state updates

n_actions = 9

def calculate_iou(state):
    image, bbox_observed, bbox_true, action_history = state

    img_mask = np.zeros((image.height, image.width))
    gt_mask = np.zeros((image.height, image.width))

    x1, y1, x2, y2 = bbox_observed
    img_mask[y1:y2, x1:x2] = 1.0

    x1, y1, x2, y2 = bbox_true
    gt_mask[y1:y2, x1:x2] = 1.0

    img_and = cv2.bitwise_and(img_mask, gt_mask)
    img_or = cv2.bitwise_or(img_mask, gt_mask)
    j = np.count_nonzero(img_and)
    i = np.count_nonzero(img_or)
    iou = float(float(j)/float(i))
    
    return iou

def update_action_history(action_history, action):
    action_history_new = action_history.clone()
    action_tmp = torch.zeros(9)
    action_tmp[action] = 1
    action = action_tmp
      
    last_actions = action_history_new[:81].clone()
       
    action_history_new[:9] = action
    action_history_new[9:] = last_actions
        
    return action_history_new
 
def take_action(state, action):
    image, bbox_observed, bbox_true, action_history = state        
    x1, y1, x2, y2 = bbox_observed
    alph_w = int(0.2 * (x2 - x1))
    alph_h = int(0.2 * (y2 - y1))
    
    done = False
    
    if action == 0: #horizontal move to the right
        #if x2 + alph_w > image.width:
        #    alph_w = image.width - x2
        x1 += alph_w
        x2 = min(x2 + alph_w, image.width)
    elif action == 1: #horizontal move to the left
        #if alph_w > x1:
        #    alph_w = x1
        x1 = max(x1 - alph_w, 0)
        x2 -= alph_w
    elif action == 2: #vertical move up
        #if alph_h > y1:
        #    alph_h = y1
        y1 = max(y1 - alph_h, 0)
        y2 -= alph_h
    elif action == 3: #vertical move down
        #if y2 + alph_h > image.height:
        #    alph_h = image.height - y2
        y1 += alph_h
        y2 = min(y2 + alph_h, image.height)
    elif action == 4: #scale up
        #max_x_oob = max(alph_w - x1, x2 + alph_w - image.width)
        #if max_x_oob > 0:
        #    alph_w -= max_x_oob
        x1 = max(x1 - math.floor(alph_w/2), 0)
        x2 = min(x2 + math.floor(alph_w/2), image.width)
        #max_y_oob = max(alph_h - y1, y2 + alph_h - image.height)
        #if max_y_oob > 0:
        #    alph_h -= max_y_oob
        y1 = max(y1 - math.floor(alph_h/2), 0)
        y2 = min(y2 + math.floor(alph_h/2), image.height)
    elif action == 5: #scale down
        x1 += math.floor(alph_w/2)
        x2 -= math.floor(alph_w/2)
        y1 += math.floor(alph_h/2)
        y2 -= math.floor(alph_h/2)
    elif action == 6: #decrease height (aspect ratio)
        y1 += math.floor(alph_h/2)
        y2 -= math.floor(alph_h/2)
    elif action == 7: #decrease width (aspect ratio)
        x1 += math.floor(alph_w/2)
        x2 -= math.floor(alph_w/2)
    elif action == 8: #trigger
        done = True
        
    bbox_observed_new = (x1, y1, x2, y2)
    action_history_new = update_action_history(action_history, action)
    next_state = State(image, bbox_observed_new, bbox_true, action_history_new)
    
    iou_old = calculate_iou(state)
    iou_new = calculate_iou(next_state)
       
    if done:
        if iou_new >= 0.6:
            reward = 3.0
        else:
            reward = -3.0
    else:
        reward = np.sign(iou_new - iou_old)
        
    return reward, next_state, done

def find_positive_actions(state):
    image, bbox_observed, bbox_true, action_history = state
    positive_actions = []
    for i in range(n_actions):
        reward, next_state, done = take_action(state, i)
        if reward > 0:
            positive_actions.append(i)
    return positive_actions

In [0]:
# Hyperparameters / utilities

BATCH_SIZE = 64
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

policy_net = Net().to(device)
target_net = Net().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = torch.optim.Adam(policy_net.parameters(), lr=1e-4)
memory = ReplayMemory(10000)


# training loop

train_loader = torch.utils.data.DataLoader(VOC2012, batch_size=BATCH_SIZE, collate_fn=default_collate, shuffle=True)

VOCtest = torchvision.datasets.VOCDetection("drive/My Drive/VOC2012", image_set='val')
test_loader = torch.utils.data.DataLoader(VOCtest, batch_size=1, collate_fn=default_collate, shuffle=True)
test_iter = enumerate(test_loader)

steps_done = 0

import timeit

def select_action(img_t, action_history, states):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        # select best action from model with probability 1-epsilon
        with torch.no_grad():
            actions = policy_net(img_t, action_history)
            return torch.max(actions, 1).indices
    else:
        # select random positive action with probability epsilon
        actions = []
        for state in states:
            positive_actions = find_positive_actions(state)
            if len(positive_actions) > 0:
                action = random.choice(positive_actions)
            else:
                action = random.randrange(n_actions)
            actions.append(action)
        return torch.tensor(actions, device=device)

total_time = 0
print("First 10 DQN params (initialization):", policy_net.state_dict()['dqn.0.weight'][0][:10])
for i, states in enumerate(train_loader):
    print("Running batch", i)
    batch_steps = 0
    start = timeit.default_timer()
    while len(states) > 0 and batch_steps < 100:
        img_t, action_history = state_transform(states)
        actions = select_action(img_t, action_history, states)
        states_new = []
        for j in range(actions.shape[0]):
            action = actions[j].item()
            state = states[j]
            reward, next_state, done = take_action(state, action)
            reward = torch.tensor([reward], device=device)
            memory.push(state, action, next_state, reward)
            if not done:
                states_new.append(next_state)
        optimize_model()
        
        if batch_steps % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())
            
        states = states_new
        batch_steps+=1
    
    # save visualization
    _, s = test_iter.__next__()
    localize(s[0], "img_{}".format(i))
    
    stop = timeit.default_timer()
    t = (stop-start)/60
    total_time += t
    print("Finished batch {0} in {1:.2f} minutes.".format(i, t))
    print("Total time: {0:.2f} minutes.".format(total_time))
    print("First 10 DQN params after batch {0}:".format(i), policy_net.state_dict()['dqn.0.weight'][0][:10])

First 10 DQN params (initialization): tensor([ 3.5972e-03, -2.9684e-03, -8.1722e-03,  1.4963e-02,  1.0381e-02,
         8.4182e-05,  1.4594e-02, -3.7783e-03,  1.9429e-02,  1.3995e-02],
       device='cuda:0')
Running batch 0
Finished batch 0 in 2.24 minutes.
Total time: 2.24 minutes.
First 10 DQN params after batch 0: tensor([ 0.0045, -0.0029, -0.0073,  0.0141,  0.0117, -0.0002,  0.0147, -0.0041,
         0.0198,  0.0127], device='cuda:0')
Running batch 1
Finished batch 1 in 3.62 minutes.
Total time: 5.87 minutes.
First 10 DQN params after batch 1: tensor([ 0.0072, -0.0033, -0.0067,  0.0155,  0.0135, -0.0010,  0.0162, -0.0024,
         0.0198,  0.0146], device='cuda:0')
Running batch 2
Finished batch 2 in 3.58 minutes.
Total time: 9.44 minutes.
First 10 DQN params after batch 2: tensor([ 0.0076, -0.0016, -0.0036,  0.0162,  0.0143,  0.0002,  0.0186, -0.0029,
         0.0201,  0.0150], device='cuda:0')
Running batch 3
Finished batch 3 in 4.22 minutes.
Total time: 13.66 minutes.
First 10 

In [0]:
i, S = enumerate(train_loader).__next__()
print(S[0])

img_t, action_history = state_transform(S)
actions = select_action(img_t, action_history, S)
action = actions[0].item()
print(actions, action)
print(S[0])
reward, next_state, done = take_action(S[0], action)
print("reward:", reward)
print("next_state:", next_state)
print("done:", done)

In [0]:
# Optimization

def get_last_action(state):
    last_action = state.action_history[:9]
    return last_action.nonzero().item()

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: get_last_action(s) != 8,
                                          batch.next_state)), device=device, dtype=torch.uint8)
    
    non_final_next_states = [s for s in batch.next_state if get_last_action(s) != 8]
    non_final_img_t, non_final_action_history = state_transform(non_final_next_states)
    
    state_batch = batch.state
    action_batch = torch.tensor(batch.action, device=device)
    reward_batch = torch.cat(batch.reward)
       
    img_t, action_history = state_transform(state_batch)
    
    actions = policy_net(img_t, action_history)

    state_action_values = policy_net(img_t, action_history).gather(1, action_batch.view(-1, 1))

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_img_t, non_final_action_history).max(1)[0].detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        if param.grad is not None:
            param.grad.data.clamp(-1, 1)
    optimizer.step()

In [0]:
# visualization

from PIL import ImageDraw

def draw_boxes(state):
    image = state.image.copy()
    draw = ImageDraw.Draw(image)
    draw.rectangle(state.bbox_true, outline=(255,0,255))
    draw.rectangle(state.bbox_observed, outline=(0,255,255))
    return(image)

def localize(state, name):
    vis = draw_boxes(state)
    w = state.image.width
    h = state.image.height
    for i in range(20):
        img_t, action_history = state_transform([state])
        action = policy_net(img_t, action_history).max(1).indices[0].item()
        reward, state, done = take_action(state, action)
        vis_new = Image.new('RGB', (vis.width + w, h))
        vis_new.paste(vis)
        vis_new.paste(draw_boxes(state), (vis.width, 0))
        vis = vis_new
        if done:
            break
    vis.save("drive/My Drive/visualization/{}.png".format(name))