Skip to content

Commit f5464d8

Browse files
committed
Initial commit
1 parent 2042cc4 commit f5464d8

17 files changed

+2054
-1
lines changed

README.md

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,55 @@
1-
# sysid-transfer-functions-pytorch
1+
# Transfer functions and deep learning with dynoNet : new applications in system
2+
identification
3+
4+
5+
This repository contains the Python code to reproduce the results of the paper "Transfer functions and deep learning with dynoNet : new applications in system
6+
identification" by Marco Forgione and Dario Piga.
7+
8+
We describe the linear dynamical operator as a differentiable layer compatible with back-propagation-based training.
9+
The operator is parametrized as a rational transfer function and thus can represent an infinite impulse response (IIR)
10+
filtering operation, as opposed to the Convolutional layer of 1D-CNNs that is equivalent to finite impulse response (FIR) filtering.
11+
12+
In the dynoNet architecture (already introduced [here](https://github.com/forgi86/dynonet)), linear dynamical operators are combined with static (i.e., memoryless) non-linearities which can be either elementary
13+
activation functions applied channel-wise; fully connected feed-forward neural networks; or other differentiable operators.
14+
15+
In this work, we show how to non-standard learning problems may be tackled using the differentiable
16+
transfer function block, namely:
17+
18+
* Learning with quantized measurements
19+
* Learning in the presence of colored noise
20+
21+
# Folders:
22+
* [torchid](torchid_nb): PyTorch implementation of the linear dynamical operator (aka G-block in the paper) used in dynoNet
23+
* [examples](examples): examples using dynoNet for system identification
24+
* [util](util): definition of metrics R-square, RMSE, fit index
25+
26+
Two [examples](examples) discussed in the paper are:
27+
28+
* [Parallel Wiener-Hammerstein](examples/ParWH): A circuit with Wiener-Hammerstein behavior. Experimental dataset from http://www.nonlinearbenchmark.org
29+
* [BW](examples/BW): Bouc-Wen. A nonlinear dynamical system describing hysteretic effects in mechanical engineering. Experimental dataset from http://www.nonlinearbenchmark.org
30+
31+
32+
For the [WH2009](examples/WH2009) example, the main scripts are:
33+
34+
* ``WH2009_train.py``: Training of the dynoNet model
35+
* ``WH2009_test.py``: Evaluation of the dynoNet model on the test dataset, computation of metrics.
36+
37+
Similar scripts are provided for the other examples.
38+
39+
NOTE: the original data sets are not included in this project. They have to be manually downloaded from
40+
http://www.nonlinearbenchmark.org and copied in the data sub-folder of the example.
41+
# Software requirements:
42+
Simulations were performed on a Python 3.7 conda environment with
43+
44+
* numpy
45+
* scipy
46+
* matplotlib
47+
* pandas
48+
* pytorch (version 1.4)
49+
50+
These dependencies may be installed through the commands:
51+
52+
```
53+
conda install numpy scipy pandas matplotlib
54+
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
55+
```

examples/ParWH/__init__.py

Whitespace-only changes.

examples/ParWH/models.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from torchid_nb.module.lti import MimoLinearDynamicalOperator, SisoLinearDynamicalOperator
3+
from torchid_nb.module.static import MimoStaticNonLinearity, MimoChannelWiseNonLinearity
4+
5+
6+
class ParallelWHNet(torch.nn.Module):
7+
def __init__(self):
8+
super(ParallelWHNet, self).__init__()
9+
self.nb_1 = 12
10+
self.na_1 = 12
11+
self.nb_2 = 13
12+
self.na_2 = 12
13+
self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1)
14+
self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10)
15+
#self.F_nl = MimoStaticNonLinearity(2, 2, n_hidden=10)
16+
self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=0)
17+
#self.G3 = SisoLinearDynamicalOperator(n_b=3, n_a=3, n_k=1)
18+
19+
def forward(self, u):
20+
y1_lin = self.G1(u)
21+
y1_nl = self.F_nl(y1_lin) # B, T, C1
22+
y2_lin = self.G2(y1_nl) # B, T, C2
23+
24+
return y2_lin #+ self.G3(u)
25+
26+
27+
class ParallelWHNetVar(torch.nn.Module):
28+
def __init__(self):
29+
super(ParallelWHNetVar, self).__init__()
30+
self.nb_1 = 3
31+
self.na_1 = 3
32+
self.nb_2 = 3
33+
self.na_2 = 3
34+
self.G1 = MimoLinearDynamicalOperator(1, 16, n_b=self.nb_1, n_a=self.na_1, n_k=1)
35+
self.F_nl = MimoStaticNonLinearity(16, 16) #MimoChannelWiseNonLinearity(16, n_hidden=10)
36+
self.G2 = MimoLinearDynamicalOperator(16, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1)
37+
38+
def forward(self, u):
39+
y1_lin = self.G1(u)
40+
y1_nl = self.F_nl(y1_lin) # B, T, C1
41+
y2_lin = self.G2(y1_nl) # B, T, C2
42+
43+
return y2_lin
44+
45+
46+
class ParallelWHResNet(torch.nn.Module):
47+
def __init__(self):
48+
super(ParallelWHResNet, self).__init__()
49+
self.nb_1 = 4
50+
self.na_1 = 4
51+
self.nb_2 = 4
52+
self.na_2 = 4
53+
self.G1 = MimoLinearDynamicalOperator(1, 2, n_b=self.nb_1, n_a=self.na_1, n_k=1)
54+
self.F_nl = MimoChannelWiseNonLinearity(2, n_hidden=10)
55+
self.G2 = MimoLinearDynamicalOperator(2, 1, n_b=self.nb_2, n_a=self.na_2, n_k=1)
56+
self.G3 = SisoLinearDynamicalOperator(n_b=6, n_a=6, n_k=1)
57+
58+
def forward(self, u):
59+
y1_lin = self.G1(u)
60+
y1_nl = self.F_nl(y1_lin) # B, T, C1
61+
y2_lin = self.G2(y1_nl) # B, T, C2
62+
63+
return y2_lin + self.G3(u)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pandas as pd
2+
import numpy as np
3+
import os
4+
import matplotlib.pyplot as plt
5+
6+
if __name__ == '__main__':
7+
8+
N = 16384 # number of samples per period
9+
M = 20 # number of random phase multisine realizations
10+
P = 2 # number of periods
11+
nAmp = 5 # number of different amplitudes
12+
13+
# Column names in the dataset
14+
COL_F = ['fs']
15+
TAG_U = 'u'
16+
TAG_Y = 'y'
17+
18+
# Load dataset
19+
#df_X = pd.read_csv(os.path.join("data", "WH_CombinedZeroMultisineSinesweep.csv"))
20+
df_X = pd.read_csv(os.path.join("data", "ParWHData_Estimation_Level2.csv"))
21+
df_X.columns = ['amplitude', 'fs', 'lines'] + [TAG_U + str(i) for i in range(M)] + [TAG_Y + str(i) for i in range(M)] + ['?']
22+
23+
# Extract data
24+
y = np.array(df_X['y0'], dtype=np.float32)
25+
u = np.array(df_X['u0'], dtype=np.float32)
26+
fs = np.array(df_X[COL_F].iloc[0], dtype = np.float32)
27+
N = y.size
28+
ts = 1/fs
29+
t = np.arange(N)*ts
30+
31+
32+
# In[Plot]
33+
fig, ax = plt.subplots(2, 1, sharex=True)
34+
ax[0].plot(t, y, 'k', label="$y$")
35+
ax[0].legend()
36+
ax[0].grid()
37+
38+
ax[1].plot(t, u, 'k', label="$u$")
39+
ax[1].legend()
40+
ax[1].grid()
41+

