<a href="https://colab.research.google.com/github/kwn-w/python/blob/main/pondernet_pr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from typing import Tuple

In [3]:
import torch
from torch import nn


In [6]:
pip install Module

Collecting Module
  Downloading module-0.0.4-py3-none-any.whl (13 kB)
Installing collected packages: Module
Successfully installed Module-0.0.4


In [11]:
pip install labml-helpers

Collecting labml-helpers
  Downloading labml_helpers-0.4.82-py3-none-any.whl (18 kB)
Collecting labml>=0.4.129
  Downloading labml-0.4.132-py3-none-any.whl (121 kB)
[K     |████████████████████████████████| 121 kB 5.2 MB/s 
Collecting gitpython
  Downloading GitPython-3.1.18-py3-none-any.whl (170 kB)
[K     |████████████████████████████████| 170 kB 22.8 MB/s 
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.7-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
Collecting smmap<5,>=3.0.1
  Downloading smmap-4.0.0-py2.py3-none-any.whl (24 kB)
Installing collected packages: smmap, gitdb, gitpython, labml, labml-helpers
Successfully installed gitdb-4.0.7 gitpython-3.1.18 labml-0.4.132 labml-helpers-0.4.82 smmap-4.0.0


In [13]:
from labml_helpers.module import Module

**PonderNet with GRU for Parity Task**

In [14]:
class ParityPonderGRU(Module):
  def __init__(self, n_elems: int, n_hidden: int, max_steps: int):
    super().__init__()
    self.max_steps = max_steps
    self.n_hidden = n_hidden

    self.gru = nn.GRUCell(n_elems, n_hidden)

    self.output_layer = nn.Linear(n_hidden, 1)

    self.lambda_layer = nn.Linear(n_hidden, 1)
    self.lambda_prob = nn.Sigmoid()

    self.is_halt = False


  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    batch_size = x.shape[0]

    h = x.new_zeros((x.shape[0], self.n_hidden))
    h = self.gru(x, h)

    p = []
    y = []

    un_halted_prob = h.new_ones((batch_size,))

    halted = h.new_zeros((batch_size,))

    p_m = h.new_zeros((batch_size,))
    y_m = h.new_zeros((batch_size,))

    for n in range(1, self.max_steps + 1):
      if n == self.max_steps:
        lambda_n = h.new_ones(h.shape[0])
      else:
        lambda_n = self.lambda_prob(self.lambda_layer(h))[:, 0]

      y_n = self.output_layer(h)[:, 0]

      p_n = un_halted_prob * lambda_n

      un_halted_prob = un_halted_prob * (1 - lambda_n)
      halt = torch.bernoulli(lambda_n) * (1 - halted)

      p.append(p_n)
      y.append(y_n)

      p_m = p_m * (1 - halt) + p_n * halt
      y_m = y_m * (1 - halt) + y_n * halt

      halted = halted + halt

      h = self.gru(x, h)

      if self.is_halt and halted.sum() == batch_size:
        break

      return torch.stack(p), torch.stack(y), p_m, y_m

In [15]:
class ReconstructionLoss(Module):
  def __init__(self, loss_func: nn.Module):
    super().__init__()
    self.loss_func = loss_func


  def forward(self, p: torch.Tensor, y_hat: torch.Tensor, y: torch.Tensor):
    total_loss = p.new_tensor(0.)
    for n in range(p.shape[0]):
      loss = (p[n] * self.loss_func(y_hat[n], y)).mean()
      total_loss = total_loss + loss
    return total_loss

In [16]:
class RegularizationLoss(Module):
  def __init__(self, lambda_p: float, max_steps: int = 1_000):
    super().__init__()
    p_g = torch.zeros((max_steps,))
    not_halted = 1.
    
    for k in range(max_steps):
      p_g[k] = not_halted * lambda_p
      not_halted = not_halted * (1 - lambda_p)
      
    self.p_g = nn.Parameter(p_g, requires_grad=False)
    self.kl_div = nn.KLDivLoss(reduction='batchmean')
    
    
  def forward(self, p: torch.Tensor):
    p = p.transpose(0, 1)
    p_g = self.p_g[None, :p.shape[1]].expand_as(p)
    return self.kl_div(p.log(), p_g)