# Experiment 1 - Network Convergence

### Summary

In this experiment, the objective is to understand if the backpropagation throught time algorithm converges in this network using different input signals

In [1]:
# imports
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from ssnn import SSNN
import os
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**2,1), 'MB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**2,1), 'MB')


NVIDIA GeForce RTX 3080
Memory Usage:
Allocated: 0.0 MB
Cached:    0.0 MB


In [2]:
# for reproductibility
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic=True
torch.use_deterministic_algorithms(True)

%env CUBLAS_WORKSPACE_CONFIG=:4096:8

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [8]:
# support functions

def observability(A:torch.Tensor, C:torch.Tensor):
    """
    Returns observability matrix
    """
    with torch.no_grad():
        O = torch.zeros(A.shape[0], C.shape[1])
        for row in range(len(O)):
            O[row] = torch.matmul(C, torch.linalg.matrix_power(A, row))
    return O

def clamp(n, minn, maxn):
    """
    Clamps number between 2
    """
    return max(min(maxn, n), minn)

def save_gif(title:str, dir:str, **kwargs):
    """
    create a gif vizualization of this training process
    """
    
    plot_size = kwargs["plot_size"]
    signal_amplitude = kwargs["signal_amplitude"]
    n_epochs = kwargs["n_epochs"]
    t_loss = kwargs["t_loss"]
    t_plot = kwargs["t_plot"]
    t_observability = kwargs["t_observability"]

    fig, (ax1, ax2) = plt.subplots(2,1)

    ax1.set_title("Output over time")
    ax2.set_title("Loss over epoch")

    line1, = ax1.plot([])
    line2, = ax2.plot([])

    ax1.set_xlim(0, plot_size)
    ax1.set_ylim(-signal_amplitude-5, signal_amplitude+5)

    ax2.set_xlim(0, n_epochs)
    ax2.set_ylim(t_loss[0]-0.5, t_loss[0]+0.5)

    ax1.set_xticks([])

    t_text1 = ax1.text(.5, -signal_amplitude-5+0.5, '', fontsize=10)
    t_text2 = ax2.annotate("", xy = (10, -30), size=10, color='k', xycoords='axes points')
    t_text3 = ax2.annotate("", xy = (80, -30), size=10, color='k', xycoords='axes points')

    def animate(frame_n):
        x1 = np.arange(0, plot_size, 1)
        y1 = t_plot[frame_n].cpu().detach().numpy()

        x2 = np.arange(0, frame_n, 1)
        y2 = np.array(t_loss)[0:frame_n]

        line1.set_data((x1, y1))
        line1.set_color('b')
        line2.set_data((x2, y2))
        line2.set_color('k')

        t_text1.set_text(f"epoch:{frame_n}")
        t_text2.set_text(f"loss:{t_loss[frame_n]:.2f}")
        t_text3.set_text(f"det(O):{t_observability[frame_n]:.2e}")
        
        t_text3.set_text(f"det(O):{t_observability[frame_n]:.2e}")

        if frame_n != 0:
            ax2.set_ylim(   clamp(min(np.array(t_loss)[:frame_n]), 0, 1000)-0.5,
                            clamp(max(np.array(t_loss)[:frame_n]), 0, 1000)+0.5)

        return line1, line2

    anim = FuncAnimation(fig, animate, frames=n_epochs, interval=1)
    anim.save(os.path.join(os.getcwd(), dir, title), writer = 'pillow')
    plt.close()

## Experiment with pure sine

In [9]:
# Generate signal, windows it, generate input and output data.
from unittest import signals

train_signal_len = 200
step = 0.5
window_len = 20
signal_amplitude = 5
eval_signal_len = 10
plot_size = 100

def generate_data(signal_len:int, step:float, window_len:int, **kwargs):
    A = kwargs['amplitude']
    eval_len = kwargs['eval_len']

    # generate a sine dataset
    x_train = torch.arange(0, signal_len, step)
    x_train = x_train.unfold(0, window_len+1, 1)
    x_train = x_train.unsqueeze(-1)

    y_train = x_train[:,-1,:]
    x_train = x_train[:,0:-1,:]

    x_train = A*torch.sin(x_train)
    y_train = A*torch.sin(y_train)

    x_eval = torch.arange(signal_len, signal_len+eval_len*step, step)
    x_eval = x_eval.unsqueeze(-1)
    x_eval = x_eval.unsqueeze(0)
    eval_signal = A*torch.sin(x_eval)

    return (x_train, y_train), eval_signal

(train_signal, train_target), eval_signal = generate_data(signal_len=train_signal_len, step=step, window_len=window_len, amplitude=signal_amplitude, eval_len=eval_signal_len)

train_signal = train_signal.to(device)
train_target = train_target.to(device)
eval_signal = eval_signal.to(device)

#plt.plot(eval_signal.squeeze(0))

In [10]:
# load ssnn model
model = SSNN(u_len = 1, x_len = 10, y_len = 1)
model = model.to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [12]:
# train loop
t_loss = []
t_plot = []
t_observability = []

n_epochs = 10
batch_size = 32

