In [49]:
#!/usr/bin/env python
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time as tm
import math
import sys
import pickle as pkl
import matplotlib.pyplot as plt

from oampnet import oampnet
from sample_generator import sample_generator
# from matrix_models import * 
from utils import *

torch.manual_seed(123)
np.random.seed(123)

#Parameters of the system
NT = 32
NR = 64
mod_n = 64


#Parametes of the model
num_layers = 10
train_iter = 10000
train_batch_size = 1000
test_batch_size = 5000
mini_validtn_batch_size = 5000

learning_rate = 1e-3

#Flags and parameters for signal generation and testing
QR = True
test_set_flag = True
corr_flag = True
batch_corr = True
rho_low = 0.55
rho_high = 0.75
rho = 0.6

save_interim_model = False
save_to_file = False

model_filename = './validtn_results/oampnet_fullcorr_q' + str(mod_n) + '_' + str(NT) + '_' + str(NR) + '.pth'
curr_accr = './validtn_results/curr_accr_' + str(NT) + '.txt'

#Load test set and covariance matrix
# with open ('/home/nicoz/MMNet-tests/learning_based/Torch/Tests_sets/Test_set_6T_24R', 'rb') as fp:
#     R_test, Cu = pkl.load(fp)
    
validtn_NT_list = np.asarray([NT])

snrdb_list = {16:np.arange(11.0, 22.0), 32:np.arange(23.0, 29.0)}

def sym_accuracy(out, j_indices):
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def sym_detection(x_hat, j_indices, real_QAM_const, imag_QAM_const):
    real_QAM_const = real_QAM_const.to(device=x_hat.device)
    imag_QAM_const = imag_QAM_const.to(device=x_hat.device)
    x_real, x_imag = torch.chunk(x_hat, 2, dim=-1)
    x_real = x_real.unsqueeze(dim=-1).expand(-1,-1, real_QAM_const.numel())
    x_imag = x_imag.unsqueeze(dim=-1).expand(-1, -1, imag_QAM_const.numel())

    x_real = torch.pow(x_real - real_QAM_const, 2)
    x_imag = torch.pow(x_imag - imag_QAM_const, 2)
    x_dist = x_real + x_imag
    x_indices = torch.argmin(x_dist, dim=-1)

    return x_indices

def generate_big_validtn_data(generator, batch_size, corr_flag, rho, batch_corr, rho_low, rho_high):
    validtn_data_dict = {int(NT):{} for NT in validtn_NT_list}
    with open('/home/nicolas/MIMO_detection_project/H_5000bs_3264', 'rb') as fp:
        big_validtn_H = pkl.load(fp)
    for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
            big_validtn_y, _, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data_Hinput(big_validtn_H, int(NT), snr_db_min=snr, snr_db_max=snr, batch_size = batch_size)
#             big_validtn_H, big_validtn_y, _, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
            validtn_data_dict[int(NT)][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict


def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, big_noise_sigma, real_QAM_const, imag_QAM_const, device):

    with torch.no_grad():
        H = validtn_H.to(device=device)
        y = validtn_y.to(device=device)
        noise_sigma = big_noise_sigma.to(device=device)

        list_batch_x_predicted = model.forward(H, y, noise_sigma)
        validtn_out = list_batch_x_predicted[-1].to(device='cpu')
        indices_oampnet = sym_detection(validtn_out, validtn_j_indices, real_QAM_const, imag_QAM_const)
        accr = sym_accuracy(indices_oampnet, validtn_j_indices)

        del H, y, noise_sigma, list_batch_x_predicted

    return accr

def mini_validation(model, mini_validation_dict, i, device, real_QAM_const, imag_QAM_const, save_to_file=False):
    result_dict = {int(NT):{} for NT in validtn_NT_list}
    for index,NT in enumerate(validtn_NT_list):
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = mini_validation_dict[NT][snr]
            accr = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, real_QAM_const, imag_QAM_const, device)
            result_dict[NT][snr] = 1. - accr

    if (save_to_file):
        with open(curr_accr, 'w') as f:
            print((i, result_dict), file=f)
            print('Intermediate validation results stored at : ', curr_accr)

    return result_dict


def loss_fn(x, list_batch_x_predicted, j_indices, real_QAM_const, imag_QAM_const, criterion, ser_only=False):
    if (ser_only):
        SER_final = sym_detection(list_batch_x_predicted[-1], j_indices, real_QAM_const, imag_QAM_const)
        return SER_final
    else:
        x_out = torch.cat(list_batch_x_predicted, dim=0)
        x = x.repeat(num_layers, 1)
        loss = criterion(x_out, x)
        SER_final = sym_detection(list_batch_x_predicted[-1], j_indices, real_QAM_const, imag_QAM_const)
        return loss, SER_final


