In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [17]:
class QueryNetwork(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_heads, batch_norm = True):
        super(QueryNetwork, self).__init__()
        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.num_heads = num_heads
        self.batch_norm = True
        self.Wq = nn.Linear(dim_in, dim_hidden * num_heads, bias = False)
        self.bn = nn.BatchNorm1d(dim_hidden * num_heads)
    def forward(self, x):
        # for memory of size less than 384*384 the paper says that batch norm gives no significant improvement
        # when using batch norm, padding tokens in the sequence can skew mean and variance estimate
        query = self.Wq(x)
        
        if self.batch_norm:
            original_shape = query.shape
            query = query.reshape(-1, self.dim_hidden * self.num_heads)
            normalized = self.bn(query)
            normalized = normalized.reshape(original_shape)
            return normalized.reshape(*normalized.shape[:-1], self.num_heads, self.dim_hidden)
        else:
            return query.reshape(*query.shape[:-1], self.num_heads, self.dim_hidden) # batch size x context length x num heads x query dim

In [18]:
qn = QueryNetwork(7,5, 4)
qn(torch.rand(100,10,7)).shape

torch.Size([100, 10, 4, 5])

In [62]:
class ProductKey(nn.Module):
    def __init__(self, dim, num_subkeys, top_k, num_heads):
        super(ProductKey, self).__init__()
        assert dim % 2 == 0, "key must be able to be split into 2"
        self.dim = dim
        self.subkey_size = dim // 2
        self.top_k = top_k
        self.num_subkeys = num_subkeys
        
        keyl = torch.empty(num_heads, num_subkeys, self.subkey_size)
        keyr = torch.empty(num_heads, num_subkeys, self.subkey_size)

        std = 1/math.sqrt(dim)
        keyl.uniform_(-std, std)
        keyr.uniform_(-std, std)

        self.keyl = nn.Parameter(keyl)
        self.keyr = nn.Parameter(keyr)

    def forward(self, query):
        # multihead query
        batch, context_length, num_heads, query_size = query.shape
        
        queryl = query[..., :self.subkey_size]
        queryr = query[..., self.subkey_size:]

        scorel = torch.einsum('bcnq,nkq->bcnk', queryl, self.keyl) # batch size x context length x num head x subquery length , num heads x num keys x subquery length
        scorer = torch.einsum('bcnq,nkq->bcnk', queryr, self.keyr)

        top_keys_l, top_idx_l = scorel.topk(self.top_k) # batch, context, heads, top k
        top_keys_r, top_idx_r = scorer.topk(self.top_k)

        #duplicate along the rows
        product_scores_l = top_keys_l.reshape(*top_keys_l.shape[:-1], top_keys_l.shape[-1], 1).expand(*top_keys_l.shape[:-1], top_keys_l.shape[-1], top_keys_l.shape[-1])
        # duplicate along the columns
        product_scores_r = top_keys_r.reshape(*top_keys_r.shape[:-1], 1, top_keys_r.shape[-1]).expand(*top_keys_r.shape[:-1], top_keys_r.shape[-1], top_keys_r.shape[-1])

        product_scores = (product_scores_l + product_scores_r) # batch, context, heads, top k, top k
        product_scores = product_scores.reshape(batch, context_length, num_heads, self.top_k * self.top_k)

        product_indices = top_idx_l.reshape(*top_idx_l.shape[:-1], top_idx_l.shape[-1], 1).expand(*top_idx_l.shape[:-1], top_idx_l.shape[-1], top_idx_l.shape[-1]) * self.num_subkeys
        product_indices += top_idx_r.reshape(*top_idx_r.shape[:-1], top_idx_r.shape[-1], 1).expand(*top_idx_r.shape[:-1], top_idx_r.shape[-1], top_idx_r.shape[-1]) 
        product_indices = product_indices.reshape(batch, context_length, num_heads, self.top_k * self.top_k)

        top_product_scores, top_product_indices = product_scores.topk(self.top_k)
        selected_value_weights = F.softmax(top_product_scores, dim=-1)
        selected_value_indices = torch.gather(product_indices, -1, top_product_indices)
        #print(top_product_scores)
        #print(torch.gather(product_scores, -1, top_product_indices)) # they should be equal
        return selected_value_weights, selected_value_indices
        
        

In [23]:
block = torch.rand(1,2,3,4)
print(block)
print(block[...,:2])
print(block[...,2:])

tensor([[[[0.1053, 0.8178, 0.5525, 0.3539],
          [0.4601, 0.4129, 0.9341, 0.3508],
          [0.3305, 0.7884, 0.3015, 0.9904]],

         [[0.1591, 0.5620, 0.1787, 0.1385],
          [0.9471, 0.3718, 0.5186, 0.1728],
          [0.8403, 0.5152, 0.0786, 0.8085]]]])
tensor([[[[0.1053, 0.8178],
          [0.4601, 0.4129],
          [0.3305, 0.7884]],

         [[0.1591, 0.5620],
          [0.9471, 0.3718],
          [0.8403, 0.5152]]]])
tensor([[[[0.5525, 0.3539],
          [0.9341, 0.3508],
          [0.3015, 0.9904]],

         [[0.1787, 0.1385],
          [0.5186, 0.1728],
          [0.0786, 0.8085]]]])


