In [32]:
import numpy as np 
import random
from scipy.optimize import minimize

# generate training and tesing data
Nt = 2
Nr = 4
# generate channel
H = np.sqrt(1/2)*(np.random.randn(Nr,Nt)+1j*np.random.randn(Nr,Nt))

In [33]:
# generate transmit signal
def generate_random_bit_sequence(length):
    return ''.join(random.choice('01') for _ in range(length))

def qam16_modulation(binary_input):
    mapping = {
        '0000': (-3-3j),
        '0001': (-3-1j),
        '0010': (-3+3j),
        '0011': (-3+1j),
        '0100': (-1-3j),
        '0101': (-1-1j),
        '0110': (-1+3j),
        '0111': (-1+1j),
        '1000': (3-3j),
        '1001': (3-1j),
        '1010': (3+3j),
        '1011': (3+1j),
        '1100': (1-3j),
        '1101': (1-1j),
        '1110': (1+3j),
        '1111': (1+1j)
    }
    return mapping.get(binary_input, "Invalid binary input")/np.sqrt(10)

def generate_x_sequence(length, Nt):
    total_bits_sequence = generate_random_bit_sequence(length*Nt*4)
    bits_sequence = [total_bits_sequence[i:i+4] for i in range(0, len(total_bits_sequence), 4)]
    x_sequence = [np.array([qam16_modulation(bits_sequence[i]), qam16_modulation(bits_sequence[i+1])]) for i in range(0, len(bits_sequence), Nt)]
    return bits_sequence, x_sequence

# noise
SNR_dB = 10
SNR = 10.0**(SNR_dB/10.0)
def generate_noise(SNR, Nr):
    return np.sqrt(1/(2*SNR))*(np.random.randn(Nr,1)+1j*np.random.randn(Nr,1))

# generate training and tesing data
def generate_data(Nr,Nt,SNR,length,H_channel):
    bits_sequence, x_sequence = generate_x_sequence(length, Nt)
    n_sequence = [generate_noise(SNR, Nr) for i in range(length)]
    y_sequence = [np.dot(H_channel, x_sequence[i].reshape(Nt,1)) + n_sequence[i] for i in range(length)]
    return bits_sequence, x_sequence, y_sequence

training_length = 1000
bits_sequence, x_sequence, y_sequence = generate_data(Nr,Nt,SNR,training_length,H)

In [34]:
def bits2signals(bits):
    # bits: input binary string with length of (4*Nt) 
    return np.array([qam16_modulation(bits[i:i+4]) for i in range(0, len(bits), 4)]).reshape(Nt,1)
def calculate_layer1(H_hat, y):
    dimension_layer1 = 2**(4*Nt)
    output = {}
    for index in range(dimension_layer1):
        bits = str(bin(index)[2:].zfill(4*Nt))
        s = bits2signals(bits)
        error = y - np.dot(H_hat,s)
        value =  np.exp(-np.square(np.linalg.norm(error)))
        output[bits] = value
    return output

def calculate_layer2(layer1_output):
    sum_exp = [[0 for i in range(2)] for j in range(4*Nt)]
    for bits in layer1_output:
        value = layer1_output[bits]
        for index in range(4*Nt):
            sum_exp[index][eval(bits[index])] += value
    output = {}
    for index in range(4*Nt):
        # llr = np.log(sum_exp[index][1]/sum_exp[index][0])
        output[index] = (sum_exp[index][1])/(sum_exp[index][1]+sum_exp[index][0])
    return output

def calculate_cross_entropy(layer2_output, true_sequence):
    dimension = len(true_sequence)
    entropy = 0
    for index in range(dimension):
        if true_sequence[index] == '1':
            entropy += (-np.log(layer2_output[index]))
    return entropy

def calculate_square_error(layer2_output, true_sequence):
    dimension = len(true_sequence)
    loss = 0
    for index in range(dimension):
        if true_sequence[index] == '1':
            loss += np.square(1-layer2_output[index])
        else:
            loss += np.square(layer2_output[index])
    return loss