def train(model, optimizer, generator, device='cpu'):

    mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, corr_flag, rho, batch_corr, rho_low, rho_high)

    criterion = nn.MSELoss().to(device=device)
    model.train()
    real_QAM_const = generator.real_QAM_const.to(device=device)
    imag_QAM_const = generator.imag_QAM_const.to(device=device)

    for i in range(train_iter):
        H, y, x, j_indices, noise_sigma = generator.give_batch_data(NT, snr_db_min=snrdb_list[NT][0], snr_db_max=snrdb_list[NT][-1], batch_size=train_batch_size, correlated_flag=corr_flag, rho=rho)
        H = H.to(device=device).double()
        y = y.to(device=device).double()
        noise_sigma = noise_sigma.to(device=device).double()

        list_batch_x_predicted = model.forward(H, y, noise_sigma)

        x = x.to(device=device).double()
        j_indices = j_indices.to(device=device)

        loss, SER = loss_fn(x, list_batch_x_predicted, j_indices, real_QAM_const, imag_QAM_const, criterion)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del H, y, x, j_indices, noise_sigma, list_batch_x_predicted

        if (i%1000==0):
            print('iteration number : ', i, 'User : ', NT, 'loss : ', loss.item())
            print('Now validating')

            model.eval()
            mini_validtn_result = mini_validation(model, mini_validation_dict, i, device, real_QAM_const, imag_QAM_const, save_to_file)
            print('Mini validation result : ', mini_validtn_result)

            model.train()
            if (save_interim_model):
                torch.save(model.state_dict(), model_filename)
                print('********Model Saved******* at directory : ', model_filename)


In [47]:
device = 'cuda'
generator = sample_generator(train_batch_size, mod_n, NR)
model = oampnet(num_layers, generator.constellation, generator.real_QAM_const, generator.imag_QAM_const, device=device)
model = model.to(device=device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train(model, optimizer, generator, device)
print('******************************** Now Testing **********************************************')

iteration number :  0 User :  32 loss :  0.0007897598978224027
Now validating
Mini validation result :  {32: {23.0: 0.94064375, 24.0: 0.975275, 25.0: 0.99151875, 26.0: 0.99686875, 27.0: 0.999225, 28.0: 0.99984375}}
iteration number :  1000 User :  32 loss :  0.0007239999860287529
Now validating
Mini validation result :  {32: {23.0: 0.9557875, 24.0: 0.9854125, 25.0: 0.996425, 26.0: 0.99905, 27.0: 0.99985, 28.0: 0.99998125}}
iteration number :  2000 User :  32 loss :  0.0005843981632409133
Now validating
Mini validation result :  {32: {23.0: 0.95653125, 24.0: 0.98585, 25.0: 0.9965625, 26.0: 0.99914375, 27.0: 0.99985625, 28.0: 1.0}}
iteration number :  3000 User :  32 loss :  0.0006959515238682098
Now validating
Mini validation result :  {32: {23.0: 0.956425, 24.0: 0.9864375, 25.0: 0.99691875, 26.0: 0.9992125, 27.0: 0.99989375, 28.0: 0.99999375}}
iteration number :  4000 User :  32 loss :  0.0007278761741643484
Now validating
Mini validation result :  {32: {23.0: 0.95644375, 24.0: 0.98624

KeyboardInterrupt: 

In [70]:
from collections import defaultdict

validtn_batch_size = 1
validtn_iter = 5

corr_list = np.asarray([0.60, 0.60])

def generate_big_validtn_data(generator, batch_size):
    validtn_data_dict = {int(NT):{rho:{} for rho in corr_list} for NT in validtn_NT_list}
#     with open('/home/nicolas/MIMO_detection_project/H_5000bs_3264', 'rb') as fp:
#         big_validtn_H = pkl.load(fp)
    for NT in validtn_NT_list:
        for rho in corr_list:
            for snr in snrdb_list[NT]:
#                 big_validtn_y, _, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data_Hinput(big_validtn_H, int(NT), snr_db_min=snr, snr_db_max=snr, batch_size = batch_size)
                big_validtn_H, big_validtn_y, _, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
                validtn_data_dict[int(NT)][rho][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict


def validate_oampnet(model, generator, device, real_QAM_const, imag_QAM_const, save_result=True):
    result_dict = {int(NT):{rho:defaultdict(float) for rho in corr_list} for NT in validtn_NT_list}
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size)
        for NT in validtn_NT_list:
            for rho in corr_list:
                for snr in snrdb_list[NT]:
                    big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][rho][snr]
                    accr = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, real_QAM_const, imag_QAM_const, device)
                    result_dict[NT][rho][snr] =  result_dict[NT][rho][snr] + (accr-result_dict[NT][rho][snr])/float(iter+1.0)

        if (save_result):
            with open(oampnet_validtn_filename, 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('Big validation results saved at directory : ', oampnet_validtn_filename)
        print('Big Validation resut, Accr for ' + str(NT) + ' : ', result_dict[NT])


def test(model, generator, device):
    model.eval()

    # Testing Trained Network
    validate_oampnet(model, generator, device, generator.real_QAM_const, generator.imag_QAM_const, False)

# test_batch_size = 5000
# mini_validation_dict = generate_big_validtn_data(generator, test_batch_size, corr_flag, rho, batch_corr, rho_low, rho_high)


# real_QAM_const = generator.real_QAM_const.to(device=device)
# imag_QAM_const = generator.imag_QAM_const.to(device=device)
# mini_validtn_result = mini_validation(model, mini_validation_dict, 0, device, real_QAM_const, imag_QAM_const, save_to_file)
# print('Mini validation result : ', mini_validtn_result)

generator = sample_generator(validtn_batch_size, mod_n, NR)
device = 'cuda'
model = oampnet(num_layers, generator.constellation, generator.real_QAM_const, generator.imag_QAM_const, device=device)
model = model.to(device=device)
# model.load_state_dict(torch.load(model_filename))
print('*******Successfully loaded pre-trained model***********')

test(model, generator, device)

*******Successfully loaded pre-trained model***********
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {23.0: 0.90625, 24.0: 0.9375, 25.0: 1.0, 26.0: 1.0, 27.0: 1.0, 28.0: 1.0})}
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {23.0: 0.9765625, 24.0: 0.984375, 25.0: 1.0, 26.0: 1.0, 27.0: 1.0, 28.0: 1.0})}
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {23.0: 0.954861111111111, 24.0: 0.9756944444444444, 25.0: 0.9652777777777778, 26.0: 1.0, 27.0: 1.0, 28.0: 1.0})}
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {23.0: 0.90625, 24.0: 0.986328125, 25.0: 0.98046875, 26.0: 1.0, 27.0: 1.0, 28.0: 1.0})}
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {23.0: 0.92875, 24.0: 0.99125, 25.0: 0.9875, 26.0: 1.0, 27.0: 1.0, 28.0: 1.0})}


