<a href="https://colab.research.google.com/github/farenga/neuralODEs/blob/main/0_deep_implicit_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implicit layers

When dealing with implicit layers, instead of specifying the explicit function for computing the layer's output we specify a set of conditions. By the way everything can be made implicit, like the following explict formulation:
\begin{equation}
z = f(x), \qquad f:X \rightarrow Z \subset \mathbb{R}^n
\end{equation}
becomes
\begin{equation}
\text{find} \quad x \quad \text{s.t.} \quad g(x,z)=0, \qquad g(x,z)=z-f(x), \qquad g:X\times Z \rightarrow \mathbb{R}^n
\end{equation}

### Example

Let us now find $z$ such that $z = \tanh(Wz+x)$.

This means finding the root of $g(x,z)$, where $g(x,z)=z-\tanh(Wz+x)$

When we deal with implicit equations we tackle the problem with an iterative approach.

In [1]:
import numpy as np
import torch
import torch.nn as nn

In [35]:
max_iter = 100
iter = 0
tol = 1e-3
n = 3
z = np.zeros(shape=(n,1))
x = np.ones(shape=(n,1))
W = np.random.rand(n,n)
while (iter<max_iter):
  z_next = np.tanh(np.matmul(W,z)+x)
  z = z_next
  iter += 1

In [36]:
z

array([[0.9974062 ],
       [0.95144457],
       [0.97637945]])

Pythorch implementation via a linear layer

In [17]:
class TanhFixedPointLayer(nn.Module):
  def __init__(self, out_features, tol = 1e-4, max_iter = 50):
    super().__init__()
    self.linear = nn.Linear(out_features, out_features, bias = False)
    self.tol = tol
    self.max_iter = max_iter

  def forward(self,x):
    z = torch.zeros_like(x)
    self.iterations = 0

    while (self.iterations < self.max_iter):
      z_next = torch.tanh(self.linear(z)+x)
      self.err = torch.norm(z-z_next)
      z = z_next
      self.iterations += 1
      if (self.err < self.tol):
        break
      
    return z

In [37]:
layer = TanhFixedPointLayer(n)
x = torch.tensor([1.,1.,1.])
z = layer(x)
z
print(z)
print(layer.iterations)
print(layer.err)

tensor([0.4906, 0.7258, 0.7385], grad_fn=<TanhBackward0>)
13
tensor(4.9090e-05, grad_fn=<CopyBackwards>)


In order to find $g(x,z)$ roots we can use Newton's method, indeed instead of performing the following fixed point iteration

\begin{equation}
z_{n+1} = \tanh(W z_n + x)
\end{equation}

we can express this as an implicit equation $g(x,z)=0$ and apply Newton's method to $g$:

\begin{equation}
z_{n+1} = z_n - \frac{g(x,z_n)}{\partial_z g(x,z)|_{z=z_n}}
\end{equation}



In [46]:
class TanhNewtonLayer(nn.Module):
    def __init__(self, out_features, tol = 1e-4, max_iter=50):
        super().__init__()
        self.linear = nn.Linear(out_features, out_features, bias=False)
        self.tol = tol
        self.max_iter = max_iter
  
    def forward(self, x):
        # initialize output z to be zero
        z = torch.tanh(x)
        self.iterations = 0
    
        # iterate until convergence
        while self.iterations < self.max_iter:
            z_linear = self.linear(z) + x
            g = z - torch.tanh(z_linear)
            self.err = torch.norm(g)
            if self.err < self.tol:
                break

            # newton step
            J = torch.eye(z.shape[1])[None,:,:] - (1 / torch.cosh(z_linear)**2)[:,:,None]*self.linear.weight[None,:,:]
            z = z - torch.solve(g[:,:,None], J)[0][:,:,0]
            self.iterations += 1

        g = z - torch.tanh(self.linear(z) + x)
        z[torch.norm(g,dim=1) > self.tol,:] = 0
        return z

In [47]:
layer = TanhNewtonLayer(50)
X = torch.randn(10,50)
Z = layer(X)
print(f"Terminated after {layer.iterations} iterations with error {layer.err}")

Terminated after 3 iterations with error 1.0844622693184647e-06


By the way this approach is not efficient, since:


*   At each iteration we have to compute and invert the Jacobian matrix (for each sample in the minibatch)
*   By implementing the Newton's method directly within an automatic differentiation toolkit we have to:
    *  Save intermediate iterates of the hidden units, that in this case means to store iterates of the Jacobian terms (High memory consumption)
    * Backprop + Inversion can be numerically unstable
    

