In [4]:
from box_world_env import BoxWorld
import torch
from torch.utils.data import Dataset, DataLoader
import torch.autograd as autograd         # computation graph
from torch import Tensor                  # tensor node in the computation graph
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim               # optimizers e.g. gradient descent, ADAM, etc.
from torch.jit import script, trace       # hybrid frontend decorator and tracing jit

In [399]:
import numpy as np
from torchvision import transforms

In [400]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

In [884]:
class ChannelMaxPool(nn.Module):
    def forward(self, x):
        N, C, = x.shape[:2]
        return torch.max(x.reshape(N, C, -1), 2)[0]

In [885]:
class ConcatCoords(nn.Module):
    # assume X is of shape N, C, H, W
    # TODO: generalize to arbitrary dimensions
    def forward(self, x):
        N, C, H, W = x.shape
        coords = torch.empty(N, 2, H, W).double()
        # x coordinate
        coords[:, 0, :, :] = 2 * (torch.arange(W, dtype=torch.double).reshape(-1, 1).repeat(1, W) / W) - 1
        # y coordinate
        coords[:, 1, :, :] = 2 * (torch.arange(H, dtype=torch.double).reshape(1, -1).repeat(H, 1) / H) - 1
        x = torch.cat([x, coords], 1)
        return x
    

In [886]:
class MHDPA(nn.Module):
    def __init__(self, entity_dim, qkv_dim, n_heads):
        super().__init__()
        self.entity_dim = entity_dim
        self.qkv_dim = qkv_dim
        self.n_heads = n_heads
        self.d = self.qkv_dim // n_heads
        assert self.d * n_heads == self.qkv_dim, "Number of heads must evenly divide QKV dimension"

        self.Wq = nn.Linear(entity_dim, qkv_dim)
        self.Wv = nn.Linear(entity_dim, qkv_dim)
        self.Wk = nn.Linear(entity_dim, qkv_dim)

        self.attention_weights = None

    def forward(self, q, k, v):
        N = q.size(0)

        # linear projections, layer normalizations, reshaping to heads x d
        q = nn.functional.layer_norm(self.Wq(q), [self.qkv_dim]).view(N, -1, self.n_heads, self.d).transpose(1, 2)
        k = nn.functional.layer_norm(self.Wk(k), [self.qkv_dim]).view(N, -1, self.n_heads, self.d).transpose(1, 2)
        v = nn.functional.layer_norm(self.Wv(v), [self.qkv_dim]).view(N, -1, self.n_heads, self.d).transpose(1, 2)
        interactions, self.attention_weights = self.attention(q, k, v)
        x = interactions.transpose(1, 2).contiguous().view(N, -1, self.n_heads * self.d)
        return x

        
    def attention(self, q, k, v):
        d = q.size(-1)
        saliencies = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d)
        weights = torch.nn.functional.softmax(saliencies, dim=-1)
        return torch.matmul(weights, v), weights