In [68]:
1 - 0.93923125, 1 -  0.97480625, 1 -  0.99055, 1 - 0.99725, 1 -  0.99931875, 1 -  0.9998

(0.06076875000000004,
 0.025193750000000015,
 0.009449999999999958,
 0.00275000000000003,
 0.0006812499999999666,
 0.00019999999999997797)

In [20]:
batch_size = 100
time_seq = 5
H0 = torch.empty((batch_size, 2 * NR, 2 * NT))
H1 = torch.empty((batch_size, 2 * NR, 2 * NT))
H2 = torch.empty((batch_size, 2 * NR, 2 * NT))
H3 = torch.empty((batch_size, 2 * NR, 2 * NT))
H4 = torch.empty((batch_size, 2 * NR, 2 * NT))

with open('/home/nicolas/MIMO_detection_project/HyperMIMO/rho_model_kron/H_test', 'rb') as fp:
    H = pkl.load(fp)
for ii in range(0, batch_size):
    H0[ii] = H[0 + ii * time_seq:1 + ii*time_seq,:,:]
    H1[ii] = H[1 + ii * time_seq:2 + ii*time_seq,:,:]
    H2[ii] = H[2 + ii * time_seq:3 + ii*time_seq,:,:]
    H3[ii] = H[3 + ii * time_seq:4 + ii*time_seq,:,:]
    H4[ii] = H[4 + ii * time_seq:5 + ii*time_seq,:,:]
    
generator = sample_generator(train_batch_size, mod_n, NR)
H = H4.repeat_interleave(5, dim=0)
print('**************************Starting testing*******************************************')
# accs_NN = model_eval(H_test, H_inv, H_tilde, model, snrdb_classical_list[NT][0], snrdb_classical_list[NT][-1], test_batch_size, generator, 'cuda', iterations=150)
accs_NN = model_eval(NT, model, snrdb_list[NT][0], snrdb_list[NT][-1], train_batch_size, generator, 'cuda', test_set_flag = True, test_set = H, QR = QR, iterations = 500)
results_total.append(accs_NN)
accs_NN

**************************Starting testing*******************************************


[(5.0, 0.053140000000000076),
 (6.0, 0.03411399999999842),
 (7.0, 0.02009999999999812),
 (8.0, 0.011774000000000395),
 (9.0, 0.006613999999997899),
 (10.0, 0.003635999999997641),
 (11.0, 0.0017639999999968792),
 (12.0, 0.0008819999999969408),
 (13.0, 0.0004919999999976055),
 (14.0, 0.00022599999999872722)]

In [21]:
with open('/home/nicolas/MIMO_detection_project/results/H_seq_oampnet_time', 'wb') as fp:
    pkl.dump(results_total, fp)