In [None]:
## import torch
import numpy as np
import torch.nn as nn
import pickle
import os

from sample_generator import sample_generator
from iterative_classifier import iterative_classifier
from classic_detectors import *

# Parameters
NR = 12
# NT_list = np.arange(16, 32)
NT_list = np.arange(4,8)
NT_prob = NT_list/NT_list.sum()
mod_n = 4
d_transmitter_encoding = NR
d_model = 512
n_head = 2
nhid = d_model*4
nlayers = 16
dropout = 0.0

epoch_size = 5000
train_iter = 130*epoch_size


# Batch sizes for training and validation sets
train_batch_size = 1000
mini_validtn_batch_size = 1000

learning_rate = 1e-4

corr_flag = True
batch_corr = True
rho_low = 0.6
rho_high = 0.6

validtn_NT_list = np.asarray([6, 6])
snrdb_list = {6:np.arange(11.0, 22.0), 16:np.arange(11.0, 22.0), 32:np.arange(16.0, 27.0)}
factor_list = (validtn_NT_list/validtn_NT_list.sum())/snrdb_list[16].size

model_filename = './validtn_results/model_with1thres.pth'
# curr_accr = './validtn_results/curr_accr.txt'
load_pretrained_model = True
save_interim_model = True
save_to_file = False

def sym_detection(x_hat, j_indices, real_QAM_const, imag_QAM_const):
    #Convierte a complejo
    x_real, x_imag = torch.chunk(x_hat, 2, dim=-1)
    #Lo expande a los 4 posibles simbolos para comparar
    x_real = x_real.unsqueeze(dim=-1).expand(-1,-1, real_QAM_const.numel())
    x_imag = x_imag.unsqueeze(dim=-1).expand(-1, -1, imag_QAM_const.numel())

    #Calcula la resta
    x_real = torch.pow(x_real - real_QAM_const, 2)
    x_imag = torch.pow(x_imag - imag_QAM_const, 2)
    x_dist = x_real + x_imag
    x_indices = torch.argmin(x_dist, dim=-1)

    accuracy = (x_indices == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/j_indices.numel()


def batch_matvec_mul(A,b):
    '''Multiplies a matrix A of size batch_sizexNxK
       with a vector b of size batch_sizexK
       to produce the output of size batch_sizexN
    '''    
    C = torch.matmul(A, torch.unsqueeze(b, dim=2))
    return torch.squeeze(C, -1) 

def batch_identity_matrix(row, cols, batch_size):
    eye = torch.eye(row, cols)
    eye = eye.reshape((1, row, cols))
    
    return eye.repeat(batch_size, 1, 1)

def batch_trace(H):
    return H.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

def get_snr_range(NT):
    peak = NT*(5.0/16.0) + 6.0
    snr_low = peak
    snr_high = peak+10.0
    return (snr_low, snr_high)

def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    del out
    return accuracy/j_indices.numel()
#     return accuracy/1.numel()

def loss_function(criterion, out, j_indices):
    out = torch.cat(out, dim=1).permute(1,2,0)
    j_indices = j_indices.repeat(nlayers, 1)
    loss = criterion(out, j_indices)
    del out, j_indices
    return loss

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device, criterion=None, attn_weights = None):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device).float()
        validtn_y = validtn_y.to(device=device).float()
        validtn_noise_sigma = validtn_noise_sigma.to(device=device).float()
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        if (criterion):
            validtn_j_indices = validtn_j_indices.to(device=device)
            loss = loss_function(criterion, validtn_out, validtn_j_indices)
            validtn_j_indices = validtn_j_indices.to(device='cpu')

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_noise_sigma, validtn_out, validtn_j_indices

        if (criterion):
            return accr, loss.item()
        else:
            return accr, None

