# Some output experiments on the SSNN

import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from ssnn import SSNN

torch.manual_seed(0)
import random
random.seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True)

amp = 10
n_epochs = 200
plot_size = 100
state_len = 5

# Dataset
x_train = torch.arange(0, 100, 0.5)
x_train = x_train.unfold(0, 20, 1)
x_train = x_train.unsqueeze(-1)

y_train = x_train[:,-1,:]
x_train = x_train[:,0:-1,:]

x_train = amp*torch.sin(x_train)
y_train = amp*torch.sin(y_train)

# function
def observability(A:torch.Tensor, C:torch.Tensor):
    with torch.no_grad():
        O = torch.zeros(A.shape[0], C.shape[1])
        for row in range(len(O)):
            O[row] = torch.matmul(C, torch.linalg.matrix_power(A, row))
    return O

# train loop
ssnn = SSNN(1, state_len, 1)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(ssnn.parameters(), lr=0.001, momentum=0.9)

t_loss = []
t_observability = []
t_plot = []


for epoch in range(n_epochs):  # loop over the dataset multiple times
    optimizer.zero_grad()

    outputs = ssnn(x_train)
    outputs = outputs[:, -1, :]
    loss = criterion(outputs, y_train)

    t_loss.append(loss.item())
    with torch.no_grad():
        t_observability.append(torch.linalg.det(observability(ssnn.A, ssnn.C)).item())


    loss.backward()
    optimizer.step()

    # eval per epoch
    x_eval = torch.arange(100, 110, 0.5)
    x_eval = x_eval.unsqueeze(-1)
    x_eval = x_eval.unsqueeze(0)
    y_eval = amp*torch.sin(x_eval)

    for t in range(plot_size-20):
        with torch.no_grad():
            y_eval = torch.cat((y_eval, ssnn(y_eval[:, t:20+t, :])[:,-1,:].unsqueeze(-1)), dim=1)

    t_plot.append(y_eval.squeeze())

# show plots
def clamp(n, minn, maxn):
    return max(min(maxn, n), minn)

fig, (ax1, ax2) = plt.subplots(2,1)

ax1.set_title("Output over time")
ax2.set_title("Loss over epoch")

line1, = ax1.plot([])
line2, = ax2.plot([])

ax1.set_xlim(0, plot_size)
ax1.set_ylim(-amp-5, amp+5)

ax2.set_xlim(0, n_epochs)
ax2.set_ylim(t_loss[0]-0.5, t_loss[0]+0.5)

ax1.set_xticks([])

t_text1 = ax1.text(.5, -amp-5+0.5, '', fontsize=10)
t_text2 = ax2.annotate("", xy = (10, -30), size=10, color='k', xycoords='axes points')
t_text3 = ax2.annotate("", xy = (80, -30), size=10, color='k', xycoords='axes points')

def animate(frame_n):
    x1 = np.arange(0, plot_size, 1)
    y1 = t_plot[frame_n].detach().numpy()
    
    x2 = np.arange(0, frame_n, 1)
    y2 = np.array(t_loss)[0:frame_n]

    line1.set_data((x1, y1))
    line1.set_color('b')
    line2.set_data((x2, y2))
    line2.set_color('k')

    t_text1.set_text(f"epoch:{frame_n}")
    t_text2.set_text(f"loss:{t_loss[frame_n]:.2f}")
    
    t_text3.set_text(f"det(O):{t_observability[frame_n]:.2e}")

    if frame_n != 0:
        ax2.set_ylim(   clamp(min(np.array(t_loss)[:frame_n]), 0, 1000)-0.5,
                        clamp(max(np.array(t_loss)[:frame_n]), 0, 1000)+0.5)

    return line1, line2

anim = FuncAnimation(fig, animate, frames=n_epochs, interval=1)
anim.save(f"gifs/amp{amp}:n_epochs{n_epochs}:plot_size{plot_size}:state_len{state_len}.gif", writer = 'pillow')
plt.close()

In [3]:
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from ssnn import SSNN

torch.manual_seed(0)
import random
random.seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True)

