In [67]:
# library
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torchmetrics import R2Score, MeanSquaredError
import csv

r2score  = R2Score()
msescore = MeanSquaredError()

torch.manual_seed(2)
np.random.seed(2)
torch.set_printoptions(precision=8)

In [68]:
# Model
class Net(torch.nn.Module):
  def __init__(self, n=4, p=8, noi=1, b0_size = 8, bi_size = 2, trunk_size = 8):
    super(Net, self).__init__()
    self.n   = n                    # horizon window length
    self.p   = p                    # size of branch and trunk output
    self.noi = noi                  # number of input

    self.b0_size    = b0_size
    self.bi_size    = bi_size
    self.trunk_size = trunk_size
    
    # Branch x0
    self.input_x0  = torch.nn.Linear(1, self.b0_size)
    self.hidden_x0 = torch.nn.Linear(self.b0_size, self.b0_size)
    self.output_x0 = torch.nn.Linear(self.b0_size, self.p)

    # Branch 1 u
    self.input_u  = torch.nn.Linear(self.noi*self.n, self.bi_size)
    self.hidden_u = torch.nn.Linear(self.bi_size, self.bi_size)
    self.output_u = torch.nn.Linear(self.bi_size, self.p)

    # Trunk
    self.input_t  = torch.nn.Linear(1, self.trunk_size)
    self.hidden_t = torch.nn.Linear(self.trunk_size, self.trunk_size)
    self.output_t = torch.nn.Linear(self.trunk_size, self.p)

  def forward(self, x0, u, t):
    # h
    h = torch.selu(self.input_x0(x0))
    h = torch.selu(self.hidden_x0(h))
    h = self.output_x0(h)

    # f
    f = torch.selu(self.input_u(u))
    f = torch.selu(self.hidden_u(f))
    f = self.output_u(f)

    # g
    g = torch.selu(self.input_t(t))
    g = torch.selu(self.hidden_t(g))
    g = self.output_t(g)

    return torch.sum(h*f*g, dim=1).reshape(-1,1)

In [69]:
# Model error
def eval(model, testset):
    with torch.no_grad():
        pred_Y = model(testset.x0_data, testset.u_data, testset.t_data)

    r2_1  = r2score(pred_Y[0::4], testset.y_data[0::4])
    mse_1 = msescore(pred_Y[0::4], testset.y_data[0::4])
    r2_2  = r2score(pred_Y[1::4], testset.y_data[1::4])
    mse_2 = msescore(pred_Y[1::4], testset.y_data[1::4])
    r2_3  = r2score(pred_Y[2::4], testset.y_data[2::4])
    mse_3 = msescore(pred_Y[2::4], testset.y_data[2::4])
    r2_4  = r2score(pred_Y[3::4], testset.y_data[3::4])
    mse_4 = msescore(pred_Y[3::4], testset.y_data[3::4])
    return r2_1.item(), mse_1.item(), r2_2.item(), mse_2.item(), r2_3.item(), mse_3.item(), r2_4.item(), mse_4.item()

In [70]:
# Data
class Data(torch.utils.data.Dataset):
  def __init__(self, src_file, n, H, noi):
    self.n   = n                                  # horizon length
    self.H   = H                                  # max window length
    self.noi = noi                                # number of input
    self.src_file = src_file                      # source file
    df = pd.read_csv(self.src_file, header=None)

    X0, U, T, Y = np.array([[1]], dtype=np.float32), np.ones((1, n*self.noi)), np.array([[1]], dtype=np.float32), np.array([[1]], dtype=np.float32)
    for i in range(df.shape[0]):
        row = np.array(df.iloc[i])
        for j in range(self.H - self.n):
            x0 = np.array([[row[self.H*self.noi + j]]])
            u  = np.array([row[j:j + self.n*self.noi]])
            for t in range(1, self.n + 1):
                y = np.array([[row[self.H*self.noi + j + t]]])
                t = np.array([[t]])

                X0 = np.concatenate((X0, x0))
                U  = np.concatenate((U, u))
                T  = np.concatenate((T, t))
                Y  = np.concatenate((Y, y))

    X0, U, T, Y = X0[1:], U[1:], T[1:], Y[1:]

    self.x0_data = torch.tensor(X0, dtype=torch.float32)
    self.u_data  = torch.tensor(U,  dtype=torch.float32)
    self.t_data  = torch.tensor(T,  dtype=torch.float32)
    self.y_data  = torch.tensor(Y,  dtype=torch.float32)

  def __len__(self):
    return len(self.x0_data)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    x0  = self.x0_data[idx]
    u   = self.u_data[idx]
    t   = self.t_data[idx]
    y   = self.y_data[idx]
    sample = {'x0':x0, 'u':u, 't':t, 'y':y}
    return sample

In [71]:
# Early stopping
def early_stop(list, min_epochs, patience):
    if(len(list) > min_epochs):
        if(np.max(list[-patience:]) < 1.0001*np.max(list[0: -patience])):
            return 1
    return 0

In [72]:
# Plot
def plot(net, dataset, size):
    with torch.no_grad():
        pred_Y = net(dataset.x0_data, dataset.u_data, dataset.t_data)

    plt.figure(figsize=size)
    plt.plot(dataset.y_data[0::4], 'b',   label=r'real',      linewidth=3)
    plt.plot(pred_Y[0::4],         'r--', label=r'predicted', linewidth=1)
    plt.ylabel(r'x(t)')
    plt.legend()
    plt.show()

