In [381]:
import numpy as np
np.random.seed(5)

def compute_poly(x, m, p, poly_power_limit):
    """
    x : m x 1 complex vector
    PolyW : polynomial layer weights
    m : input dimension
    p : polynomial degree
    poly_power_limit : polynomial power limit

    return : a p x m complex vector
    """
    # out put : a nx1 complex vector

    # polynomial 
    u = np.zeros((p, m), dtype=np.complex128)
    for j in range(p): 
        u[j] = np.power(np.abs(x) * poly_power_limit,2*j) * x 
    return u
def compute_poly_sum(u, PolyW):
    """
    u : p x m complex vector
    PolyW : polynomial layer weights
    return : a m x 1 complex vector
    """
    return np.dot(PolyW, u)

def compute_conv(Us, ConvW):
    """
    Us : k x m  complex vector
    ConvW : convolutional layer weights : k x m x n
    return : a 1 x n complex vector
    """
    k, m, n = ConvW.shape ##
    C = np.zeros((n), dtype=np.complex128)

    for i in range(k):
        C += np.dot(Us[i], ConvW[i])
    return C
    

In [382]:
### generate complex data 

def generate_data(n_samples, true_coefficients, noise_level=0.1, poly_power_limit=0.05):

    # Generate a random input signal
    k,m, n = true_coefficients[0].shape
    p = true_coefficients[1].shape[0]
    ConvW = true_coefficients[0]
    PolyW = true_coefficients[1]

    x = np.random.randn(n_samples, m) + 1j*np.random.randn(n_samples, m)

    # Step 1: Compute the polynomial features: in n_samples x m , out n_samples x p x m
    U = np.zeros((n_samples, p, m), dtype=np.complex128)
    for i in range(n_samples):
        U[i] = compute_poly(x[i], m, p, poly_power_limit)
    # Step 2: Compute the polynomial sum : in n_samples x p x m, out n_samples x m
    Us = np.zeros((n_samples, m), dtype=np.complex128)
    for i in range(n_samples):
        Us[i] = compute_poly_sum(U[i], PolyW)

    # Step 3: Compute the convolution : in n_samples x k x m, out n_samples x n
    y = np.zeros((n_samples, n), dtype=np.complex128)
    for i in range(k, n_samples):
        y[i] = compute_conv(Us[i-k+1 : i+1], ConvW)
    
    
    # Add noise
    # noise = noise_level * np.random.randn(n_samples, n) + 1j*noise_level * np.random.randn(n_samples, n)
    # y += noise
    
    return x, y


In [383]:
# Example usage
n_samples = 10000

# p : polynomial degree 
# k : influence scope 
# m : output dimension  
# n : input dimension 
p,k,m,n = 3, 5, 3, 2

true_coefficients = [0,1]

### part 1 : polynomial layer
true_coefficients[1] = np.random.randn(p) + 1j*np.random.randn(p)
poly_power_limit = 0.05


### part 2 : conv layer k x m x n
true_coefficients[0] = np.random.randn(k, m, n) + 1j*np.random.randn(k, m, n)
noise_level = 0.05


x, y = generate_data(n_samples, true_coefficients, noise_level,poly_power_limit)

print("Input signal:", x[9])
print("Desired signal:", y[9])

Input signal: [-0.693-1.262j -0.593+0.777j  0.788-1.41j ]
Desired signal: [-0.621+0.989j  1.957+3.844j]


In [384]:
### LMS filter

def lms_filter(x, y, k, mu, p,true_coefficients,epoch = 5):
    Ws = [] 
    m, n = x.shape[1], y.shape[1]

    PolyW = np.random.randn(p) + 1j*np.random.randn(p) ## p 
    ConvW = np.array([np.random.randn(m, n) + 1j*np.random.randn(m, n) for _ in range(k)]) ## k x m x n

    # # backdoor 
    # PolyW = true_coefficients[1]  + 0.1 * np.random.randn(p) + 1j*0.1 * np.random.randn(p)
    # ConvW = true_coefficients[0] + 0.1 * np.random.randn(k, m, n) + 1j*0.1 * np.random.randn(k, m, n)

    poly_power_limit = 0.05

    ## compute polynomial features :  len x p x m
    u = np.zeros((len(x), p, m), dtype=np.complex128)
    for i in range(len(x)):
        u[i] = compute_poly(x[i], m, p, poly_power_limit)
    
    ## iteration
    epoch = 5
    for _ in range(epoch):
        for i in range(k, len(x)):
            ### compute polynomial features : every conv has k poly 

            ## calculate error
            U = np.zeros((k, m), dtype=np.complex128)
            for j in range(k):
                U[k-j-1] = compute_poly_sum(u[i-j], PolyW) ## 1 x m
            y_hat = compute_conv(U, ConvW)
            e = y_hat - y[i]

            if i % 10 == 0:
                print("e_0",e)

            ### update weights
            # PolyW update  
            U = np.zeros((p, m), dtype=np.complex128)
            for j in range(k):
                TMP =  np.dot(np.conj(ConvW[j]).T , np.conj(u[i-k+1+j]).T) ## n x p
                PolyW -= mu  *  np.dot(e, TMP) ## 1 x p 

            ## calculate errror once more
            U = np.zeros((k, m), dtype=np.complex128)
            for j in range(k):
                U[k-j-1] = compute_poly_sum(u[i-j], PolyW) ## 1 x m
            y_hat = compute_conv(U, ConvW)
            e = y_hat - y[i]
            if i % 100== 0:
                print("e_1",e)

            # ConvW update
            for j in range(k):
                U = compute_poly_sum(u[i-k+1+j], PolyW) ## 1 x m
                ConvW[j] -= mu  * np.outer(np.conj(U),e) ## mx1  * 1
            
            ## calculate errror once more
            U = np.zeros((k, m), dtype=np.complex128)
            for j in range(k):
                U[k-j-1] = compute_poly_sum(u[i-j], PolyW) ## 1 x m
            y_hat = compute_conv(U, ConvW)
            e = y_hat - y[i]

            if i % 100 == 0:
                print("e_2",e)
                print("y_hat", y_hat)
                print("y", y[i])

            Ws.append((ConvW.copy(), PolyW.copy()))
    return  Ws