amps = [1,5,10,20]
n_epochs = 500
plot_size = 100
state_lens = [1,2,3,4,5,6,7,8,9,10]

class lstm_linear(torch.nn.Module):
    def __init__(self, u_len:int, x_len:int, y_len:int):
        super(lstm_linear, self).__init__()
        
        self.lstm = torch.nn.LSTM(u_len, x_len, batch_first=True)
        self.linear = torch.nn.Linear(x_len, y_len)
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        y = self.linear(out)
        return y

for amp in amps:
    for state_len in state_lens:
        # Dataset
        x_train = torch.arange(0, 100, 0.5)
        x_train = x_train.unfold(0, 20, 1)
        x_train = x_train.unsqueeze(-1)

        y_train = x_train[:,-1,:]
        x_train = x_train[:,0:-1,:]

        x_train = amp*torch.sin(x_train)
        y_train = amp*torch.sin(y_train)

        # function
        def observability(A:torch.Tensor, C:torch.Tensor):
            with torch.no_grad():
                O = torch.zeros(A.shape[0], C.shape[1])
                for row in range(len(O)):
                    O[row] = torch.matmul(C, torch.linalg.matrix_power(A, row))
            return O

        # train loop
        lstm = lstm_linear(1, state_len, 1)

        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.SGD(lstm.parameters(), lr=0.001, momentum=0.9)

        t_loss = []
        t_plot = []


        for epoch in range(n_epochs):  # loop over the dataset multiple times
            optimizer.zero_grad()

            outputs = lstm(x_train)
            loss = criterion(outputs, y_train)

            t_loss.append(loss.item())

            loss.backward()
            optimizer.step()

            # eval per epoch
            x_eval = torch.arange(100, 110, 0.5)
            x_eval = x_eval.unsqueeze(-1)
            x_eval = x_eval.unsqueeze(0)
            y_eval = amp*torch.sin(x_eval)

            for t in range(plot_size-20):
                with torch.no_grad():
                    y_eval = torch.cat((y_eval, lstm(y_eval[:, t:20+t, :]).unsqueeze(-1)), dim=1)

            t_plot.append(y_eval.squeeze())

        # show plots
        def clamp(n, minn, maxn):
            return max(min(maxn, n), minn)

        fig, (ax1, ax2) = plt.subplots(2,1)

        ax1.set_title("Output over time")
        ax2.set_title("Loss over epoch")

        line1, = ax1.plot([])
        line2, = ax2.plot([])

        ax1.set_xlim(0, plot_size)
        ax1.set_ylim(-amp-5, amp+5)

        ax2.set_xlim(0, n_epochs)
        ax2.set_ylim(t_loss[0]-0.5, t_loss[0]+0.5)

        ax1.set_xticks([])

        t_text1 = ax1.text(.5, -amp-5+0.5, '', fontsize=10)
        t_text2 = ax2.annotate("", xy = (10, -30), size=10, color='k', xycoords='axes points')

        def animate(frame_n):
            x1 = np.arange(0, plot_size, 1)
            y1 = t_plot[frame_n].detach().numpy()
            
            x2 = np.arange(0, frame_n, 1)
            y2 = np.array(t_loss)[0:frame_n]

            line1.set_data((x1, y1))
            line1.set_color('b')
            line2.set_data((x2, y2))
            line2.set_color('k')

            t_text1.set_text(f"epoch:{frame_n}")
            t_text2.set_text(f"loss:{t_loss[frame_n]:.2f}")

            if frame_n != 0:
                ax2.set_ylim(   clamp(min(np.array(t_loss)[:frame_n]), 0, 1000)-0.5,
                                clamp(max(np.array(t_loss)[:frame_n]), 0, 1000)+0.5)

            return line1, line2

        anim = FuncAnimation(fig, animate, frames=n_epochs, interval=1)
        anim.save(f"gifs/lstm/amp{amp}:n_epochs{n_epochs}:plot_size{plot_size}:state_len{state_len}.gif", writer = 'pillow')
        plt.close()