In [1]:
import torch
import torch.nn as nn

from pideq.net import PINN, PIDEQ
from pideq.deq.solvers import forward_iteration

In [2]:
net = PINN(1., n_nodes=3)

In [7]:
deq = PIDEQ.from_pinn(net)

x = torch.rand(1,2)

print(net(x[...,0].unsqueeze(-1), x[...,1].unsqueeze(-1)))
print(deq(x[...,0].unsqueeze(-1), x[...,1].unsqueeze(-1))[0])

tensor([[-0.3152,  0.4197]], grad_fn=<AddmmBackward0>)
tensor([[-0.3152,  0.4197]], grad_fn=<AddmmBackward0>)


In [8]:
# define number of states for each hidden layer
n_states = 0
for l in net.fcn[:-2]:  # exclude ouptut layer
    try:
        n_states += l.out_features
    except AttributeError:
        pass
n_states

12

In [9]:
deq = PIDEQ(1., n_in=net.fcn[0].in_features, n_out=net.fcn[-1].out_features, n_states=n_states, n_hidden=0, solver=forward_iteration)
deq

PIDEQ(
  (B): Linear(in_features=12, out_features=12, bias=True)
  (A): Linear(in_features=2, out_features=12, bias=True)
  (h): Linear(in_features=12, out_features=2, bias=True)
)

In [10]:
l1 = net.fcn[0]

A_w = torch.zeros_like(deq.A.weight)
A_w[:l1.weight.shape[0]] = l1.weight
deq.A.weight = nn.Parameter(A_w)

A_b = torch.zeros_like(deq.A.bias)
A_b[:l1.bias.shape[0]] = l1.bias
deq.A.bias = nn.Parameter(A_b)

In [11]:
B_w = torch.zeros_like(deq.B.weight)
B_b = torch.zeros_like(deq.B.bias)

l0 = net.fcn[0].out_features  # end of the last hidden layer's output
for l in net.fcn[1:-2]:  # skip first and last layers
    if isinstance(l, nn.Linear):
        B_w[l0:l0 + l.out_features,l0 - l.in_features:l0] = l.weight
        B_b[l0:l0 + l.out_features] = l.bias
        l0 = l0 + l.out_features

deq.B.weight = nn.Parameter(B_w)
deq.B.bias = nn.Parameter(B_b)

In [12]:
ll = net.fcn[-1]

h_w = torch.zeros_like(deq.h.weight)
h_w[:,-ll.weight.shape[-1]:] = ll.weight
deq.h.weight = nn.Parameter(h_w)

h_b = torch.zeros_like(deq.h.bias)
h_b[-ll.bias.shape[0]:] = ll.bias
deq.h.bias = nn.Parameter(h_b)

In [13]:
x = torch.rand(1,2)
x

tensor([[0.4703, 0.3210]])

In [14]:
z0 = torch.zeros(1,12)
z = z0

In [21]:
z = deq.nonlin(deq.A(x) + deq.B(z))
z

tensor([[-0.1996,  0.2331,  0.2224,  0.2395,  0.3311, -0.1186,  0.0319,  0.1266,
         -0.2079, -0.1951,  0.2081,  0.1212]], grad_fn=<TanhBackward0>)

In [22]:
deq.h(z)

tensor([[0.3377, 0.1848]], grad_fn=<AddmmBackward0>)

In [23]:
# z = torch.zeros(1,2)
z = x
for l in net.fcn:
    z = l(z)
    print(z)

tensor([[-0.2023,  0.2375,  0.2261]], grad_fn=<AddmmBackward0>)
tensor([[-0.1996,  0.2331,  0.2224]], grad_fn=<TanhBackward0>)
tensor([[ 0.2443,  0.3440, -0.1191]], grad_fn=<AddmmBackward0>)
tensor([[ 0.2395,  0.3311, -0.1186]], grad_fn=<TanhBackward0>)
tensor([[ 0.0319,  0.1273, -0.2110]], grad_fn=<AddmmBackward0>)
tensor([[ 0.0319,  0.1266, -0.2079]], grad_fn=<TanhBackward0>)
tensor([[-0.1976,  0.2112,  0.1218]], grad_fn=<AddmmBackward0>)
tensor([[-0.1951,  0.2081,  0.1212]], grad_fn=<TanhBackward0>)
tensor([[0.3377, 0.1848]], grad_fn=<AddmmBackward0>)


In [26]:
net(x[...,0], x[...,1])

tensor([0.0477, 0.0570], grad_fn=<AddBackward0>)

In [28]:
deq(x[...,0], x[...,1])

(tensor([[0.0477, 0.0570],
         [0.0477, 0.0570]], grad_fn=<AddmmBackward0>),
 tensor(0.6040, grad_fn=<DivBackward0>))