examples/ParWH/parWH_test.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import pandas as pd
2+
import numpy as np
3+
import os
4+
import matplotlib.pyplot as plt
5+
import torch
6+
import torch.nn as nn
7+
import control
8+
from torchid_nb.module.lti import MimoLinearDynamicalOperator
9+
from torchid_nb.module.static import MimoStaticNonLinearity
10+
import util.metrics
11+
from examples.ParWH.models import ParallelWHNet
12+
13+
14+
if __name__ == '__main__':
15+
16+
model_name = "PWH_quant"
17+
18+
# Dataset constants
19+
amplitudes = 5 # number of different amplitudes
20+
realizations = 20 # number of random phase multisine realizations
21+
samp_per_period = 16384 # number of samples per period
22+
n_skip = 1000
23+
periods = 1 # number of periods
24+
seq_len = samp_per_period * periods # data points per realization
25+
26+
# Column names in the dataset
27+
TAG_U = 'u'
28+
TAG_Y = 'y'
29+
30+
test_signal = "100mV" #"ramp" #ramp"#"320mV" #"1000mV"#"ramp"
31+
32+
33+
# In[Load dataset]
34+
35+
dict_test = {"100mV": 0, "320mV": 1, "550mV": 2, "775mV": 3, "1000mV": 4, "ramp": 5}
36+
dataset_list_level = ['ParWHData_Validation_Level' + str(i) for i in range(1, amplitudes + 1)]
37+
dataset_list = dataset_list_level + ['ParWHData_ValidationArrow']
38+
39+
df_X_lst = []
40+
for dataset_name in dataset_list:
41+
dataset_filename = dataset_name + '.csv'
42+
df_Xi = pd.read_csv(os.path.join("data", dataset_filename))
43+
df_X_lst.append(df_Xi)
44+
45+
46+
df_X = df_X_lst[dict_test[test_signal]] # first
47+
48+
# Extract data
49+
y_meas = np.array(df_X['y'], dtype=np.float32)
50+
u = np.array(df_X['u'], dtype=np.float32)
51+
fs = np.array(df_X['fs'].iloc[0], dtype=np.float32)
52+
N = y_meas.size
53+
ts = 1/fs
54+
t = np.arange(N)*ts
55+
56+
# In[Set-up model]
57+
58+
net = ParallelWHNet()
59+
model_folder = os.path.join("models", model_name)
60+
net.load_state_dict(torch.load(os.path.join(model_folder, f"{model_name}.pt")))
61+
62+
# In[Predict]
63+
u_torch = torch.tensor(u[None, :, None], dtype=torch.float, requires_grad=False)
64+
65+
with torch.no_grad():
66+
y_hat = net(u_torch)
67+
68+
# In[Detach]
69+
70+
y_hat = y_hat.detach().numpy()[0, :, 0]
71+
72+
# In[Plot]
73+
fig, ax = plt.subplots(2, 1, sharex=True)
74+
ax[0].plot(t, y_meas, 'k', label="$y$")
75+
ax[0].plot(t, y_hat, 'r', label="$\hat y$")
76+
ax[0].plot(t, y_meas - y_hat, 'g', label="$e$")
77+
ax[0].legend()
78+
ax[0].grid()
79+
80+
ax[1].plot(t, u, 'k', label="$u$")
81+
ax[1].legend()
82+
ax[1].grid()
83+
84+
# In[Inspect linear model]
85+
86+
# First linear block
87+
# a_coeff_1 = net.G1.a_coeff.detach().numpy()
88+
# b_coeff_1 = net.G1.b_coeff.detach().numpy()
89+
# a_poly_1 = np.empty_like(a_coeff_1, shape=(2, 2, net.na_1 + 1))
90+
# a_poly_1[:, :, 0] = 1
91+
# a_poly_1[:, :, 1:] = a_coeff_1[:, :, :]
92+
# b_poly_1 = np.array(b_coeff_1)
93+
# G1_sys = control.TransferFunction(b_poly_1, a_poly_1, ts)
94+
#
95+
# plt.figure()
96+
# mag_G1_1, phase_G1_1, omega_G1_1 = control.bode(G1_sys[0, 0])
97+
# plt.figure()
98+
# mag_G1_2, phase_G1_2, omega_G1_2 = control.bode(G1_sys[1, 0])
99+
#
100+
# # Second linear block
101+
# a_coeff_2 = net.G2.a_coeff.detach().numpy()
102+
# b_coeff_2 = net.G2.b_coeff.detach().numpy()
103+
# a_poly_2 = np.empty_like(a_coeff_2, shape=(2, 1, net.na_2 + 1))
104+
# a_poly_2[:, :, 0] = 1
105+
# a_poly_2[:, :, 1:] = a_coeff_2[:, :, :]
106+
# b_poly_2 = np.array(b_coeff_2)
107+
# G2_sys = control.TransferFunction(b_poly_2, a_poly_2, ts)
108+
109+
# plt.figure()
110+
# mag_G2_1, phase_G2_1, omega_G2_1 = control.bode(G2_sys[0, 0])
111+
# plt.figure()
112+
# mag_G2_2, phase_G2_2, omega_G2_2 = control.bode(G2_sys[0, 1])
113+
114+
115+
# In[Metrics]
116+
117+
idx_test = range(n_skip, N)
118+
119+
e_rms = 1000*util.metrics.error_rmse(y_meas[idx_test], y_hat[idx_test])
120+
mae = 1000 * util.metrics.error_mae(y_meas[idx_test], y_hat[idx_test])
121+
fit_idx = util.metrics.fit_index(y_meas[idx_test], y_hat[idx_test])
122+
r_sq = util.metrics.r_squared(y_meas[idx_test], y_hat[idx_test])
123+
u_rms = 1000*util.metrics.error_rmse(u, 0)
124+
125+
print(f"RMSE: {e_rms:.2f}mV\nMAE: {mae:.2f}mV\nFIT: {fit_idx:.1f}%\nR_sq: {r_sq:.1f}\nRMSU: {u_rms:.2f}mV")

0 commit comments

Comments
 (0)