In [29]:
a = torch.rand(2,4,1)
print(a)
print(a.expand(2,4,4))
b = a.reshape(2,1,4)
print(b.expand(2,4,4))

tensor([[[0.3561],
         [0.9040],
         [0.2236],
         [0.9060]],

        [[0.8523],
         [0.1579],
         [0.4793],
         [0.3675]]])
tensor([[[0.3561, 0.3561, 0.3561, 0.3561],
         [0.9040, 0.9040, 0.9040, 0.9040],
         [0.2236, 0.2236, 0.2236, 0.2236],
         [0.9060, 0.9060, 0.9060, 0.9060]],

        [[0.8523, 0.8523, 0.8523, 0.8523],
         [0.1579, 0.1579, 0.1579, 0.1579],
         [0.4793, 0.4793, 0.4793, 0.4793],
         [0.3675, 0.3675, 0.3675, 0.3675]]])
tensor([[[0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060],
         [0.3561, 0.9040, 0.2236, 0.9060]],

        [[0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675],
         [0.8523, 0.1579, 0.4793, 0.3675]]])


In [59]:
qn = QueryNetwork(7,6, 4)
query = qn(torch.rand(10,2,7))
#print(query.shape)
key = ProductKey(6, 3, 2, 4)
key(query)

tensor([[[[ 0.9406,  0.8643],
          [ 0.1486,  0.1485],
          [ 1.3653,  1.2828],
          [ 0.5725,  0.5647]],

         [[ 1.1675,  0.9089],
          [ 0.0824, -0.0272],
          [ 0.3099,  0.1507],
          [ 0.2130,  0.1989]]],


        [[[ 0.4970,  0.3907],
          [ 0.3884,  0.3348],
          [-0.1141, -0.1393],
          [ 0.4365,  0.1453]],

         [[ 0.3047,  0.1988],
          [ 0.7850,  0.7605],
          [ 0.9413,  0.8459],
          [ 0.4884,  0.4194]]],


        [[[ 0.8606,  0.8399],
          [ 0.3752,  0.2896],
          [ 2.5514,  2.4396],
          [ 0.2540,  0.1578]],

         [[ 1.0256,  0.6873],
          [ 1.2209,  0.9314],
          [ 1.2635,  1.1965],
          [ 0.4429,  0.2562]]],


        [[[ 0.4367,  0.3870],
          [ 0.8675,  0.8348],
          [ 1.7378,  1.5173],
          [ 0.5041,  0.4577]],

         [[ 0.6319,  0.4495],
          [ 0.1568,  0.1509],
          [-0.0641, -0.1277],
          [ 0.6859,  0.2412]]],


        [[[ 0.01

In [73]:
class PKM(nn.Module):
    def __init__(self, dim_in, dim_hidden, num_subkeys, top_k, num_heads, batch_norm = True):
        super(PKM, self).__init__()
        self.dim_in, self.dim_hidden, self.num_subkeys, self.top_k, self.num_heads = dim_in, dim_hidden, num_subkeys, top_k, num_heads
        self.query_network = QueryNetwork(dim_in, dim_hidden, num_heads, batch_norm)
        self.key_table = ProductKey(dim_hidden, num_subkeys, top_k, num_heads)
        self.value_table = nn.Embedding(num_subkeys * num_subkeys, dim_in)
    def forward(self, x):
        queries = self.query_network(x)
        weights, indices = self.key_table(queries) # shape is batch, context length, num heads, top k
        original_shape = weights.shape
        
        weights, indices = weights.reshape(-1, self.top_k), indices.reshape(-1, self.top_k)
        values = self.value_table(indices)
        weights = weights.reshape(original_shape)
        values = values.reshape(*original_shape, self.dim_in)

        weighted_values = torch.einsum('bcnk,bcnkd->bcd', weights, values) # take linear combination of weights & values and sum over all heads

        return weighted_values
        
        
        
        

In [74]:
pkm = PKM(5,8,10,2,4)
pkm(torch.rand(3,4,5))

tensor([[[ 3.4671, -1.9806, -1.0823, -1.2092, -6.0788],
         [-1.9874, -1.2526, -0.0125, -0.1405,  1.1630],
         [-4.6440,  1.8096, -2.4216, -0.6242, -4.7508],
         [-0.6805, -0.1313, -3.2581, -5.1318, -0.4604]],

        [[ 1.7440,  0.5148, -1.8089, -4.4855, -2.7130],
         [-0.9145,  0.6873,  3.1486, -1.0154, -1.3887],
         [-0.9262, -0.1986,  2.8402, -0.4718,  0.1240],
         [ 1.3452,  1.8757, -2.3615,  0.9435, -0.7127]],

        [[ 1.6607, -0.0710, -1.3316, -0.3714, -0.1003],
         [ 1.7431, -0.3094, -3.9427, -3.6363, -3.9978],
         [-0.1268,  0.2422,  0.2941, -0.8417, -2.8110],
         [-3.6644,  1.7039,  2.1051, -1.0184, -1.6720]]],
       grad_fn=<ViewBackward0>)