In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time as tm
import math
import sys
import pickle as pkl
import matplotlib.pyplot as plt
# from gurobipy import *
from multiprocessing.dummy import Pool as ThreadPool 
from numpy import linalg as LA
from classic_detectors import *
from sample_generator import *
from utils import *
import os *

#parameters
NT = 2
NR = 4

snrdb_list = {16:np.arange(5.0, 15.0), 2:np.arange(5.0, 15.0), 6:np.arange(10.0, 21.0)}
mod_n = 4

corr_flag = True
batch_corr = True
batch_size = 100
time_seq = 5

parallel = False

M = int(np.sqrt(mod_n))
sigConst = np.linspace(-M+1, M-1, M) 
sigConst /= np.sqrt((sigConst ** 2).mean())
sigConst /= np.sqrt(2.) PATH = os.getcwd()

In [None]:
def ZF(NT, snr_min, snr_max, batch_size, repeat_sample, generator, device, Cu = None, H = None, iterations = 50):
    SNR_dBs = np.linspace(np.int(snr_min), np.int(snr_max), np.int(snr_max - snr_min + 1))
    accs_NN = []#np.zeros(shape=SNR_dBs.shape)
    real_QAM_const = generator.real_QAM_const.to(device=device)
    imag_QAM_const = generator.imag_QAM_const.to(device=device)
    bs = batch_size * repeat_sample
    H = torch.tensor(H)
    H = H.repeat_interleave(repeat_sample, dim=0)
    for i in range(SNR_dBs.shape[0]):
        acum = 0
        print(i)
        for jj in range(iterations):
            rho = 0.6
            y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=SNR_dBs[i], snr_db_max=SNR_dBs[i], batch_size=batch_size)

            H = H.to(device=device)
            y = y.to(device=device)
            x = x.to(device=device)
            j_indices = j_indices.to(device=device)      
            noise_sigma = noise_sigma.to(device=device)

            y_ZF = batch_matvec_mul(torch.pinverse(H).double(), y.double())

            SER_final = sym_detection(y_ZF, j_indices, real_QAM_const, imag_QAM_const)
            acum += SER_final
        acum = acum/iterations
        accs_NN.append((SNR_dBs[i], 1. - acum))# += acc[1]/iterations
        print([SNR_dBs[i], 1. - acum])

    return accs_NN

def MMSE(NT, snr_min, snr_max, batch_size, repeat_sample, generator, device, Cu = None, H=None, iterations = 50):
    SNR_dBs = np.linspace(np.int(snr_min), np.int(snr_max), np.int(snr_max - snr_min + 1))
    accs_NN = []
    real_QAM_const = generator.real_QAM_const.to(device=device)
    imag_QAM_const = generator.imag_QAM_const.to(device=device)
    bs = repeat_sample * batch_size

    H = torch.tensor(H)
    H = H.repeat_interleave(repeat_sample, dim=0)
    for i in range(SNR_dBs.shape[0]):
        acum = 0
        print(i)
        for jj in range(iterations):
            y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=SNR_dBs[i], snr_db_max=SNR_dBs[i], batch_size=bs)


            H = H.to(device=device).double()
            y = y.to(device=device).double()
            x = x.to(device=device).double()
            j_indices = j_indices.to(device=device).double()        
            noise_sigma = noise_sigma.to(device=device).double()

            y_MMSE = mmse(y, H, noise_sigma, device).double()

            SER_final = sym_detection(y_MMSE, j_indices, real_QAM_const, imag_QAM_const)
            acum += SER_final
        acum = acum/iterations
        accs_NN.append((SNR_dBs[i], 1. - acum))
        print([SNR_dBs[i], 1. - acum])
        
    return accs_NN


