# Implementing Liquid time constant Networks

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from enum import Enum

In [12]:
class MappingType(Enum):
    Identity = 0
    Linear = 1
    Affine = 2

In [13]:
class ODESolver(Enum):
    SemiImplicit = 0
    Explicit = 1
    RungeKutta = 2

In [3]:
class LTCCell(nn.Module):
    def __init__(self, num_units):
        super(LTCCell, self).__init__()

        self.input_size = -1
        self.num_units = num_units
        self.is_built = False

        # Number of ODE solver steps in one RNN step
        self.ode_solver_unfolds = 6
        self.solver = ODESolver.SemiImplicit

        self.input_mapping = MappingType.Affine

        self.erev_init_factor = 1

        self.w_init_max = 1.0
        self.w_init_min = 0.01
        self.cm_init_min = 0.5
        self.cm_init_max = 0.5
        self.gleak_init_min = 1
        self.gleak_init_max = 1

        self.w_min_value = 0.00001
        self.w_max_value = 1000
        self.gleak_min_value = 0.00001
        self.gleak_max_value = 1000
        self.cm_t_min_value = 0.000001
        self.cm_t_max_value = 1000

        self.fix_cm = None
        self.fix_gleak = None
        self.fix_vleak = None

        self.is_built = False


In [4]:
 def forward(self, inputs, state):
        if not self.is_built:
            self.build(inputs.size(1))

        if inputs.size(1) != self.input_size:
            raise ValueError("Input size mismatch")

        inputs = self.map_inputs(inputs)

        if self.solver == ODESolver.Explicit:
            next_state = self.ode_step_explicit(inputs, state)
        elif self.solver == ODESolver.SemiImplicit:
            next_state = self.ode_step(inputs, state)
        elif self.solver == ODESolver.RungeKutta:
            next_state = self.ode_step_runge_kutta(inputs, state)
        else:
            raise ValueError("Unknown ODE solver")

        outputs = next_state

        return outputs, next_state

In [5]:
  def map_inputs(self, inputs):
        if self.input_mapping == MappingType.Affine or self.input_mapping == MappingType.Linear:
            w = torch.nn.Parameter(torch.ones(self.input_size))
            inputs = inputs * w
        if self.input_mapping == MappingType.Affine:
            b = torch.nn.Parameter(torch.zeros(self.input_size))
            inputs = inputs + b
        return inputs


In [6]:
 def build(self, input_size):
        self.input_size = input_size

        self.sensory_mu = nn.Parameter(torch.rand(self.input_size, self.num_units) * 0.5 + 0.3)
        self.sensory_sigma = nn.Parameter(torch.rand(self.input_size, self.num_units) * 5 + 3)
        self.sensory_W = nn.Parameter(torch.ones(self.input_size, self.num_units))
        sensory_erev_init = (2 * np.random.randint(0, 2, size=(self.input_size, self.num_units)) - 1).astype(np.float32)
        self.sensory_erev = nn.Parameter(torch.tensor(sensory_erev_init) * self.erev_init_factor)

        self.mu = nn.Parameter(torch.rand(self.num_units, self.num_units) * 0.5 + 0.3)
        self.sigma = nn.Parameter(torch.rand(self.num_units, self.num_units) * 5 + 3)
        self.W = nn.Parameter(torch.ones(self.num_units, self.num_units))
        erev_init = (2 * np.random.randint(0, 2, size=(self.num_units, self.num_units)) - 1).astype(np.float32)
        self.erev = nn.Parameter(torch.tensor(erev_init) * self.erev_init_factor)

        if self.fix_vleak is None:
            self.vleak = nn.Parameter(torch.rand(self.num_units) * 0.4 - 0.2)
        else:
            self.vleak = torch.tensor(self.fix_vleak, requires_grad=False)

        if self.fix_gleak is None:
            self.gleak = nn.Parameter(torch.ones(self.num_units) * self.gleak_init_min)
        else:
            self.gleak = torch.tensor(self.fix_gleak, requires_grad=False)

        if self.fix_cm is None:
            self.cm_t = nn.Parameter(torch.ones(self.num_units) * self.cm_init_min)
        else:
            self.cm_t = torch.tensor(self.fix_cm, requires_grad=False)

        self.is_built = True

In [7]:
def ode_step(self, inputs, state):
        v_pre = state

        sensory_w_activation = self.sensory_W * self.sigmoid(inputs, self.sensory_mu, self.sensory_sigma)
        sensory_rev_activation = sensory_w_activation * self.sensory_erev

        w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
        w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)

        for t in range(self.ode_solver_unfolds):
            w_activation = self.W * self.sigmoid(v_pre, self.mu, self.sigma)

            rev_activation = w_activation * self.erev

            w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
            w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory

            numerator = self.cm_t * v_pre + self.gleak * self.vleak + w_numerator
            denominator = self.cm_t + self.gleak + w_denominator

            v_pre = numerator / denominator

        return v_pre

In [8]:
 def sigmoid(self, v_pre, mu, sigma):
        mues = v_pre - mu
        x = sigma * mues
        return torch.sigmoid(x)

In [9]:
  def get_param_constrain_op(self):
        cm_clipping_op = torch.clamp(self.cm_t, self.cm_t_min_value, self.cm_t_max_value)
        gleak_clipping_op = torch.clamp(self.gleak, self.gleak_min_value, self.gleak_max_value)
        w_clipping_op = torch.clamp(self.W, self.w_min_value, self.w_max_value)
        sensory_w_clipping_op = torch.clamp(self.sensory_W, self.w_min_value, self.w_max_value)

        return [cm_clipping_op, gleak_clipping_op, w_clipping_op, sensory_w_clipping_op]