In [1]:
import torch
from torch import nn
import numpy as np 
from math import sqrt

In [2]:
x = torch.FloatTensor([[[-1 for x in range(32)],[-1 for x in range(32)]],
 [[1 for x in range(32)], [1 for x in range(32)]],
 [[2 for x in range(32)], [-2 for x in range(32)]]])
x = x.permute(1, 0, 2)
x = x.repeat(10, 3, 1)

In [3]:
x = torch.FloatTensor(13, 17, 2).uniform_(0, 1)

In [4]:
from rl_with_rnn import *
input_size =2
embedding_dim = 24

In [5]:
ge = GraphEmbedding(input_size, embedding_dim)

In [6]:
encoder_input = ge(x)

## Encoder

In [7]:
class skip_connection(nn.Module):
    def __init__(self, module):
        super(skip_connection, self).__init__()
        self.module = module
    def forward(x):
        return x + self.module(x)

Attention module(Without Batch normalization)

In [8]:
class att_layer(nn.Module):
    def __init__(self, embed_dim, n_heads, feed_forward_hidden=512, bn=False):
        super(att_layer, self).__init__()
        self.mha = torch.nn.MultiheadAttention(embed_dim, n_heads)
        self.embed = nn.Sequential(nn.Linear(embed_dim, feed_forward_hidden), nn.ReLU(), nn.Linear(feed_forward_hidden, embed_dim))
    def forward(self, x):
        #I don't know why, but multiheadattention in pytorch starts with (target_seq_length, batch_size, embedding_size).
        # thus we permute order first. https://pytorch.org/docs/stable/nn.html#multiheadattention
        x = x.permute(1, 0, 2)
        _1 = x + self.mha(x, x, x)[0]
        _1 = _1.permute(1, 0, 2)
        _2 = _1 + self.embed(_1)
        return _1
class attention_module(nn.Sequential):
    def __init__(self, embed_dim, n_heads, feed_forward_hidden=512, bn=False):
        super(attention_module, self).__init__(
            att_layer(embed_dim, n_heads, feed_forward_hidden, bn),
            att_layer(embed_dim, n_heads, feed_forward_hidden, bn),
        )

In [9]:
hidden_dim = 8
mha = attention_module(embedding_dim, hidden_dim)

In [10]:
h = mha(encoder_input)
h.shape

torch.Size([13, 17, 24])

In [11]:
h_bar = h.mean(1)
h_context_embed = nn.Linear(embedding_dim, embedding_dim)

### Calculating query

In [12]:
W_placeholder = nn.Parameter(torch.Tensor(2 * embedding_dim))
W_placeholder.data.uniform_(-1, 1)  # Placeholder should be in range of 
inp = W_placeholder
v_weight_embed = nn.Linear(embedding_dim * 2, embedding_dim)

In [13]:
query = init_embed = h_context_embed(h_bar) + v_weight_embed(inp)

In [14]:
query.shape

torch.Size([13, 24])

### Multihead glimpse

In [20]:
glimpse_key_embedding = nn.Linear(embedding_dim, hidden_dim, bias=False)
glimpse_val_embedding = nn.Linear(embedding_dim, hidden_dim, bias=False)
logit_embedding = nn.Linear(embedding_dim, hidden_dim, bias=False)

In [27]:
key = glimpse_key_embedding(h)
val = glimpse_val_embedding(h)
#g_logits = logit_embedding(x)

In [26]:
g_key.shape

torch.Size([13, 8])

In [None]:
g_key.shape, g_val.shape, g_logits.shape

In [None]:
n_heads = 4
batch_size = g_key.size(0)
g_q = query.view(batch_size, n_heads, -1).permute(1, 0, 2) # (n_head, batch_size, embed_size)

In [None]:
g_q.shape #(n_head, batch_size, embed_size)

In [None]:
def get_logits(query, g_key, g_val, g_logits, n_head=4):
    batch_size = g_key.size(0)
    embed_size = g_key.size(-1)
    d_k = embed_size // n_head
    g_key = g_key.reshape(g_key.shape[0], g_key.shape[1], n_head, d_k) # (batch_size, seq_len, n_head, embed_size)
    g_key = g_key.permute(2, 0, 1, 3) # (n_head, batch_size, seq_len, embed_size)
    g_val = g_val.reshape(g_val.shape[0], g_val.shape[1], n_head, d_k)
    g_val = g_val.permute(2, 0, 1, 3)
    #Must be masked, but not yet implemented
    ret = torch.matmul(g_q.unsqueeze(2), g_key.transpose(-2, -1)) / sqrt(g_key.size(-1)) # Ret value should be masked
    ret_softmax = torch.softmax(ret, -1)
    heads = torch.matmul(ret_softmax, g_val)
    
    ret = heads.permute(1, 2, 0, 3).reshape(batch_size, 1, -1)
    ret.shape, g_logits.shape #batch_size, seq_len, 
    
    logits = torch.matmul(ret, g_logits.transpose(-1, -2))
    logits = torch.tanh(logits) * 10
    logits = logits.squeeze(1)
    return logits

In [None]:
logits = get_logits(query, g_key, g_val, g_logits)
print(logits.shape)

In [None]:
r = torch.softmax(logits, -1)

In [None]:
r.shape

In [None]:
n_head = 4
embed_size = 32
g_key = g_key.reshape(g_key.shape[0], g_key.shape[1], 4, 8) # (batch_size, seq_len, n_head, embed_size)
g_key = g_key.permute(2, 0, 1, 3) # (n_head, batch_size, seq_len, embed_size)
g_val = g_val.reshape(g_val.shape[0], g_val.shape[1], 4, 8)
g_val = g_val.permute(2, 0, 1, 3)

In [None]:
from math import sqrt
ret = torch.matmul(g_q.unsqueeze(2), g_key.transpose(-2, -1)) / sqrt(g_key.size(-1)) # Ret value should be masked

In [None]:
ret_softmax = torch.softmax(ret, -1)
ret_softmax.shape # #(num_head, batch_size, 1, seq_len)

In [None]:
g_val.shape

In [None]:
heads = torch.matmul(ret_softmax, g_val)

In [None]:
batch_size = 2
print(heads.shape) # (num_head, batch_size, 1, hidden_size)
ret = heads.permute(1, 2, 0, 3).reshape(batch_size, 1, -1)

In [None]:
ret.shape, g_logits.shape #batch_size, seq_len, 

In [None]:
logits = torch.matmul(ret, g_logits.transpose(-1, -2))
logits.shape # (batch_size, )

In [None]:
logits = torch.tanh(logits) * 10

In [None]:
glimpse.shape