def calculate_cost_function(H_hat_vec):
    H_hat = H_hat_vec[0:Nr*Nt].reshape(Nr,Nt)+1j*H_hat_vec[Nr*Nt:2*Nr*Nt].reshape(Nr,Nt)
    # H_hat = H_hat_vec
    total_loss = 0
    for ii in range(training_length):
        layer1_output = calculate_layer1(H_hat, y_sequence[ii])
        layer2_output = calculate_layer2(layer1_output)
        true_sequence = ''.join(bits_sequence[ii*Nt+jj] for jj in range(Nt))
        total_loss += calculate_square_error(layer2_output,true_sequence)
    mean_loss = total_loss/training_length
    print(mean_loss)
    return mean_loss
        
# calculate_cost_function(H)

In [35]:
def detection(y, H_trained):
    layer1_output = calculate_layer1(H_trained, y)
    layer2_output = calculate_layer2(layer1_output)
    detect_result = ''
    for ii in range(len(layer2_output)):
        if(layer2_output[ii]>0.5):
            detect_result += '1'
        else:
            detect_result += '0'
    return(detect_result)

def count_differences(str1, str2):
    return sum(a != b for a, b in zip(str1, str2))

def calculate_BER(H_trained):
    # tesing set
    testing_length = 1000
    bits_sequence_testing, x_sequence_testing, y_sequence_testing = generate_data(Nr,Nt,SNR,testing_length,H)
    error = 0
    for ii in range(len(y_sequence_testing)):
        detect_result = detection(y_sequence_testing[ii], H_trained)
        true_sequence = ''.join(bits_sequence_testing[ii*Nt+jj] for jj in range(Nt))
        error += count_differences(detect_result, true_sequence)
    BER = error/(len(y_sequence_testing)*len(detect_result))
    return BER

In [36]:
for iternum in 20*np.ones(5):
    if iternum == 10:
        ini = np.sqrt(1/2)*(np.random.randn(Nr*Nt*2))
    else:
        ini = out.x
    out = minimize(calculate_cost_function, x0=ini, method="COBYLA", options={'maxiter':iternum,'catol':1})
    H_trained = out.x[0:Nr*Nt].reshape(Nr,Nt)+1j*out.x[Nr*Nt:2*Nr*Nt].reshape(Nr,Nt)
    BER = calculate_BER(H_trained)
    print(iternum, BER)

2.9239424921508324
3.231200233623111
3.019891228400817
2.9297268520267066
3.153884977456572
2.6351420244454964
2.7014225012129156
2.405309206261874
2.4941664400626378
2.820351131689049
2.182016370566822
2.7420720510234777
2.282103068866274
2.8333310030590404
2.0677485614248368
2.1938879472750323
2.2814396235239602
1.6162676939594458
1.6528550792279495
1.6209706465873321
20.0 0.273625
1.6162676939594458
1.7417940545330608
1.7619405681715439
1.9005919494981363
1.8368382478513026
1.6829197373356366
1.4835125411748122
1.5657924889778647
1.5162223232935768
1.620242197801353
1.546190985121788
1.8326250147209064
1.7788868348720868
1.9770845500749967
1.4892387442935493
1.6566903471857723
1.7108223531934588
1.2985599240226229
1.470888067854181
1.3436464077400887
20.0 0.23475
1.2985599240226229
1.3641582189525836
1.436335612626584
1.4901888898632496
1.4263210690862642
1.365881872857748
1.429747131705114
1.3465814715578788
1.2602988946853362
1.3852828693043722
1.2331868036198683
1.309519494290331

In [37]:
H

array([[-0.56353513-0.04431127j,  0.05779396+0.43643596j],
       [ 0.03654535-1.3070904j , -1.04743208-0.46166622j],
       [ 0.51703921-0.80300923j,  0.78725538+0.35021508j],
       [-0.08077629-0.77437306j, -0.34218109+0.13608486j]])

In [38]:
out.x

array([-0.9357165 ,  0.10202284,  0.24899947, -0.72990454,  0.86668714,
        1.09658976, -0.17725806, -0.42535233,  0.3867819 ,  0.64961349,
       -0.54642506, -0.630602  , -0.66332603,  0.29110641, -0.68560994,
        0.06957675])