def mini_validation(model, mini_validation_dict, i, device, criterion=None, save_to_file=True, attn_weights = None):
    result_dict = {int(NT):{} for NT in validtn_NT_list}
    loss_list = []
    for index,NT in enumerate(validtn_NT_list):
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = mini_validation_dict[NT][snr]
            ###--------------------------------------------------------------------------------###                   
            big_validtn_H = big_validtn_H.to(device=device).float()
            big_validtn_y = big_validtn_y.to(device=device).float()
            big_noise_sigma = big_noise_sigma.to(device=device).float()
            big_validtn_j_indices = big_validtn_j_indices.to(device=device)
            
            y_MMSE = mmse(big_validtn_y, big_validtn_H, big_noise_sigma, device)
            index_wrong = []

            for ii in range(y_MMSE.shape[0]):
                if sym_detection(y_MMSE[ii:ii+1,:], big_validtn_j_indices[ii:ii+1,:], generator.real_QAM_const.to(device=device), generator.imag_QAM_const.to(device=device)) != 1.0:
                    index_wrong.append(ii)

    #         print(attn_weights)
            if len(index_wrong) != 0:
                big_validtn_H = big_validtn_H[index_wrong,:,:].to(device=device)
                big_validtn_y = big_validtn_y[index_wrong,:].to(device=device)
                big_noise_sigma = big_noise_sigma[index_wrong].to(device=device)
                big_validtn_j_indices = big_validtn_j_indices[index_wrong].to(device=device)


                accr, loss = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device, criterion, attn_weights)
                out = model.forward(big_validtn_H, big_validtn_y, big_noise_sigma)[-1].permute(1,2,0)
                out = out.argmax(dim=1)

                index_wrongg = []
                for ii in range(out.shape[0]):
                    if torch.sum(out[ii:ii+1,:] != big_validtn_j_indices[ii:ii+1,:]) != 0:
                        index_wrongg.append(ii)
                    
            result_dict[NT][snr] = (mini_validtn_batch_size - len(index_wrongg)) / mini_validtn_batch_size
            ###--------------------------------------------------------------------------------###       
#             result_dict[NT][snr] = accr
            loss_list.append(loss*factor_list[index])

    print('Validtn result, Accr for 16 : ', result_dict[6])
#     print('Validation resut, Accr for 32 : ', result_dict[32])
    if (save_to_file):
        with open(curr_accr, 'w') as f:
            print((i, result_dict), file=f)
        print('Saved intermediate validation results at : ', curr_accr)

    if (criterion):
        return np.sum(loss_list)

def generate_big_validtn_data(generator, batch_size, corr_flag, rho, batch_corr, rho_low, rho_high):
    validtn_data_dict = {int(NT):{} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho, batch_corr=batch_corr,rho_low=rho_low, rho_high=rho_high)
#             attn_weights_test = attn_mask(big_validtn_H, NT)
            validtn_data_dict[int(NT)][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def save_model_func(model, optimizer):
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_filename)
    print('******Model Saved********** at directory : ', model_filename)


def train(model, optimizer, lr_scheduler, generator , device='cpu'):

    mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, corr_flag, None, batch_corr, rho_low, rho_high)
    # Fix loss criterion
    criterion = nn.CrossEntropyLoss().to(device=device)
    model.train()
    epoch_count = 1

    for i in range(0, train_iter+1):

        # Randomly select number of transmitters
#         NT = np.random.choice(NT_list, p=NT_prob)
#         rho = np.random.triangular(rho_low, rho_high, rho_high)
        NT = 6
        rho = 0.6
        
        snr_low, snr_high = get_snr_range(NT)
        H, y, j_indices, noise_sigma = generator.give_batch_data(NT, snr_db_min=snr_low, snr_db_max=snr_high, batch_size=train_batch_size, correlated_flag=corr_flag, rho=rho)        
        H = H.to(device=device).float()
        y = y.to(device=device).float()
        noise_sigma = noise_sigma.to(device=device).float()
        j_indices = j_indices.to(device=device)
        
        y_MMSE = mmse(y, H, noise_sigma, device)
        index_wrong = []

        for ii in range(y_MMSE.shape[0]):
            if sym_detection(y_MMSE[ii:ii+1,:], j_indices[ii:ii+1,:], generator.real_QAM_const.to(device=device), generator.imag_QAM_const.to(device=device)) != 1.0:
                index_wrong.append(ii)
        
        H_f = H[index_wrong,:,:].to(device=device)
        y_f = y[index_wrong,:].to(device=device)
        noise_sigma_f = noise_sigma[index_wrong].to(device=device)
        j_indices_f = j_indices[index_wrong].to(device=device)

        out = model.forward(H_f, y_f , noise_sigma_f)

        del H, y, noise_sigma,H_f, y_f, noise_sigma_f

        loss = loss_function(criterion, out, j_indices_f)
        del j_indices, out, j_indices_f
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_item = loss.item()
        del loss

        if (i%epoch_size==0):
            print('iteration number : ', i, 'Epoch : ', epoch_count, 'User : ', NT, 'loss : ', loss_item)
            print('Now validating')

            model.eval()
            mini_validtn_loss = mini_validation(model, mini_validation_dict, i, device, criterion, save_to_file)
            print('Mini validation loss : ', mini_validtn_loss)
            lr_scheduler.step(mini_validtn_loss)

            model.train()
            if (save_interim_model):
                save_model_func(model, optimizer)

            epoch_count = epoch_count+1