class ResidualMLP(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        self.lin1 = nn.Linear(dim, dim)
        self.lin2 = nn.Linear(dim, dim)

    def forward(self, x):
        y = nn.functional.relu(self.lin1(x))
        y = nn.functional.relu(self.lin2(x))
        y = y + x
        y = nn.functional.layer_norm(y, [y.size(-1)])
        return y

class RelationalBlock(nn.Module):
    def __init__(self, n_entities, entity_dim, qkv_dim, n_heads):
        super().__init__()
        self.mhdpa = MHDPA(entity_dim, qkv_dim, n_heads)
        self.mlps = nn.ModuleList([
            ResidualMLP(qkv_dim) for i in range(n_entities)])
        self.n_entities = n_entities
        

    def forward(self, x):
        N, C, H, W = x.shape
        # -> x: H * W x N x C
        x = x.reshape(N, C, -1).permute(2, 0, 1)
        x = self.mhdpa(x, x, x)
        # -> x: N, H * W, qkv_dim
        x = x.permute(1, 0, 2)
        N = x.size(0)
        for n in range(N):
            for e in range(self.n_entities):
                x[n, e] = self.mlps[e](x[n, e])
        import pdb
        pdb.set_trace()
        return x

In [887]:
def res_module():
    return nn.Sequential(
        BasicBlock(26, 26),
        BasicBlock(26, 26),
        BasicBlock(26, 26)
    )

In [888]:
def rel_module():
    return nn.Sequential(
        MHDPABlock(26)
    )

In [913]:
class RRLAgent(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 12, (2, 2), 1),
            nn.ReLU(),
            nn.Conv2d(12, 24, (2, 2), 1),
            nn.ReLU(),
            ConcatCoords(),
        )
        self.rb = RelationalBlock(144, 26, 64, 2)
        self.linear = nn.Sequential(
            ChannelMaxPool(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 5)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.rb(x)
        x = x.transpose(1, 2)
        x = self.linear(x)
        return x

In [914]:
def net():
    return nn.Sequential(
        nn.Conv2d(3, 12, (2, 2), 1),
        nn.ReLU(),
        nn.Conv2d(12, 24, (2, 2), 1),
        nn.ReLU(),
        ConcatCoords(),
        rel_module(),
        ChannelMaxPool(),
        nn.Linear(26, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 5)
    )

In [985]:
def baseline_net():
    return nn.Sequential(
        nn.Conv2d(3, 12, (2, 2), 1),
        nn.ReLU(),
        nn.Conv2d(12, 24, (2, 2), 1),
        nn.ReLU(),
        ConcatCoords(),
        res_module(),
        ChannelMaxPool(),
        nn.Linear(26, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 4)
    )

In [986]:
bw = BoxWorld(12, 4, 2, 2)

In [987]:
bw.reset().shape

(14, 14, 3)

In [988]:
state = transforms.functional.to_tensor(bw.reset()).unsqueeze(0).double()

In [989]:
state.shape

torch.Size([1, 3, 14, 14])

In [990]:
state.type()

'torch.DoubleTensor'

In [991]:
bn = baseline_net().double()

In [992]:
bn(state)

tensor([[ 0.0921,  0.2308, -0.0159,  0.1954]], dtype=torch.float64,
       grad_fn=<AddmmBackward>)

In [993]:
n = RRLAgent().double()

In [994]:
# n(state)

# Q-Learning Baseline

In [995]:
from collections import namedtuple
from itertools import count
import random
import math

In [996]:
# inspiration from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

In [997]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

In [998]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [1010]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = BoxWorld(12, 2, 1, 1)

n_actions = env.action_space.n

policy_net = baseline_net().double().to(device).double()
target_net = baseline_net().double().to(device).double()
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)


steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None]).double()
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device).double()
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

num_episodes = 50
for i_episode in range(num_episodes):
    print(i_episode)
    # Initialize the environment and state
    state = env.reset()
    state = transforms.functional.to_tensor(state).unsqueeze(0).double()
    for t in count():
        # Select and perform an action
        action = select_action(state)
        next_state, reward, done, _ = env.step(action.item())
        next_state = transforms.functional.to_tensor(next_state).unsqueeze(0).double()

        reward = torch.tensor([reward], device=device)
        print(t, action, reward, done)


        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the target network)
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

0
0 tensor([[1]]) tensor([0]) False
optimizing
1 tensor([[3]]) tensor([0]) False
optimizing
2 tensor([[2]]) tensor([0]) False
optimizing
3 tensor([[3]]) tensor([0]) False
optimizing
4 tensor([[2]]) tensor([0]) False
optimizing
5 tensor([[1]]) tensor([0]) False
optimizing
6 tensor([[1]]) tensor([0]) False
optimizing
7 tensor([[1]]) tensor([0]) False
optimizing
8 tensor([[2]]) tensor([0]) False
optimizing
9 tensor([[2]]) tensor([0]) False
optimizing
10 tensor([[2]]) tensor([0]) False
optimizing
11 tensor([[2]]) tensor([0]) False
optimizing
12 tensor([[2]]) tensor([0]) False
optimizing
13 tensor([[2]]) tensor([0]) False
optimizing
14 tensor([[0]]) tensor([0]) False
optimizing
15 tensor([[1]]) tensor([0]) False
optimizing
16 tensor([[2]]) tensor([0]) False
optimizing
17 tensor([[1]]) tensor([0]) False
optimizing
18 tensor([[2]]) tensor([0]) False
optimizing
19 tensor([[2]]) tensor([0]) False
optimizing
20 tensor([[0]]) tensor([0]) False
optimizing
21 tensor([[1]]) tensor([0]) False
optimiz

KeyboardInterrupt: 

In [955]:
def action(state):
    

SyntaxError: unexpected EOF while parsing (<ipython-input-955-76694e87e56d>, line 2)

In [None]:
import torch.optim as optim

In [926]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = BoxWorld(12, 4, 2, 2)
n_actions = env.action_space.n

agent = RRLAgent().to(device)
target = RRLAgent().to(device)
target.load_state_dict(agent.state_dict())

optimizer = optim.RMSProp(agent.parameters())

def action(state):


SyntaxError: unexpected EOF while parsing (<ipython-input-926-dcee57fe88c8>, line 11)