In [385]:
### Example usage
## Hyperparameters

mu = 0.001 ## seed 5 时 0.05崩溃，0.01正常

Ws = lms_filter(x, y, k, mu, p, true_coefficients)

## print in 3 decimal places
np.set_printoptions(precision=3)

print("Estimated coefficients:", Ws[0][1][0])
print("Estimated coefficients:", Ws[5][1][0])
print("Estimated coefficients:", Ws[-1][1][0])
# print("True coefficients:", true_coefficients[0])



e_0 [ 3.279+1.334j -0.305-0.814j]
e_0 [ 0.455+0.987j -0.039+4.47j ]
e_0 [-1.505-2.122j  0.668-2.11j ]
e_0 [0.026+1.723j 1.512+1.656j]
e_0 [-1.631+2.436j -3.214+1.65j ]
e_0 [1.064+2.977j 3.948+1.044j]
e_0 [ 0.31 +0.2j   -0.488-1.437j]
e_0 [-1.649-4.832j -2.451-1.808j]
e_0 [ 0.583+1.627j -0.223+3.966j]
e_0 [1.02 +4.j    0.788+1.023j]
e_1 [1.017+3.967j 0.709+0.946j]
e_2 [1.016+3.965j 0.709+0.946j]
y_hat [0.195-0.075j 0.039-0.844j]
y [-0.822-4.04j -0.67 -1.79j]
e_0 [-0.203-1.074j -4.795+0.183j]
e_0 [ 1.196-0.228j -0.602+1.716j]
e_0 [ 0.457+2.143j -0.175-1.032j]
e_0 [-0.419-0.437j -0.618-0.275j]
e_0 [-0.744-2.695j -0.983-0.225j]
e_0 [-1.789-3.157j  3.12 -3.304j]
e_0 [ 0.129+3.364j -2.268+0.071j]
e_0 [-1.223+1.119j -0.854+4.489j]
e_0 [ 2.531+2.547j -0.525-0.9j  ]
e_0 [-1.441+1.471j  0.105-0.683j]
e_1 [-1.44 +1.389j -0.076-0.611j]
e_2 [-1.44 +1.389j -0.076-0.611j]
y_hat [ 0.452+0.087j -0.204-0.579j]
y [ 1.892-1.302j -0.128+0.032j]
e_0 [0.685+0.918j 1.076+4.066j]
e_0 [0.104-1.537j 1.759-1.909j

In [388]:
Ws[-1]

(array([[[ 1.195-0.426j,  0.57 -1.912j],
         [-0.237+0.151j,  0.177-2.171j],
         [ 1.667+0.207j,  0.363+0.613j]],
 
        [[ 0.469-0.197j, -0.88 -0.378j],
         [ 2.21 -0.656j,  0.729-1.87j ],
         [-1.759-1.343j, -2.604-0.344j]],
 
        [[ 2.105+0.19j , -0.855+0.248j],
         [ 1.085-2.084j,  1.122-0.45j ],
         [ 1.408+1.623j,  0.69 +0.863j]],
 
        [[-1.246+0.955j, -1.164-1.437j],
         [-0.06 +0.153j,  0.102-3.128j],
         [-0.119-0.89j ,  0.122-0.187j]],
 
        [[-0.998+0.728j,  0.685-1.428j],
         [ 0.05 +0.325j,  0.28 +1.1j  ],
         [-0.027-0.78j , -0.34 +0.028j]]]),
 array([-0.36 +0.14j ,  1.712+0.213j, -0.107+1.425j]))

In [387]:
true_coefficients

[array([[[-0.909+0.198j, -0.592+1.335j],
         [ 0.188-0.087j, -0.33 +1.562j],
         [-1.193-0.306j, -0.205-0.478j]],
 
        [[-0.359+0.101j,  0.603+0.355j],
         [-1.665+0.27j , -0.7  +1.292j],
         [ 1.151+1.139j,  1.857+0.494j]],
 
        [[-1.511-0.336j,  0.645-0.101j],
         [-0.981+1.413j, -0.857+0.221j],
         [-0.872-1.311j, -0.423-0.69j ]],
 
        [[ 0.996-0.578j,  0.712+1.152j],
         [ 0.059-0.107j, -0.363+2.26j ],
         [ 0.003+0.657j, -0.106+0.125j]],
 
        [[ 0.793-0.436j, -0.632+0.972j],
         [-0.006-0.241j, -0.101-0.824j],
         [-0.052+0.568j,  0.249+0.013j]]]),
 array([ 0.441-0.252j, -0.331+0.11j ,  2.431+1.582j])]