In [85]:
import torch
import torch.nn as nn
from torchid.dynonet.module.lti import MimoLinearDynamicalOperator, SisoLinearDynamicalOperator
from torchid.dynonet.module.static import MimoStaticNonLinearity, MimoChannelWiseNonLinearity

In [86]:
# Let us sketch a PWH model:
# - 3 inputs
# - 4 outputs
# - 10 branches
#    u ---> [3x10 LTI] ---> [10x10 FULL STATIC] ---> [10x4 LTI] ---> y


In [87]:
B, T, U, H, Y = 32, 10_000, 3, 10, 4

In [88]:
batch_u = torch.randn((B, T, U))

In [89]:
G1 = MimoLinearDynamicalOperator(U, H, n_a=3, n_b=4, n_k=0)

In [90]:
x = G1(batch_u)

In [91]:
x.shape

torch.Size([32, 10000, 10])

In [92]:
F = MimoStaticNonLinearity(H, H)
x1 = G1(batch_u)
x1.shape

torch.Size([32, 10000, 10])

In [93]:
x2 = F(x1)
x2.shape

torch.Size([32, 10000, 10])

In [94]:
G2 = MimoLinearDynamicalOperator(H, Y, n_a=3, n_b=4, n_k=0)

In [95]:
y = G2(x2)
y.shape

torch.Size([32, 10000, 4])

In [96]:
model = nn.Sequential(
    MimoLinearDynamicalOperator(U, H, n_a=3, n_b=4, n_k=0), # G1
    MimoStaticNonLinearity(H, H), # F
    MimoLinearDynamicalOperator(H, Y, n_a=3, n_b=4, n_k=0) # G2
)

In [97]:
batch_y = model(batch_u)
batch_y.shape

torch.Size([32, 10000, 4])

In [98]:
model

Sequential(
  (0): MimoLinearDynamicalOperator()
  (1): MimoStaticNonLinearity(
    (net): Sequential(
      (0): Linear(in_features=10, out_features=20, bias=True)
      (1): Tanh()
      (2): Linear(in_features=20, out_features=10, bias=True)
    )
  )
  (2): MimoLinearDynamicalOperator()
)

In [99]:
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
            ('G1', MimoLinearDynamicalOperator(U, H, n_a=3, n_b=4, n_k=0)),
            ('F',  MimoStaticNonLinearity(H, H), ),
            ('G2', MimoLinearDynamicalOperator(H, Y, n_a=3, n_b=4, n_k=0)),
]))
model.G1

MimoLinearDynamicalOperator()

In [100]:
class CustomDynonet(nn.Module):
    def __init__(self, n=2):
        super().__init__()
        self.net = nn.Sequential(
            MimoLinearDynamicalOperator(1, 4, n_a=n, n_b=n+1, n_k=0), # G1
            MimoStaticNonLinearity(4, 3), # F
            MimoLinearDynamicalOperator(3, 1, n_a=n, n_b=n+1, n_k=0) # G2
        )
        self.G = SisoLinearDynamicalOperator(n_a=n, n_b=n+1)

    def forward(self, u):
        y = self.net(u) + self.G(u)
        return y

In [101]:
batch_u = torch.randn(B, T, 1)
model = CustomDynonet(n=2)
model(batch_u).shape

torch.Size([32, 10000, 1])