In [1]:
import numpy as np
import matplotlib.pyplot as plt

from sysgen import sysgen
from sysid import sysid
from syssim import syssim

In [2]:
# System Data
n = 3
p = 2

# Nominal system:
A_0 = np.array([[0.6, 0.5, 0.4],
                [0, 0.4, 0.3],
                [0, 0, 0.3]])

B_0 = np.array([[1, 0.5],
                [0.5, 1],
                [0.5, 0.5]])

V = np.array([[0, 0, 0],
              [0, 1, 0],
              [0, 0, 1]])  # Modification pattern applied to A_0

U = np.array([[1, 0],
              [0, 0],
              [0, 1]])  # Modification pattern applied to B_0

# Noise levels
sigu = 1
sigw = 1
sigx = 1

# Rollout length
T = 5

# Select FL_solver (Fed_Avg = 0, Fed_Lin = 1)
FL_solver = 0

q = 25  # number of estimations
R = 200  # number of global iterations
s = 1  # fixed system for the error computation

M = [1, 2, 5, 25, 100] 
N = 25  # Fixed number of rollouts
epsilon = 0.01  # Fixed dissimilarity

In [3]:
# Generating the system matrices
# Define `sysgen` to generate similar systems
A, B = sysgen(A_0, B_0, V, U, M, epsilon)

E_avg = np.zeros((len(M), R))

# Numerical results varying the number of clients
Error_matrix = np.zeros((q, R))


In [4]:
true_theta = np.hstack([A_0, B_0])
true_theta

array([[0.6, 0.5, 0.4, 1. , 0.5],
       [0. , 0.4, 0.3, 0.5, 1. ],
       [0. , 0. , 0.3, 0.5, 0.5]])

In [5]:
M=100
R=200

In [6]:

n = A[0].shape[0]
p = B[0].shape[1]
s=0


X = []
Z = []
W = []
for i in range(M):
    X_i, Z_i, W_i, list_x0 = syssim(A, B, T, N, i, sigu, sigw, sigx)
    X.append(X_i)
    Z.append(Z_i)
    W.append(W_i)

# FedSysID

# Initialiser le serveur avec \bar{\Theta}_0 et \alpha
Theta_0 = np.hstack([(1/2) * A[s], (1/2) * B[s]])
alpha = 1e-4  # pas d'apprentissage
K = 10  # nombre d'itérations locales


Theta_s = Theta_0.copy()  # serveur
Theta_c = [None] * M  # clients
print(f"longueur de Theta_c : {len(Theta_c)} \n")
Error = np.zeros(R) # vecteur d'erreur  de dim (R,)

for r in range(R):
       
    print(f"Round {r} \n")

    print(f"valeur de theta_s en début de round : \n{Theta_s} \n")

    # Initialiser chaque client avec \bar{\Theta}_0
    for i in range(M):
        Theta_c[i] = Theta_s.copy()
    
    # Côté client :
    for i in range(M):
        # print(f"Coté client, boucle, i vaut : {i} \n")
        if FL_solver == 0:
            # FedAvg
            for k in range(1, K + 1):
                # tmp= Theta_c[i]
                # print(f"dimension de Theta_c[i] vaut : {tmp.shape} \n")
                new_val = Theta_c[i] + (alpha / k) * ((X[i] - Theta_c[i] @ Z[i]) @ Z[i].T)
                # print(f"dimension de new_val : {new_val.shape} \n")
                Theta_c[i] = new_val

    # # Côté serveur :
    Theta_sum = np.zeros((n, n + p))
    for i in range(M):
        Theta_sum = Theta_sum + Theta_c[i]
    Theta_s = (1 / M) * Theta_sum
        

    print(f"valeur de theta_s en fin de round : \n{Theta_s} \n")
    print(f"FIN Round {r} \n ======================== \n")
    Error[r] = np.linalg.norm(Theta_s - true_theta, ord=2) # TODO : askip ici l'erreur c la norme 2 et jsp pk  # DONE : ord=2 impose norme spectale


longueur de Theta_c : 100 

Round 0 