In [10]:
# Cu = createCu()
generator = sample_generator(train_batch_size, mod_n, NR)
device = 'cuda'
model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NR, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

if (load_pretrained_model):
    checkpoint = torch.load(model_filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, 0.0001, 'rel', 0, 0, 1e-08, verbose = True)
    print('*******Successfully loaded pre-trained model***********')
else:
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, 0.0001, 'rel', 0, 0, 1e-08, verbose = True)

# train(model, optimizer, lr_scheduler, generator, device)
print('******************************** Now Testing **********************************************')
#Last: 60k

*******Successfully loaded pre-trained model***********
******************************** Now Testing **********************************************


In [11]:
NT = 6
rho = 0.6

mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, corr_flag, None, batch_corr, rho_low, rho_high)
device = 'cuda'
criterion = nn.CrossEntropyLoss().to(device=device)
mini_validation(model, mini_validation_dict, 0, device, criterion, save_to_file)

Validtn result, Accr for 16 :  {11.0: 0.991, 12.0: 0.996, 13.0: 0.997, 14.0: 1.0, 15.0: 1.0, 16.0: 1.0, 17.0: 1.0, 18.0: 1.0, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0}


0.15376132896000688

In [14]:
validtn_batch_size = 1000
validtn_iter = 500

corr_list = np.asarray([0.60, 0.6])
from collections import defaultdict

def generate_big_validtn_data(generator, batch_size):
    validtn_data_dict = {int(NT):{rho:{} for rho in corr_list} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for rho in corr_list:
            for snr in snrdb_list[NT]:
                big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
                validtn_data_dict[int(NT)][rho][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device)
        validtn_y = validtn_y.to(device=device)
        validtn_noise_sigma = validtn_noise_sigma.to(device=device)
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_out, validtn_noise_sigma

    return accr

def validate_model(model, generator, device, save_result=True):
    result_dict = {int(NT):{rho:defaultdict(float) for rho in corr_list} for NT in validtn_NT_list}
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size)
        for NT in validtn_NT_list:
            for rho in corr_list:
                for snr in snrdb_list[NT]:
                    big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][rho][snr]
                    ###--------------------------------------------------------------------------------###                   
                    big_validtn_H = big_validtn_H.to(device=device).float()
                    big_validtn_y = big_validtn_y.to(device=device).float()
                    big_noise_sigma = big_noise_sigma.to(device=device).float()
                    big_validtn_j_indices = big_validtn_j_indices.to(device=device)

                    y_MMSE = mmse(big_validtn_y, big_validtn_H, big_noise_sigma, device)
                    index_wrong = []

                    for ii in range(y_MMSE.shape[0]):
                        if sym_detection(y_MMSE[ii:ii+1,:], big_validtn_j_indices[ii:ii+1,:], generator.real_QAM_const.to(device=device), generator.imag_QAM_const.to(device=device)) != 1.0:
                            index_wrong.append(ii)

                    index_wrongg = []
                    if len(index_wrong) != 0:
                #         print(attn_weights)
                        big_validtn_H = big_validtn_H[index_wrong,:,:].to(device=device)
                        big_validtn_y = big_validtn_y[index_wrong,:].to(device=device)
                        big_noise_sigma = big_noise_sigma[index_wrong].to(device=device)
                        big_validtn_j_indices = big_validtn_j_indices[index_wrong].to(device=device)


    #                     accr, loss = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device)
                        out = model.forward(big_validtn_H, big_validtn_y, big_noise_sigma)[-1].permute(1,2,0)
                        out = out.argmax(dim=1)

                        index_wrongg = []
                        for ii in range(out.shape[0]):
                            if torch.sum(out[ii:ii+1,:] != big_validtn_j_indices[ii:ii+1,:]) != 0:
                                index_wrongg.append(ii)

                    accr = (validtn_batch_size - len(index_wrongg)) / validtn_batch_size
                    ###--------------------------------------------------------------------------------###       
                    
