In [32]:
import torch 
from torch import nn
import numpy as np

class SSM(nn.Module):
    """Defines a SSM layer
       x̄(t) = Ax(t) + Bu(t) -> Change in x(t) over time
       y(t) = Cx(t) + Du(t) -> We assume D=0 in this implementation

    Args:
        
        N : size of the parameter matrices

        The elements of each matrix are drawn from a uniform distribution between 0 and 1. 

    Returns : 

        A, B, C, D : Parameters to be learned by gradient descent

        A : Size = NxN

        B : Size = Nx1

        C : Size = 1xN
    """
    def __init__(self, N : int, print_shape : bool):
        super().__init__()
        self.A = torch.rand(N, N)
        self.B = torch.rand(N, 1)
        self.C = torch.rand(1, N)

        A = self.A
        B = self.B
        C = self.C
        
        if print_shape:
            print(f"""
            Shape of A -> {self.A.shape}
            Shape of B -> {self.B.shape} 
            Shape of C -> {self.C.shape}""")
        return A, B, C

    def discretize_signal(self, A, B, C, step, print_shape : bool):
        """

        > To be applied on a discrete input sequence (u0, u1, ...) instead of continuous function
        ut), the SSM must be discretized by a step size 'd' that represents the resolution of the
        input. 

        > Conceptually, the inputs Uk can be viewed as sampling an implicit underlying
        continuous signal u(t), where uk = u(kA).

        > To discretize the continuous-time SSM, we use the bilinear method, which converts the
        state matrix A into an approximation A. The discrete SSM is:

            Ab = (I - d/2 * A)^-1 @ (I + d/2 * A)
            Bb = (I - d/2 * A)^-1 @ (dB)
            C remains the same

        """
        I = np.eye(A.shape[0])
        BL = torch.linalg.inv(I - (step / 2.0) * A)
        Ab = BL @ (I + (step / 2.0) * A)
        Bb = (BL * step) @ B
        if print_shape:
            print(f"""
            Shape of I -> {self.I.shape}
            Shape of BL-> {self.BL.shape}
            Shape of Ab-> {self.Ab.shape}
            Shape of Bb-> {self.Bb.shape} 
            Shape of C -> {self.C.shape}""")

        return Ab, Bb, C

In [40]:
# Testing cells
N = 3
step = 2
A = torch.rand(N, N)
B = torch.rand(N, 1)
C = torch.rand(1, N)
print(f"""
            Shape of A -> {A.shape}
            Shape of B -> {B.shape} 
            Shape of C -> {C.shape}""")

print('--------------------------------------------------------------------')
I = torch.eye(A.shape[0])
BL = torch.linalg.inv(I - (step / 2.0) * A)
BL2 = (I + (step / 2) * A)
Ab = BL @ BL2
Bb = (BL * step) @ B
print(f"""
            Shape of I -> {I.shape}
            Shape of BL-> {BL.shape}
            Shape of Ab-> {Ab.shape}
            Shape of Bb-> {Bb.shape} 
            Shape of C -> {C.shape}""")


            Shape of A -> torch.Size([3, 3])
            Shape of B -> torch.Size([3, 1]) 
            Shape of C -> torch.Size([1, 3])
--------------------------------------------------------------------

            Shape of I -> torch.Size([3, 3])
            Shape of BL-> torch.Size([3, 3])
            Shape of Ab-> torch.Size([3, 3])
            Shape of Bb-> torch.Size([3, 1]) 
            Shape of C -> torch.Size([1, 3])
