diff --git a/ntm/head.py b/ntm/head.py index d53a4d2..f22890f 100644 --- a/ntm/head.py +++ b/ntm/head.py @@ -45,7 +45,7 @@ def _address_memory(self, k, β, g, s, γ, w_prev): k = k.clone() β = F.softplus(β) g = F.sigmoid(g) - s = F.softmax(F.softplus(s)) + s = F.softmax(F.softplus(s), dim=0) γ = 1 + F.softplus(γ) w = self.memory.address(k, β, g, s, γ, w_prev) diff --git a/ntm/memory.py b/ntm/memory.py index 9c2ef3c..0a71655 100644 --- a/ntm/memory.py +++ b/ntm/memory.py @@ -32,8 +32,7 @@ def __init__(self, N, M): # The memory bias allows the heads to learn how to initially address # memory locations by content - self.mem_bias = Variable(torch.Tensor(N, M)) - self.register_buffer('mem_bias', self.mem_bias.data) + self.register_buffer('mem_bias', Variable(torch.Tensor(N, M))) # Initialize memory bias stdev = 1 / (np.sqrt(N + M)) @@ -84,7 +83,7 @@ def address(self, k, β, g, s, γ, w_prev): def _similarity(self, k, β): k = k.view(self.batch_size, 1, -1) - w = F.softmax(β * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1)) + w = F.softmax(β * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1), dim=0) return w def _interpolate(self, w_prev, wc, g):