for epoch in tqdm(range(n_epochs)):
    batch = 0
    while batch < int(train_signal.shape[0]/batch_size):
        optimizer.zero_grad()
        outputs = model(train_signal[batch*batch_size:(batch+1)*batch_size])
        outputs = outputs[:, -1, :]
        loss = criterion(outputs, train_target[batch*batch_size:(batch+1)*batch_size])
        
        t_loss.append(loss.item())
        with torch.no_grad():
            t_observability.append(torch.linalg.det(observability(model.A, model.C)).item())

        loss.backward()
        optimizer.step()

        batch += 1

    # last batch
    optimizer.zero_grad()
    outputs = model(train_signal[batch*batch_size:-1])
    outputs = outputs[:, -1, :]
    loss = criterion(outputs, train_target[batch*batch_size:-1])

    t_loss.append(loss.item())
    with torch.no_grad():
        t_observability.append(torch.linalg.det(observability(model.A, model.C)).item())

    loss.backward()
    optimizer.step()

    # eval per epoch
    y_eval = eval_signal

    for t in range(plot_size-eval_signal_len):
        with torch.no_grad():
            y_eval = torch.cat((y_eval, model(y_eval[:, t:eval_signal_len+t, :])[:,-1,:].unsqueeze(-1)), dim=1)
    t_plot.append(y_eval.squeeze())

save_gif("pure-sine.gif", "./results/experiment-1", plot_size=plot_size, signal_amplitude=signal_amplitude, n_epochs=n_epochs, t_loss=t_loss, t_plot=t_plot, t_observability=t_observability)

  0%|          | 0/10 [00:00<?, ?it/s]

## 3 sine harmonics

In [32]:
# Generate signal, windows it, generate input and output data.
from unittest import signals

train_signal_len = 400
step = 0.5
window_len = 40
signal_amplitude = 5
eval_signal_len = 10
plot_size = 100

def generate_data(signal_len:int, step:float, window_len:int, **kwargs):
    A = kwargs['amp1']
    B = kwargs['amp2']
    C = kwargs['amp3']
    eval_len = kwargs['eval_len']

    # generate a sine dataset
    x_train = torch.arange(0, signal_len, step)
    x_train = x_train.unfold(0, window_len+1, 1)
    x_train = x_train.unsqueeze(-1)

    y_train = x_train[:,-1,:]
    x_train = x_train[:,0:-1,:]

    x_train = A*torch.sin(x_train)+B*torch.sin(x_train/3)+C*torch.sin(x_train/5)
    y_train = A*torch.sin(y_train)+B*torch.sin(y_train/3)+C*torch.sin(y_train/5)

    x_eval = torch.arange(signal_len, signal_len+eval_len*step, step)
    x_eval = x_eval.unsqueeze(-1)
    x_eval = x_eval.unsqueeze(0)
    eval_signal = A*torch.sin(x_eval)+B*torch.sin(x_eval/3)+C*torch.sin(x_eval/5)

    return (x_train, y_train), eval_signal

(train_signal, train_target), eval_signal = generate_data(signal_len=train_signal_len, step=step, window_len=window_len, amp1=signal_amplitude, amp2=2, amp3=1, eval_len=eval_signal_len)

train_signal = train_signal.to(device)
train_target = train_target.to(device)
eval_signal = eval_signal.to(device)

In [33]:
# load ssnn model
model = SSNN(u_len = 1, x_len = 10, y_len = 1)
model = model.to(device)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [34]:
# train loop
t_loss = []
t_plot = []
t_observability = []

n_epochs = 200
batch_size = 32

for epoch in tqdm(range(n_epochs)):
    batch = 0
    while batch < int(train_signal.shape[0]/batch_size):
        optimizer.zero_grad()
        outputs = model(train_signal[batch*batch_size:(batch+1)*batch_size])
        outputs = outputs[:, -1, :]
        loss = criterion(outputs, train_target[batch*batch_size:(batch+1)*batch_size])
        
        t_loss.append(loss.item())
        with torch.no_grad():
            t_observability.append(torch.linalg.det(observability(model.A, model.C)).item())

        loss.backward()
        optimizer.step()
# - Loss: MSE
# - SGD without momentum

        batch += 1

    # last batch
    optimizer.zero_grad()
    outputs = model(train_signal[batch*batch_size:-1])
    outputs = outputs[:, -1, :]
    loss = criterion(outputs, train_target[batch*batch_size:-1])

    t_loss.append(loss.item())
    with torch.no_grad():
        t_observability.append(torch.linalg.det(observability(model.A, model.C)).item())

    loss.backward()
    optimizer.step()

    # eval per epoch
    y_eval = eval_signal

    for t in range(plot_size-eval_signal_len):
        with torch.no_grad():
            y_eval = torch.cat((y_eval, model(y_eval[:, t:eval_signal_len+t, :])[:,-1,:].unsqueeze(-1)), dim=1)
    t_plot.append(y_eval.squeeze())

save_gif("3-harmonics.gif", "./results/experiment-1", plot_size=plot_size, signal_amplitude=signal_amplitude, n_epochs=n_epochs, t_loss=t_loss, t_plot=t_plot, t_observability=t_observability)

  0%|          | 0/200 [00:00<?, ?it/s]