#                     accr = validate_model_given_data(model, big_validtn_H.float(), big_validtn_y.float(), big_validtn_j_indices.float(), big_noise_sigma.float(), device)
                    result_dict[int(NT)][rho][snr] =  result_dict[int(NT)][rho][snr] + (accr-result_dict[int(NT)][rho][snr])/float(iter+1.0)

        if (save_result):
            with open(validtn_filename, 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('Intermediate Test results saved at : ', validtn_filename)
        print('Big Validtn result, Accr for 6 : ', result_dict[6])
#         print('Big Validation resut, Accr for 32 : ', result_dict[32])

def test(model, generator, device):
    model.eval()

    # Testing Trained Network
    validate_model(model, generator, device, False)


generator = sample_generator(validtn_batch_size, mod_n, NR)
device = 'cuda'
# model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NR, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
# model = model.to(device=device)

# checkpoint = torch.load(model_filename)
# model.load_state_dict(checkpoint['model_state_dict'])
# print('*******Successfully loaded pre-trained model*********** from directory : ', model_filename)

test(model, generator, device)
print('******************************** Now Testing **********************************************')

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.98, 12.0: 0.994, 13.0: 0.999, 14.0: 1.0, 15.0: 1.0, 16.0: 1.0, 17.0: 1.0, 18.0: 1.0, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9846874999999999, 12.0: 0.99775, 13.0: 0.9999375, 14.0: 1.0, 15.0: 0.9971874999999999, 16.0: 1.0, 17.0: 1.0, 18.0: 1.0, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9873456790123457, 12.0: 0.9939382716049383, 13.0: 0.9983827160493827, 14.0: 0.9991975308641976, 15.0: 0.998641975308642, 16.0: 0.9991975308641976, 17.0: 1.0, 18.0: 1.0, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9857421875000001, 12.0: 0.99398046875, 13.0: 0.9988046875, 14.0: 0.99701171875, 15.0: 0.99888671875, 16.0: 0.99974609375, 17.0: 1.0, 18.0: 1.0, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.986530850617284, 12.0: 0.9936044333333336, 13.0: 0.9976093580246912, 14.0: 0.9986548197530866, 15.0: 0.9992711777777779, 16.0: 0.9999143506172838, 17.0: 0.9998581135802471, 18.0: 0.9999698543209876, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9864655974255053, 12.0: 0.993653057158419, 13.0: 0.9975344545494903, 14.0: 0.9985743280336885, 15.0: 0.9993607660248116, 16.0: 0.9999248788062209, 17.0: 0.9998755545353057, 18.0: 0.9999735598865646, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870063781738283, 12.0: 0.9935751724243167, 13.0: 0.9978284997940061, 14.0: 0.9987443561553956, 15.0: 0.9994370021820069, 16.0: 0.999933837890625, 17.0: 0.9998903961181642, 18.0: 0.9999767131805419, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'floa

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9874173135262114, 12.0: 0.9943529435057843, 13.0: 0.9974966623250603, 14.0: 0.9992077237393212, 15.0: 0.999601663357175, 16.0: 0.9998072663123092, 17.0: 0.99986141740608, 18.0: 0.9999821232491357, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9875236737590858, 12.0: 0.9945308318935474, 13.0: 0.9973960831161871, 14.0: 0.9991265618792247, 15.0: 0.9996284325112652, 16.0: 0.9998202184669179, 17.0: 0.9998707304805302, 18.0: 0.9999161222696495, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9874229705626497, 12.0: 0.9944957479603028, 13.0: 0.9974359974090067, 14.0: 0.9990521047445892, 15.0: 0.9996529902839408, 16.0: 0.9998321006529394, 17.0: 0.9998792742083031, 18.0: 0.9999216659469004, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'floa

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867914581991619, 12.0: 0.9947199653741885, 13.0: 0.9973923746487444, 14.0: 0.9988684404541766, 15.0: 0.9995163326727612, 16.0: 0.9997492235343686, 17.0: 0.9998311335884031, 18.0: 0.9999622453277891, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868011001017707, 12.0: 0.9945942081344813, 13.0: 0.9973742332587012, 14.0: 0.9989207579602731, 15.0: 0.9994462252319776, 16.0: 0.9997608181511238, 17.0: 0.9998389411046322, 18.0: 0.9999639909100705, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.986718777851504, 12.0: 0.9946127555208684, 13.0: 0.9973571283104806, 14.0: 0.9989243798483735, 15.0: 0.9994715364209558, 16.0: 0.9997717503519509, 17.0: 0.9998463025670402, 18.0: 0.9999656367649068, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'fl

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9866507872253437, 12.0: 0.9940312286173872, 13.0: 0.9973306911304594, 14.0: 0.9986764932003385, 15.0: 0.9995083712553225, 16.0: 0.999706667268483, 17.0: 0.9999257271999656, 18.0: 0.9999876178536409, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868040483086272, 12.0: 0.9940301281580677, 13.0: 0.9973542767786472, 14.0: 0.9987231320051887, 15.0: 0.9994552180470152, 16.0: 0.9997170039649207, 17.0: 0.9999283444850656, 18.0: 0.9999880541857458, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870554108753714, 12.0: 0.994064006856266, 13.0: 0.9973768326063246, 14.0: 0.9986978722029117, 15.0: 0.9994742478843308, 16.0: 0.9997268893300268, 17.0: 0.999930847491607, 18.0: 0.9999884714662753, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'floa

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867749985320146, 12.0: 0.9937389469641955, 13.0: 0.9972223115945118, 14.0: 0.9987662713788176, 15.0: 0.9995812523456331, 16.0: 0.9997748804838291, 17.0: 0.9999689110992525, 18.0: 0.999974542195941, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869522113896274, 12.0: 0.9938033144461551, 13.0: 0.9972159828453306, 14.0: 0.998772925147026, 15.0: 0.9995647052990673, 16.0: 0.9997812891687058, 17.0: 0.9999697961356608, 18.0: 0.9999752669267268, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868687619351312, 12.0: 0.9938936743934821, 13.0: 0.9972664112322993, 14.0: 0.9987510770564344, 15.0: 0.9995770096704499, 16.0: 0.9997592046543103, 17.0: 0.9999706499010828, 18.0: 0.9999759660506039, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'flo

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9871051762874591, 12.0: 0.9939101497276803, 13.0: 0.9972993812829017, 14.0: 0.9988849361277857, 15.0: 0.9995745152400543, 16.0: 0.999779328384943, 17.0: 0.9999674697130755, 18.0: 0.9999875433429374, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9871265442484257, 12.0: 0.9939122953090779, 13.0: 0.9972683526812401, 14.0: 0.9988876837978526, 15.0: 0.9995846756119968, 16.0: 0.9997845979172441, 17.0: 0.9999682465207209, 18.0: 0.999987840801932, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870998025497216, 12.0: 0.99393811504156, 13.0: 0.9972619825749987, 14.0: 0.9988903499390552, 15.0: 0.9995945345053906, 16.0: 0.9997897110919866, 17.0: 0.9999690002789304, 18.0: 0.9999881294347236, 19.0: 1.0, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869573800536835, 12.0: 0.9940657952706237, 13.0: 0.9973791438428512, 14.0: 0.9987912398130855, 15.0: 0.999665817436987, 16.0: 0.9997773652700326, 17.0: 0.9999818828365808, 18.0: 0.9999930624869009, 19.0: 0.9999837534403352, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870202744684099, 12.0: 0.9939817506582649, 13.0: 0.9973919776457323, 14.0: 0.9987748839862552, 15.0: 0.9996727253706373, 16.0: 0.9997819673831081, 17.0: 0.9999822573389539, 18.0: 0.9999932058931869, 19.0: 0.999984089275201, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870404223928129, 12.0: 0.99404382055193, 13.0: 0.9974250464101847, 14.0: 0.9988000783330737, 15.0: 0.9996794557295785, 16.0: 0.9997864511946868, 17.0: 0.999982622214403, 18.0: 0.9999933456130838, 19.0: 0.9999844164771265, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, A

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868675383455997, 12.0: 0.9941824370193808, 13.0: 0.9974844412283493, 14.0: 0.9988935641531793, 15.0: 0.9996661486247367, 16.0: 0.9997744041064697, 17.0: 0.9999745860319422, 18.0: 0.999995836118612, 19.0: 0.999990248847614, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.986924620434716, 12.0: 0.9942337807863266, 13.0: 0.9974209451008109, 14.0: 0.9989137264653087, 15.0: 0.9996540095598871, 16.0: 0.9997785150857311, 17.0: 0.9999750491445193, 18.0: 0.9999959119960128, 19.0: 0.9999904265404993, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869441279479249, 12.0: 0.9942658202129186, 13.0: 0.9973951689876897, 14.0: 0.9989334315964183, 15.0: 0.9996602858676193, 16.0: 0.9997825328484172, 17.0: 0.9999755017560051, 18.0: 0.9999959861528912, 19.0: 0.9999906002041932, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result,

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868583755514511, 12.0: 0.9943168463386157, 13.0: 0.9972607116673854, 14.0: 0.9991094586937725, 15.0: 0.9996176471775494, 16.0: 0.9997859336476942, 17.0: 0.9999701517263383, 18.0: 0.9999973520353025, 19.0: 0.9999782355754546, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869421477215272, 12.0: 0.9942790981040915, 13.0: 0.9972890497849335, 14.0: 0.9990913823516909, 15.0: 0.9996075838901736, 16.0: 0.9997894214175717, 17.0: 0.9999706380424204, 18.0: 0.999997395178428, 19.0: 0.9999785901818817, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869430864814237, 12.0: 0.9942258886907629, 13.0: 0.997284359419171, 14.0: 0.9991061263485072, 15.0: 0.9996139515651628, 16.0: 0.9997928384433742, 17.0: 0.999971114494325, 18.0: 0.9999974374464613, 19.0: 0.99997893759566, 20.0: 1.0, 21.0: 1.0})}
Big Validtn result, A

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9872201347122396, 12.0: 0.9942697137512848, 13.0: 0.9974242069831291, 14.0: 0.9989504543900052, 15.0: 0.9995742623232975, 16.0: 0.9998191498225536, 17.0: 0.9999389627602016, 18.0: 0.9999833681831444, 19.0: 0.9999852893078504, 20.0: 0.9999859747595974, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9872464538938389, 12.0: 0.9942509382703375, 13.0: 0.9974327213359168, 14.0: 0.9989363998534825, 15.0: 0.9995805577813324, 16.0: 0.999821824085991, 17.0: 0.9999398653286206, 18.0: 0.9999836141207505, 19.0: 0.999985506837447, 20.0: 0.9999861821532983, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.987287021099237, 12.0: 0.9942619739820602, 13.0: 0.9974410788898478, 14.0: 0.9989520695695292, 15.0: 0.9995867373032947, 16.0: 0.9998097163873317, 17.0: 0.9999407512754945, 18.0: 0.999983855529212, 19.0: 0.999985720361056, 20.0: 0.9999863

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868354793804357, 12.0: 0.9940492008824349, 13.0: 0.9973701644857897, 14.0: 0.9988522067839953, 15.0: 0.9995908828698182, 16.0: 0.9998049569676232, 17.0: 0.9999349367157777, 18.0: 0.999977132696836, 19.0: 0.9999897032112737, 20.0: 0.9999901829950769, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867835616260956, 12.0: 0.9940214625223878, 13.0: 0.9973922262390801, 14.0: 0.9988542073401503, 15.0: 0.9995693483866869, 16.0: 0.999807597105983, 17.0: 0.9999358174243598, 18.0: 0.9999774422328574, 19.0: 0.9999898425904997, 20.0: 0.9999903158798611, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867595004308939, 12.0: 0.9940346635135487, 13.0: 0.9973464633024298, 14.0: 0.9988696646927003, 15.0: 0.9995751581060598, 16.0: 0.9997967021913766, 17.0: 0.9999366832814507, 18.0: 0.9999777465491372, 19.0: 0.9999899796193576, 20.0: 0.9999

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868034669744079, 12.0: 0.9941514490532938, 13.0: 0.997343998211868, 14.0: 0.998845272770953, 15.0: 0.9995042252364037, 16.0: 0.9998112645991797, 17.0: 0.9999415487929955, 18.0: 0.9999727747567914, 19.0: 0.9999924848918594, 20.0: 0.9999928350619236, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868309662665447, 12.0: 0.9940994753706873, 13.0: 0.9973522109656049, 14.0: 0.9988597292717384, 15.0: 0.9995104320427204, 16.0: 0.999813627454581, 17.0: 0.9999422805674775, 18.0: 0.9999731156007067, 19.0: 0.9999925789765611, 20.0: 0.9999929247627035, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868205955196968, 12.0: 0.9940732731908163, 13.0: 0.9973478152507504, 14.0: 0.9988739602454096, 15.0: 0.9995165420204105, 16.0: 0.9998159534485062, 17.0: 0.9999305205795954, 18.0: 0.9999734511273224, 19.0: 0.9999926715935041, 20.0: 0.99999

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.986745745829596, 12.0: 0.9942322689832968, 13.0: 0.9973216009720933, 14.0: 0.9988549821386127, 15.0: 0.9995735789444918, 16.0: 0.9998159274668647, 17.0: 0.9999467890021523, 18.0: 0.9999796674756545, 19.0: 0.9999943875205062, 20.0: 0.9999946490366766, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867487065453002, 12.0: 0.994229564278688, 13.0: 0.9972712771884326, 14.0: 0.9988450261208186, 15.0: 0.9995668997848699, 16.0: 0.9998064262293395, 17.0: 0.9999474086287123, 18.0: 0.999979904241974, 19.0: 0.9999944528761939, 20.0: 0.9999947113470848, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9867516242903515, 12.0: 0.9941804551991965, 13.0: 0.9973029601365174, 14.0: 0.9988352146007174, 15.0: 0.9995719284713944, 16.0: 0.9998086737964799, 17.0: 0.9999480192622583, 18.0: 0.9999801375719611, 19.0: 0.9999945172833352, 20.0: 0.99999

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870026960701406, 12.0: 0.994167778785735, 13.0: 0.9973298785007986, 14.0: 0.9989391892011903, 15.0: 0.999519260993158, 16.0: 0.999821432843872, 17.0: 0.9999594633867638, 18.0: 0.9999845105014221, 19.0: 0.9999957243630121, 20.0: 0.9999959235883661, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.986980898220528, 12.0: 0.9941877211439796, 13.0: 0.9973154037674651, 14.0: 0.9989398510812765, 15.0: 0.9995244934778722, 16.0: 0.9998233764138791, 17.0: 0.999959904597496, 18.0: 0.9999846790930351, 19.0: 0.9999957709001246, 20.0: 0.9999959679570595, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869593961322071, 12.0: 0.9941965382005665, 13.0: 0.9973228348653901, 14.0: 0.9989513586970882, 15.0: 0.9995296549662971, 16.0: 0.9998252936128088, 17.0: 0.9999603398217175, 18.0: 0.9999848453971346, 19.0: 0.9999958168058042, 20.0: 0.9999960

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870772619919662, 12.0: 0.9941253242856694, 13.0: 0.9972658869397409, 14.0: 0.998957698777417, 15.0: 0.9995320417389424, 16.0: 0.9998527750826783, 17.0: 0.999959255928472, 18.0: 0.9999579628230155, 19.0: 0.9999966858368973, 20.0: 0.9999968402619143, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870253874543159, 12.0: 0.9941342608736073, 13.0: 0.9972631703649242, 14.0: 0.998968348000187, 15.0: 0.9995266058525804, 16.0: 0.9998542792840626, 17.0: 0.9999596722118697, 18.0: 0.9999583923181102, 19.0: 0.9999967196978008, 20.0: 0.9999968725450529, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870047467313964, 12.0: 0.9941532746204663, 13.0: 0.997270679395507, 14.0: 0.9989686705657072, 15.0: 0.99953143021219, 16.0: 0.9998455733241433, 17.0: 0.999960083192346, 18.0: 0.9999588163419835, 19.0: 0.9999967531273596, 20.0: 0.999996904

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869655366314256, 12.0: 0.9940492725513725, 13.0: 0.9972557401572835, 14.0: 0.9990387350095555, 15.0: 0.9995525780228193, 16.0: 0.9998576641380506, 17.0: 0.9999679326830553, 18.0: 0.999966914954071, 19.0: 0.9999973916127023, 20.0: 0.9999975131517577, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.987023629718936, 12.0: 0.9940776788664701, 13.0: 0.9972725319469581, 14.0: 0.9989998545703166, 15.0: 0.9995568853029347, 16.0: 0.9998590343890869, 17.0: 0.9999682413914416, 18.0: 0.9999672334600177, 19.0: 0.999997416723348, 20.0: 0.9999975370923607, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870330065565032, 12.0: 0.9940481215403464, 13.0: 0.9972507070682761, 14.0: 0.9990094597389904, 15.0: 0.999551537103455, 16.0: 0.9998507844186736, 17.0: 0.9999685463938772, 18.0: 0.9999675481423969, 19.0: 0.999997441532548, 20.0: 0.9999975

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870147552508574, 12.0: 0.9942511032835106, 13.0: 0.9972155179225389, 14.0: 0.9990506282957597, 15.0: 0.9995680196957142, 16.0: 0.9998367623476766, 17.0: 0.9999744178475302, 18.0: 0.9999646707841999, 19.0: 0.9999979191223991, 20.0: 0.9999980160818874, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869873173837138, 12.0: 0.9942579191365075, 13.0: 0.997204455260235, 14.0: 0.9990501675179064, 15.0: 0.999562850039258, 16.0: 0.9998382480049339, 17.0: 0.9999746506756181, 18.0: 0.9999649923221853, 19.0: 0.9999979380608663, 20.0: 0.9999980341379076, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9869511104227096, 12.0: 0.9942283354992449, 13.0: 0.9971753571033837, 14.0: 0.9990587925016562, 15.0: 0.999566819593236, 16.0: 0.9998397167990163, 17.0: 0.9999748808609559, 18.0: 0.9999653102105056, 19.0: 0.9999979567843695, 20.0: 0.99999

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870310675596924, 12.0: 0.9941120293256379, 13.0: 0.9972796888351629, 14.0: 0.9990227269345902, 15.0: 0.9994818570408837, 16.0: 0.999852683321386, 17.0: 0.9999793433862288, 18.0: 0.9999319783671385, 19.0: 0.9999902762548641, 20.0: 0.999998398061686, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870049096322091, 12.0: 0.9941283223969793, 13.0: 0.9972772751376376, 14.0: 0.9990311607413177, 15.0: 0.999486328582901, 16.0: 0.9998539546553038, 17.0: 0.9999795216515411, 18.0: 0.999932565389666, 19.0: 0.9999903601701894, 20.0: 0.9999984118863156, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9870048673535985, 12.0: 0.9941186060066081, 13.0: 0.9972834987816178, 14.0: 0.9990308924049559, 15.0: 0.9994907519924021, 16.0: 0.9998552123043453, 17.0: 0.999979697997972, 18.0: 0.9999245347333598, 19.0: 0.9999904431822344, 20.0: 0.9999984

Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868270057828579, 12.0: 0.9942501000074765, 13.0: 0.9973897942228636, 14.0: 0.9990119821786337, 15.0: 0.9995305557732294, 16.0: 0.9998327051939893, 17.0: 0.9999597167748011, 18.0: 0.9999373235888909, 19.0: 0.9999920627453414, 20.0: 0.9999986923770453, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868366302999629, 12.0: 0.9942398428308178, 13.0: 0.9973783908433116, 14.0: 0.9990118838638412, 15.0: 0.9995344076029522, 16.0: 0.9998258727771999, 17.0: 0.9999600473020828, 18.0: 0.9999378378541635, 19.0: 0.9999921278711895, 20.0: 0.9999987031062026, 21.0: 1.0})}
Big Validtn result, Accr for 6 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9868297797296494, 12.0: 0.9942542555056288, 13.0: 0.9973752924698983, 14.0: 0.9990199748438943, 15.0: 0.9995300317192909, 16.0: 0.9998272985811464, 17.0: 0.9999603744463019, 18.0: 0.999938346855751, 19.0: 0.9999921923304518, 20.0: 0.999