valeur de theta_s en début de round : 
[[0.3        0.25       0.2        0.50361596 0.25      ]
 [0.         0.20354723 0.15       0.25       0.5       ]
 [0.         0.         0.15354723 0.25       0.25361596]] 

valeur de theta_s en fin de round : 
[[0.37594423 0.2966288  0.22721317 0.52188194 0.25904557]
 [0.01721366 0.22486    0.16341509 0.25941865 0.51808802]
 [0.00461743 0.0039315  0.16158082 0.25969196 0.26287208]] 

FIN Round 0 

Round 1 

valeur de theta_s en début de round : 
[[0.37594423 0.2966288  0.22721317 0.52188194 0.25904557]
 [0.01721366 0.22486    0.16341509 0.25941865 0.51808802]
 [0.00461743 0.0039315  0.16158082 0.25969196 0.26287208]] 

valeur de theta_s en fin de round : 
[[0.43456055 0.33351357 0.24931542 0.53949055 0.26775715]
 [0.02958036 0.24281995 0.17498281 0.26849989 0.53552675]
 [0.00791057 0.00700281 0.16892285 0.2690372  0.27179586]] 

FIN Round 1 

Round 2 

valeur de theta_s en début de round : 
[[0.43456055 0.

In [7]:
Error

array([0.8814924 , 0.83127145, 0.78907911, 0.75230941, 0.71932942,
       0.68910784, 0.66098536, 0.63453392, 0.60947022, 0.58560215,
       0.56279549, 0.54095304, 0.52000148, 0.49988315, 0.48055074,
       0.46196398, 0.44408754, 0.4268896 , 0.41034103, 0.39441474,
       0.37908535, 0.3643289 , 0.35012267, 0.33644504, 0.3232754 ,
       0.31059406, 0.29838221, 0.28662185, 0.27529573, 0.26438735,
       0.25388088, 0.24376115, 0.23401363, 0.22462437, 0.21557999,
       0.20686765, 0.19847505, 0.19039036, 0.18260224, 0.1750998 ,
       0.16787259, 0.16091058, 0.15420414, 0.14774401, 0.14152133,
       0.13552756, 0.12975453, 0.12419439, 0.11883959, 0.11368292,
       0.10871743, 0.10393647, 0.09933367, 0.09490291, 0.09063833,
       0.08653433, 0.08258554, 0.07878683, 0.07513329, 0.07162025,
       0.06824323, 0.06499798, 0.06188043, 0.05888674, 0.05601324,
       0.05325646, 0.05061309, 0.04808002, 0.0456543 , 0.04333312,
       0.04111383, 0.03899392, 0.036971  , 0.03504277, 0.03320

In [8]:
Error2 = sysid(A,B,T,N,M,R,sigu,sigx,sigw,FL_solver,s,true_theta)

In [9]:
print(Error2)

[0.88030443 0.82954616 0.78709366 0.7501633  0.71704456 0.6866771
 0.65839506 0.63177329 0.60653452 0.58249299 0.55951992 0.53752232
 0.51642986 0.49618681 0.47674701 0.45807065 0.44012243 0.42287021
 0.4062843  0.39033693 0.37500192 0.36025445 0.3460709  0.33242878
 0.31930659 0.30668379 0.2945407  0.28285853 0.27161924 0.26080558
 0.250401   0.24038966 0.23075636 0.22148653 0.21256621 0.203982
 0.19572104 0.18777101 0.18012008 0.17275691 0.16567059 0.15885068
 0.15228714 0.14597034 0.13989103 0.13404032 0.1284097  0.12299098
 0.11777629 0.11275808 0.1079291  0.10328239 0.09881126 0.0945093
 0.09037032 0.08638843 0.08255792 0.07887334 0.07532945 0.07192123
 0.06864385 0.06549269 0.0624633  0.05955143 0.056753   0.05406411
 0.05148101 0.04900012 0.04661803 0.04433145 0.04213726 0.04003249
 0.03801428 0.03607994 0.03422691 0.03245273 0.03075511 0.02913188
 0.02758098 0.02610052 0.02468879 0.02334442 0.0220671  0.02086613
 0.02031776 0.02045387 0.02060691 0.02076332 0.02092146 0.02108076