def ML(NT, snr_min, snr_max, batch_size, repeat_sample, generator, device, Cu = None, H=None, iterations = 500):
    SNR_dBs = np.linspace(np.int(snr_min), np.int(snr_max), np.int(snr_max - snr_min + 1))
    accs_NN = []#np.zeros(shape=SNR_dBs.shape)
    real_QAM_const = generator.real_QAM_const.to(device=device)
    imag_QAM_const = generator.imag_QAM_const.to(device=device)
    #num_cores = multiprocessing.cpu_count()
    pool = ThreadPool(40) 
    bs = repeat_sample * batch_size
    H = H.repeat_interleave(repeat_sample, dim=0)
    for i in range(SNR_dBs.shape[0]):
        acum = 0
        print(i)
        for jj in range(iterations):
            H = torch.tensor(H)
            y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H_test, NT, snr_db_min=SNR_dBs[i], snr_db_max=SNR_dBs[i], batch_size = bs)
            rho = 0.6

            j_indices = j_indices.to(device=device)  
            y = y.to(device='cpu').unsqueeze(dim=-1).numpy()
            H = H.to(device='cpu').numpy()
            const = generator.constellation
            if parallel:
                shatBatch = pool.map(ml_proc_star, zip(hBatch, yBatch))
            else:
                shatBatch, status = mlSolver(H, y, sigConst)
            SER_final = sym_detection(torch.from_numpy(np.array(shatBatch)).to(device=device).double(), j_indices, real_QAM_const, imag_QAM_const)
            acum += SER_final
            del y, x, j_indices, noise_sigma
        acum = acum/iterations
        accs_NN.append((SNR_dBs[i], 1. - acum))# += acc[1]/iterations
        print([SNR_dBs[i], 1. - acum])

    return accs_NN                    

def sym_detection(x_hat, j_indices, real_QAM_const, imag_QAM_const):
    x_real, x_imag = torch.chunk(x_hat, 2, dim=-1)
    x_real = x_real.unsqueeze(dim=-1).expand(-1,-1, real_QAM_const.numel())
    x_imag = x_imag.unsqueeze(dim=-1).expand(-1, -1, imag_QAM_const.numel())

    x_real = torch.pow(x_real - real_QAM_const, 2)
    x_imag = torch.pow(x_imag - imag_QAM_const, 2)
    x_dist = x_real + x_imag
    x_indices = torch.argmin(x_dist, dim=-1)

    accuracy = (x_indices == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/j_indices.numel()


def batch_matvec_mul(A,b):
    '''Multiplies a matrix A of size batch_sizexNxK
       with a vector b of size batch_sizexK
       to produce the output of size batch_sizexN
    '''    
    C = torch.matmul(A, torch.unsqueeze(b, dim=2))
    return torch.squeeze(C, -1) 

def batch_identity_matrix(row, cols, batch_size):
    eye = torch.eye(row, cols)
    eye = eye.reshape((1, row, cols))
    
    return eye.repeat(batch_size, 1, 1)

def batch_trace(H):
    return H.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

In [None]:
device = 'cuda'
size_sample = repeat_sample * test_batch_size
generator = sample_generator(size_sample, mod_n, NR)


batch_size = 100
time_seq = 5

with open(PATH + '/H_test', 'rb') as fp:
    H = pkl.load(fp)

# H = H0.repeat_interleave(5, dim=0).to(device=device)

    
accs_NN = []
# accs_NN.append(model_eval_MMSE(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H.double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H.double(), iterations = 500))

with open(PATH + '/ML', 'wb') as fp:
    pkl.dump(accs_NN, fp)

In [None]:
device = 'cuda'
size_sample = repeat_sample * test_batch_size
generator = sample_generator(size_sample, mod_n, NR)


batch_size = 100
time_seq = 5
H0 = torch.empty((batch_size, 2 * NR, 2 * NT))
H1 = torch.empty((batch_size, 2 * NR, 2 * NT))
H2 = torch.empty((batch_size, 2 * NR, 2 * NT))
H3 = torch.empty((batch_size, 2 * NR, 2 * NT))
H4 = torch.empty((batch_size, 2 * NR, 2 * NT))

with open(PATH + '/H_test', 'rb') as fp:
    H = pkl.load(fp)
for ii in range(0, batch_size):
    H0[ii] = H[0 + ii * time_seq:1 + ii*time_seq,:,:]
    H1[ii] = H[1 + ii * time_seq:2 + ii*time_seq,:,:]
    H2[ii] = H[2 + ii * time_seq:3 + ii*time_seq,:,:]
    H3[ii] = H[3 + ii * time_seq:4 + ii*time_seq,:,:]
    H4[ii] = H[4 + ii * time_seq:5 + ii*time_seq,:,:]
        
accs_NN = []
# accs_NN.append(model_eval_MMSE(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H.double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H0.repeat_interleave(5, dim=0).to(device=device).double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H1.repeat_interleave(5, dim=0).to(device=device).double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H2.repeat_interleave(5, dim=0).to(device=device).double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H3.repeat_interleave(5, dim=0).to(device=device).double(), iterations = 500))
accs_NN.append(model_eval_ML(NT, snrdb_list[NT][0], snrdb_list[NT][-1], 500, 1, generator, 'cuda', H = H4.repeat_interleave(5, dim=0).to(device=device).double(), iterations = 500))

with open(PATH + '/ML_time', 'wb') as fp:
    pkl.dump(accs_NN, fp)