In [None]:
def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_indices(indices, mod_n):
    real_indices = (indices//np.sqrt(mod_n)).to(dtype=torch.int32)
    imag_indices = (indices%np.sqrt(mod_n)).to(dtype=torch.int32)
    joint_bit_indices = torch.cat((real_indices, imag_indices), dim=-1)
    return joint_bit_indices

def sym_accuracy(out, j_indices):
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    bit_out_indices = bit_indices(out, mod_n)
    bit_j_indices = bit_indices(j_indices, mod_n)
    return sym_accuracy(bit_out_indices, bit_j_indices)

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device).float()
        validtn_y = validtn_y.to(device=device).float()
        validtn_noise_sigma = validtn_noise_sigma.to(device=device).float()
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_out, validtn_noise_sigma

    return accr


def validate_model(model, generator, device, save_result=True, Cu = None):
    result_dict = {int(NT): {snr: float(0) for snr in snrdb_list[6]} for NT in validtn_NT_list}    
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size, QR = QR, Cu = Cu)
#         for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][snr]
            accr = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device)
            result_dict[NT][snr] = result_dict[NT][snr] + (accr-result_dict[NT][snr])/float(iter+1.0)
#                 result_dict[NT][snr] = result_dict[NT][snr] + accr
            print('Big Validtn result, Accr for 16 : ', accr)
        if (save_result):
            with open('REMIMO_result2_testSet2', 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
#             print('Intermediate Test results saved at : ', validtn_filename)
        print('Big Validtn result, Accr for 16 : ', result_dict[6])
    #         print('Big Validation resut, Accr for 32 : ', result_dict[32])


def test(model, generator, device, Cu = None):
    model.eval()

    # Testing Trained Network
    validate_model(model, generator, device, True, Cu = Cu)

# corr_list = np.asarray([0.60, 0.70])
NT = np.random.choice(NT_list, p=NT_prob)
generator = sample_generator(validtn_batch_size, mod_n, NT)
# device = 'cuda'
# model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NT, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
# model = model.to(device=device)

# checkpoint = torch.load('re_mimo_localScatteringModel.pth')
# model.load_state_dict(checkpoint)
# print('*******Successfully loaded pre-trained model*********** from directory : ', 're_mimo_localScatteringModel.pth')

test(model, generator, device, Cu = Cu)
print('******************************** Now Testing **********************************************')