In [73]:
# Train function
def train(net, train_ds, test_ds, min_epochs=200, max_epochs=100000, patience=100):
    loss_func  = torch.nn.MSELoss()
    optimizer  = torch.optim.Adam(net.parameters(), lr=0.001)

    train_ldr = torch.utils.data.DataLoader(train_ds, batch_size=train_ds.y_data.shape[0], shuffle=True)

    R2  = np.array([])
    MSE = np.array([])
    for epoch in range(0, max_epochs+1):
        net.train()
        loss  = 0
        count = 0
        for (_, batch) in enumerate(train_ldr):
            X0 = batch['x0']
            U  = batch['u']
            T  = batch['t']
            Y  = batch['y']

            optimizer.zero_grad()
            output = net(X0, U, T)             # compute the output of the Network
            loss_val = loss_func(output, Y)    # loss function
            loss += loss_val.item()            # accumulate
            loss_val.backward()                # gradients
            optimizer.step()                   # update paramters
            count += 1
        
        net.eval()
        R2_1  = np.append(R2, eval(net, test_ds)[0])
        MSE_1 = np.append(MSE, eval(net, test_ds)[1])
        R2_2  = np.append(R2, eval(net, test_ds)[2])
        MSE_2 = np.append(MSE, eval(net, test_ds)[3])
        R2_3  = np.append(R2, eval(net, test_ds)[4])
        MSE_3 = np.append(MSE, eval(net, test_ds)[5])
        R2_4  = np.append(R2, eval(net, test_ds)[6])
        MSE_4 = np.append(MSE, eval(net, test_ds)[7])
        
        if(early_stop(list = R2, min_epochs = min_epochs, patience = patience) == 1):
            break
    
    return R2_1, MSE_1, R2_2, MSE_2, R2_3, MSE_3, R2_4, MSE_4

In [74]:
with open('stat_original.csv', 'w') as file:
     writer = csv.writer(file)
     writer.writerow(['p', 'b0', 'bi', 'trunk', 'R2_1', 'MSE_1', 'R2_2', 'MSE_2', 'R2_3', 'MSE_3', 'R2_4', 'MSE_4'])

In [75]:
df_result = pd.DataFrame({'p':[], 'b0':[], 'bi':[], 'trunk':[], 'R2_1':[], 'MSE_1':[], 'R2_2':[], 'MSE_2':[], 'R2_3':[], 'MSE_3':[], 'R2_4':[], 'MSE_4':[]})

for _p in [4, 8]:
    for b0 in [4, 8]:
        for bi in [2, 8]:
            for trunk in [4, 8]:
                # Hyperparameters
                p   = _p          # size of branch and trunk ouput
                n   = 4           # horizon window length
                noi = 2           # number of inputs
                H   = 512         # maximum window length

                # Create Dataset and DataLoader objects
                src_file_train = '0. Data/data_0.csv'
                train_ds       = Data(src_file_train, n, H, noi)

                src_file_test  = '0. Data/data_1.csv'
                test_ds        = Data(src_file_test, n, H, noi)

                # Create network
                device = torch.device("cpu")
                net = Net(n, p, noi, b0_size=b0, bi_size=bi, trunk_size=trunk).to(device)

                # train model
                min_epochs = 2
                max_epochs = 3
                patience   = 1
                R2_1, MSE_1, R2_2, MSE_2, R2_3, MSE_3, R2_4, MSE_4 = train(net, train_ds, test_ds, min_epochs, max_epochs, patience)

                with open('stat_original.csv', 'a') as file:
                    writer = csv.writer(file)
                    writer.writerow([_p, b0, bi, trunk, np.max(R2_1), np.min(MSE_1), np.max(R2_2), np.min(MSE_2), np.max(R2_3), np.min(MSE_3), np.max(R2_4), np.min(MSE_4)])

4 4 2 4 -499.6763916015625 1585.557373046875 -516.3147583007812 1643.5946044921875 -552.5969848632812 1766.7698974609375 -643.6036987304688 2065.142822265625
4 4 2 8 -1626.4930419921875 5153.9951171875 -1624.177490234375 5163.45751953125 -1618.4967041015625 5168.521484375 -1613.1136474609375 5171.20068359375
4 4 8 4 -18219.80859375 57702.21484375 -12880.5263671875 40926.73828125 -9879.7646484375 31533.837890625 -7294.45849609375 23372.75390625
4 4 8 8 -1275.292724609375 4041.80322265625 -1197.5587158203125 3808.019287109375 -1221.8734130859375 3902.723388671875 -1267.8443603515625 4065.0478515625
4 8 2 4 -3534.916015625 11197.6474609375 -5804.04541015625 18443.587890625 -8353.96484375 26664.34375 -11147.9931640625 35718.4765625
4 8 2 8 -1278.7371826171875 4052.711181640625 -1240.6412353515625 3944.899658203125 -1187.62353515625 3793.417236328125 -1136.2911376953125 3643.585205078125
4 8 8 4 -486.31298828125 1543.2376708984375 -64.16004180908203 207.02423095703125 -150.71522521972656 48