In [1]:
import torch
from torch import nn
import numpy as np 
from math import sqrt
from torch.distributions import Categorical
from rl_with_attention import Glimpse


In [2]:
x = torch.FloatTensor([[[-1 for x in range(32)],[-1 for x in range(32)]],
 [[1 for x in range(32)], [1 for x in range(32)]],
 [[2 for x in range(32)], [-2 for x in range(32)]]])
x = x.permute(1, 0, 2)
x = x.repeat(10, 3, 1)

In [3]:
from rl_with_rnn import *
input_size =2
batch_size = 17
target_size = 20
embedding_dim = 24
hidden_dim = 16
x = torch.FloatTensor(batch_size, target_size, 2).uniform_(0, 1)

In [33]:
ge = GraphEmbedding(input_size, embedding_dim)

In [34]:
GraphEmbedding

rl_with_rnn.GraphEmbedding

In [5]:
encoder_input = ge(x)

## Encoder

In [6]:
class skip_connection(nn.Module):
    def __init__(self, module):
        super(skip_connection, self).__init__()
        self.module = module
    def forward(x):
        return x + self.module(x)

Attention module(Without Batch normalization)

In [7]:
class att_layer(nn.Module):
    def __init__(self, embed_dim, n_heads, feed_forward_hidden=512, bn=False):
        super(att_layer, self).__init__()
        self.mha = torch.nn.MultiheadAttention(embed_dim, n_heads)
        self.embed = nn.Sequential(nn.Linear(embed_dim, feed_forward_hidden), nn.ReLU(), nn.Linear(feed_forward_hidden, embed_dim))
    def forward(self, x):
        #I don't know why, but multiheadattention in pytorch starts with (target_seq_length, batch_size, embedding_size).
        # thus we permute order first. https://pytorch.org/docs/stable/nn.html#multiheadattention
        x = x.permute(1, 0, 2)
        _1 = x + self.mha(x, x, x)[0]
        _1 = _1.permute(1, 0, 2)
        _2 = _1 + self.embed(_1)
        return _1
class attention_module(nn.Sequential):
    def __init__(self, embed_dim, n_heads, feed_forward_hidden=512, bn=False):
        super(attention_module, self).__init__(
            att_layer(embed_dim, n_heads, feed_forward_hidden, bn),
            att_layer(embed_dim, n_heads, feed_forward_hidden, bn),
        )

In [35]:
hidden_dim = 8
mha = attention_module(embedding_dim, hidden_dim)

In [36]:
h = mha(encoder_input)
h.shape

torch.Size([17, 20, 24])

In [10]:
h_mean = h.mean(1)
h_context_embed = nn.Linear(embedding_dim, embedding_dim)

### Calculating query

In [11]:
mask = torch.zeros(batch_size, target_size, dtype=torch.bool)
mask.shape

torch.Size([17, 20])

In [12]:
W_placeholder = nn.Parameter(torch.Tensor(2 * embedding_dim))
W_placeholder.data.uniform_(-1, 1)  # Placeholder should be in range of 
inp = W_placeholder
v_weight_embed = nn.Linear(embedding_dim * 2, embedding_dim)

In [13]:
h_bar = h_context_embed(h_mean)
query = init_embed = h_bar + v_weight_embed(inp)

In [14]:
query.shape

torch.Size([17, 24])

In [15]:
from rl_with_attention import Glimpse, Pointer

### Multihead glimpse

In [16]:
g = Glimpse(embedding_dim, hidden_dim, 4)
p = Pointer(embedding_dim, hidden_dim, 1)

In [17]:
_ , nq = g(query, h, mask)

In [18]:
prob, _ = p(nq, h, mask)

In [19]:
cat = Categorical(prob)

In [20]:
chosen = cat.sample()

In [21]:
logprobs = cat.log_prob(chosen)

In [22]:
mask = torch.zeros(batch_size, target_size, dtype=torch.bool)
mask[[i for i in range(batch_size)] , chosen] = True

In [23]:
h.shape

torch.Size([17, 20, 24])

In [24]:
cc = chosen.unsqueeze(1).unsqueeze(2).repeat(1,1,24)

In [25]:
first_chosen_hs = h.gather(1, cc).squeeze(1)
chosen_hs = h.gather(1, cc).squeeze(1)

In [26]:
v_weight = torch.cat([first_chosen_hs, chosen_hs], dim=-1)

In [27]:
query = h_bar + v_weight_embed(v_weight)

In [28]:
prev_chosen_indices = []
prev_chosen_logprobs = []
first_chosen_hs = None
mask = torch.zeros(batch_size, target_size, dtype=torch.bool)

In [29]:
#while len(prev_chosen_indices) < target_size:
for _index in range(target_size):
    print(_index)
    _ , nq = g(query, h, mask) # applying glimpse
    prob, _ = p(nq, h, mask) # applying ptr
    cat = Categorical(prob)
    chosen = cat.sample()
    logprobs = cat.log_prob(chosen)
    prev_chosen_indices.append(chosen)
    prev_chosen_logprobs.append(logprobs)
    mask[[i for i in range(batch_size)] , chosen] = True
    cc = chosen.unsqueeze(1).unsqueeze(2).repeat(1,1,24)
    if first_chosen_hs is None:
        first_chosen_hs = h.gather(1, cc).squeeze(1)
    chosen_hs = h.gather(1, cc).squeeze(1)
    v_weight = torch.cat([first_chosen_hs, chosen_hs], dim=-1)
    query = h_bar + v_weight_embed(v_weight)
    

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


In [30]:
len(prev_chosen_logprobs), len(prev_chosen_indices)

(20, 20)

In [31]:
ret_1 = torch.stack(prev_chosen_logprobs, 1)
ret_2 = torch.stack(prev_chosen_indices, 1)

In [32]:
ret_1.shape

torch.Size([17, 20])