In [2]:
import torch
import numpy as np
import torch.nn as nn
import pickle
import os

from sample_generator import sample_generator
from iterative_classifier import iterative_classifier

# Parameters
NR = 64
NT_list = np.arange(16, 33)
# NT_list = np.arange(4,8)
NT_prob = NT_list/NT_list.sum()
mod_n = 16
d_transmitter_encoding = NR
d_model = 512
n_head = 8
nhid = d_model*4
nlayers = 16
dropout = 0.0

epoch_size = 5000
train_iter = 130*epoch_size
# train_iter = 50001


# Batch sizes for training and validation sets
train_batch_size = 256
mini_validtn_batch_size = 5000

learning_rate = 1e-4

corr_flag = True
batch_corr = True
rho_low = 0.55
rho_high = 0.75

validtn_NT_list = np.asarray([16, 32])
snrdb_list = {16:np.arange(11.0, 22.0), 32:np.arange(16.0, 27.0)}
factor_list = (validtn_NT_list/validtn_NT_list.sum())/snrdb_list[16].size

model_filename = './validtn_results/model.pth'
# curr_accr = './validtn_results/curr_accr.txt'
load_pretrained_model = False
save_interim_model = True
save_to_file = False

def get_snr_range(NT):
    peak = NT*(5.0/16.0) + 6.0
    snr_low = peak
    snr_high = peak+10.0
    return (snr_low, snr_high)

def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    del out
    return accuracy/j_indices.numel()

def loss_function(criterion, out, j_indices):
    out = torch.cat(out, dim=1).permute(1,2,0)
    j_indices = j_indices.repeat(nlayers, 1)
    loss = criterion(out, j_indices)
    del out, j_indices
    return loss

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device, criterion=None):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device).float()
        validtn_y = validtn_y.to(device=device).float()
        validtn_noise_sigma = validtn_noise_sigma.to(device=device).float()
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        if (criterion):
            validtn_j_indices = validtn_j_indices.to(device=device)
            loss = loss_function(criterion, validtn_out, validtn_j_indices)
            validtn_j_indices = validtn_j_indices.to(device='cpu')

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_noise_sigma, validtn_out, validtn_j_indices

        if (criterion):
            return accr, loss.item()
        else:
            return accr, None

def mini_validation(model, mini_validation_dict, i, device, criterion=None, save_to_file=True):
    result_dict = {int(NT):{} for NT in validtn_NT_list}
    loss_list = []
    for index,NT in enumerate(validtn_NT_list):
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = mini_validation_dict[NT][snr]
            accr, loss = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device, criterion)
            result_dict[NT][snr] = accr
            loss_list.append(loss*factor_list[index])

    print('Validtn result, Accr for 16 : ', result_dict[16])
    print('Validation resut, Accr for 32 : ', result_dict[32])
    if (save_to_file):
        with open(curr_accr, 'w') as f:
            print((i, result_dict), file=f)
        print('Saved intermediate validation results at : ', curr_accr)

    if (criterion):
        return np.sum(loss_list)

def generate_big_validtn_data(generator, batch_size, corr_flag, rho, batch_corr, rho_low, rho_high):
    validtn_data_dict = {int(NT):{} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho, batch_corr=batch_corr,rho_low=rho_low, rho_high=rho_high)
            validtn_data_dict[int(NT)][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def save_model_func(model, optimizer):
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_filename)
    print('******Model Saved********** at directory : ', model_filename)


def train(model, optimizer, lr_scheduler, generator , device='cpu'):

    mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, corr_flag, None, batch_corr, rho_low, rho_high)
    # Fix loss criterion
    criterion = nn.CrossEntropyLoss().to(device=device)
    model.train()
    epoch_count = 1

    for i in range(1, train_iter+1):

        # Randomly select number of transmitters
        NT = np.random.choice(NT_list, p=NT_prob)
        rho = np.random.triangular(rho_low, rho_high, rho_high)

        snr_low, snr_high = get_snr_range(NT)
        H, y, j_indices, noise_sigma = generator.give_batch_data(NT, snr_db_min=snr_low, snr_db_max=snr_high, batch_size=None, correlated_flag=corr_flag, rho=rho)

        H = H.to(device=device).float()
        y = y.to(device=device).float()
        noise_sigma = noise_sigma.to(device=device).float()

        out = model.forward(H,y, noise_sigma)

        del H, y, noise_sigma

        j_indices = j_indices.to(device=device)
        loss = loss_function(criterion, out, j_indices)
        del j_indices, out
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_item = loss.item()
        del loss

        if (i%epoch_size==0):
            print('iteration number : ', i, 'Epoch : ', epoch_count, 'User : ', NT, 'loss : ', loss_item)
            print('Now validating')

            model.eval()
            mini_validtn_loss = mini_validation(model, mini_validation_dict, i, device, criterion, save_to_file)
            print('Mini validation loss : ', mini_validtn_loss)
            lr_scheduler.step(mini_validtn_loss)

            model.train()
            if (save_interim_model):
                save_model_func(model, optimizer)

            epoch_count = epoch_count+1


In [None]:
generator = sample_generator(train_batch_size, mod_n, NR)
device = 'cuda'
model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NR, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

if (load_pretrained_model):
    checkpoint = torch.load(model_filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, 0.0001, 'rel', 0, 0, 1e-08, verbose = True)
    print('*******Successfully loaded pre-trained model***********')
else:
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, 0.0001, 'rel', 0, 0, 1e-08, verbose = True)

train(model, optimizer, lr_scheduler, generator, device)
print('******************************** Now Testing **********************************************')



In [8]:
from collections import defaultdict

# Parameters
NR = 64
NT = 32
mod_n = 16
d_transmitter_encoding = NR
d_model = 512
n_head = 8
nhid = d_model*4
nlayers = 16
dropout = 0.0

# Batch sizes for training and validation sets
validtn_batch_size = 5000
validtn_iter = 2000

M = int(np.sqrt(mod_n))
sigConst = np.linspace(-M+1, M-1, M) 
sigConst /= np.sqrt((sigConst ** 2).mean())
sigConst /= np.sqrt(2.) #Each complex transmitted signal will have two parts

validtn_NT_list = np.asarray([16, 32])
snrdb_list = {16:np.arange(11.0, 19.0), 32:np.arange(15.0, 24.0)}
corr_list = np.asarray([0.60, 0.70])

corr_flag = True
save_result = False

validtn_filename = './final_results/network_fullcorr_validtn_results.pickle'
model_filename = './validtn_results/model.pth'

def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_indices(indices, mod_n):
    real_indices = (indices//np.sqrt(mod_n)).to(dtype=torch.int32)
    imag_indices = (indices%np.sqrt(mod_n)).to(dtype=torch.int32)
    joint_bit_indices = torch.cat((real_indices, imag_indices), dim=-1)
    return joint_bit_indices

def sym_accuracy(out, j_indices):
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    bit_out_indices = bit_indices(out, mod_n)
    bit_j_indices = bit_indices(j_indices, mod_n)
    return sym_accuracy(bit_out_indices, bit_j_indices)

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device).float()
        validtn_y = validtn_y.to(device=device).float()
        validtn_noise_sigma = validtn_noise_sigma.to(device=device).float()
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_out, validtn_noise_sigma

    return accr


def validate_model(model, generator, device, save_result=True):
    result_dict = {int(NT):{rho:defaultdict(float) for rho in corr_list} for NT in validtn_NT_list}
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size)
        for NT in validtn_NT_list:
            for rho in corr_list:
                for snr in snrdb_list[NT]:
                    big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][rho][snr]
                    accr = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device)
                    result_dict[NT][rho][snr] =  result_dict[NT][rho][snr] + (accr-result_dict[NT][rho][snr])/float(iter+1.0)

        if (save_result):
            with open(validtn_filename, 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('Intermediate Test results saved at : ', validtn_filename)
        print('Big Validtn result, Accr for 16 : ', result_dict[16])
        print('Big Validation resut, Accr for 32 : ', result_dict[32])


def generate_big_validtn_data(generator, batch_size):
    validtn_data_dict = {int(NT):{rho:{} for rho in corr_list} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for rho in corr_list:
            for snr in snrdb_list[NT]:
                big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
                validtn_data_dict[int(NT)][rho][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def test(model, generator, device):
    model.eval()

    # Testing Trained Network
    validate_model(model, generator, device, save_result)

generator = sample_generator(validtn_batch_size, mod_n, NR)
device = 'cuda'
model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NR, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
model = model.to(device=device)

checkpoint = torch.load(model_filename)
model.load_state_dict(checkpoint['model_state_dict'])
print('*******Successfully loaded pre-trained model*********** from directory : ', model_filename)

test(model, generator, device)
print('******************************** Now Testing **********************************************')


*******Successfully loaded pre-trained model*********** from directory :  ./validtn_results/model.pth
Big Validtn result, Accr for 16 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.9306125, 12.0: 0.9727375, 13.0: 0.9908125, 14.0: 0.9976625, 15.0: 0.999375, 16.0: 0.9998375, 17.0: 0.99995, 18.0: 1.0}), 0.7: defaultdict(<class 'float'>, {11.0: 0.8429, 12.0: 0.9090375, 13.0: 0.962825, 14.0: 0.9870875, 15.0: 0.9962375, 16.0: 0.9991125, 17.0: 0.99975, 18.0: 0.9999375})}
Big Validation resut, Accr for 32 :  {0.6: defaultdict(<class 'float'>, {15.0: 0.842275, 16.0: 0.91753125, 17.0: 0.96865625, 18.0: 0.98995, 19.0: 0.99790625, 20.0: 0.99968125, 21.0: 0.99989375, 22.0: 0.99986875, 23.0: 1.0}), 0.7: defaultdict(<class 'float'>, {15.0: 0.666425, 16.0: 0.7496125, 17.0: 0.84105625, 18.0: 0.9145375, 19.0: 0.9643375, 20.0: 0.987375, 21.0: 0.995925, 22.0: 0.99865, 23.0: 0.99943125})}
Big Validtn result, Accr for 16 :  {0.6: defaultdict(<class 'float'>, {11.0: 0.93065, 12.0: 0.9724312500000001, 13.0: 0

KeyboardInterrupt: 

In [None]:
def attn_mask(H, NT):
    min_value = 1
    attn_weights = torch.zeros((n_head * H.shape[0], NT, NT))
    Hcomplex = H[:,0:NR,0:NT] +1j*H[:,NR:,0:NT]

    corr = torch.zeros((H.shape[0], NT, NT))

    for nt in range(0, NT):
        for nt2 in range(0,NT):
            if nt != nt2:
                norm1 = torch.bmm(torch.conj(torch.transpose(Hcomplex[:,:,nt:nt+1],2,1)), Hcomplex[:,:,nt:nt+1])
                norm2 = torch.bmm(torch.conj(torch.transpose(Hcomplex[:,:,nt2:nt2+1],2,1)), Hcomplex[:,:,nt2:nt2+1])
                corr[:,nt:nt+1,nt2:nt2+1] = torch.abs(torch.divide(torch.bmm(torch.conj(torch.transpose(Hcomplex[:,:,nt2:nt2+1],2,1)), Hcomplex[:,:,nt:nt+1]), torch.sqrt(norm1*norm2)))


    for bs in range(0, H.shape[0]):
        for nt in range(0, NT):
            for nt2 in range(0,NT):
                if nt != nt2:
        #             torch.where(corr > 0, x, 0.)
                    if corr[bs,nt,nt2] < min_value:
                        attn_weights[0 + n_head*bs: n_head + n_head * bs,nt,nt2] = float('-inf')
                    else:
                        attn_weights[0 + n_head*bs: n_head + n_head * bs,nt,nt2] = float(0.0)
                        
    return attn_weights

In [None]:
NT = 6
rho = 0.6

mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, corr_flag, None, batch_corr, rho_low, rho_high)
device = 'cuda'
criterion = nn.CrossEntropyLoss().to(device=device)
mini_validation(model, mini_validation_dict, 0, device, criterion, save_to_file)
del mini_validation_dict

In [None]:
validtn_batch_size = 1000
validtn_iter = 500

corr_list = np.asarray([0.60, 0.6])
from collections import defaultdict

def generate_big_validtn_data(generator, batch_size):
    validtn_data_dict = {int(NT):{rho:{} for rho in corr_list} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for rho in corr_list:
            for snr in snrdb_list[NT]:
                big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho)
                validtn_data_dict[int(NT)][rho][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device)
        validtn_y = validtn_y.to(device=device)
        validtn_noise_sigma = validtn_noise_sigma.to(device=device)
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_out, validtn_noise_sigma

    return accr


def validate_model(model, generator, device, save_result=True):
    result_dict = {int(NT):{rho:defaultdict(float) for rho in corr_list} for NT in validtn_NT_list}
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size)
        for NT in validtn_NT_list:
            for rho in corr_list:
                for snr in snrdb_list[NT]:
                    big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][rho][snr]
                    accr = validate_model_given_data(model, big_validtn_H.float(), big_validtn_y.float(), big_validtn_j_indices.float(), big_noise_sigma.float(), device)
                    ###--------------------------------------------------------------------------------###       
                    big_validtn_H = big_validtn_H.to(device=device).float()
                    big_validtn_y = big_validtn_y.to(device=device).float()
                    big_noise_sigma = big_noise_sigma.to(device=device).float()
                    big_validtn_j_indices = big_validtn_j_indices.to(device=device)
                    out = model.forward(big_validtn_H, big_validtn_y, big_noise_sigma)[-1].permute(1,2,0)
                    out = out.argmax(dim=1)

                    index_wrongg = []
                    for ii in range(out.shape[0]):
                        if torch.sum(out[ii:ii+1,:] != big_validtn_j_indices[ii:ii+1,:]) != 0:
                            index_wrongg.append(ii)
                    
                    accr = (validtn_batch_size - len(index_wrongg)) / validtn_batch_size
                    ###--------------------------------------------------------------------------------###    

                    result_dict[int(NT)][rho][snr] =  result_dict[int(NT)][rho][snr] + (accr-result_dict[int(NT)][rho][snr])/float(iter+1.0)

        if (save_result):
            with open(validtn_filename, 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            print('Intermediate Test results saved at : ', validtn_filename)
        print('Big Validtn result, Acc for 6 : ', result_dict[6])
#         print('Big Validation resut, Accr for 32 : ', result_dict[32])

def test(model, generator, device):
    model.eval()

    # Testing Trained Network
    validate_model(model, generator, device, False)



generator = sample_generator(validtn_batch_size, mod_n, NR)
device = 'cuda'
# model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NR, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
# model = model.to(device=device)

# checkpoint = torch.load(model_filename)
# model.load_state_dict(checkpoint['model_state_dict'])
# print('*******Successfully loaded pre-trained model*********** from directory : ', model_filename)

test(model, generator, device)
print('******************************** Now Testing **********************************************')

In [None]:
import torch
import numpy as np
import torch.nn as nn
import pickle
import os
import collections

from sample_generator import sample_generator
from iterative_classifier import iterative_classifier
from matrix_models import *

# Parameters
NR = 12
NT_list = np.arange(6,7)
NT_prob = NT_list/NT_list.sum()
mod_n = 4
d_transmitter_encoding = NR
d_model = 512
n_head = 4
nhid = d_model*4
nlayers = 16
dropout = 0.0

epoch_size = 5000
train_iter = 130*epoch_size

# Batch sizes for training and validation sets
train_batch_size = 256
mini_validtn_batch_size = 256
validtn_batch_size = 5000
validtn_iter = 2000

learning_rate = 1e-4
QR = True
corr_flag = True
batch_corr = True
rho_low = 0.55
rho_high = 0.75

validtn_NT_list = np.asarray([6, 6])
snrdb_list = {6:np.arange(10.0, 21.0), 32:np.arange(16.0, 27.0)}
factor_list = (validtn_NT_list/validtn_NT_list.sum())/snrdb_list[6].size

model_filename = './validtn_results/model.pth'
curr_accr = './validtn_results/curr_accr.txt'
load_pretrained_model = False
save_interim_model = False
save_to_file = False

def get_snr_range(NT):
    peak = NT*(5.0/16.0) + 6.0
    snr_low = peak
    snr_high = peak+10.0
    return (snr_low, snr_high)

def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    del out
    return accuracy/j_indices.numel()

def loss_function(criterion, out, j_indices):
    out = torch.cat(out, dim=1).permute(1,2,0)
    j_indices = j_indices.repeat(nlayers, 1)
    loss = criterion(out, j_indices)
    del out, j_indices
    return loss

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device, criterion=None):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device)
        validtn_y = validtn_y.to(device=device)
        validtn_noise_sigma = validtn_noise_sigma.to(device=device)
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        if (criterion):
            validtn_j_indices = validtn_j_indices.to(device=device)
            loss = loss_function(criterion, validtn_out, validtn_j_indices)
            validtn_j_indices = validtn_j_indices.to(device='cpu')

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_noise_sigma, validtn_out, validtn_j_indices

        if (criterion):
            return accr, loss.item()
        else:
            return accr, None

def mini_validation(model, mini_validation_dict, i, device, criterion=None, save_to_file=True):
    result_dict = {int(NT):{} for NT in validtn_NT_list}
    loss_list = []
    for index,NT in enumerate(validtn_NT_list):
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = mini_validation_dict[NT][snr]
            accr, loss = validate_model_given_data(model, big_validtn_H.float(), big_validtn_y.float(), big_validtn_j_indices, big_noise_sigma.float(), device, criterion)
            result_dict[NT][snr] = accr
            loss_list.append(loss*factor_list[index])

    print('Validtn result, Accr for 6 : ', result_dict[6])
#     print('Validation resut, Accr for 32 : ', result_dict[32])
    if (save_to_file):
        with open(curr_accr, 'w') as f:
            print((i, result_dict), file=f)
        print('Saved intermediate validation results at : ', curr_accr)

    if (criterion):
        return np.sum(loss_list)

def generate_big_validtn_data(generator, batch_size, QR, Cu):
    validtn_data_dict = {int(NT):{} for NT in validtn_NT_list}
    for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
#             big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, correlated_flag=corr_flag, rho=rho, batch_corr=batch_corr,rho_low=rho_low, rho_high=rho_high)
            big_validtn_H, big_validtn_y, big_validtn_x, big_validtn_j_indices, big_noise_sigma = generator.give_batch_data(int(NT), snr_db_min=snr, snr_db_max=snr, batch_size=batch_size, QR = QR, Cu = Cu)            
            validtn_data_dict[int(NT)][snr] = (big_validtn_H, big_validtn_y , big_validtn_j_indices, big_noise_sigma)
    return validtn_data_dict

def save_model_func(model, optimizer):
    torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, model_filename)
    print('******Model Saved********** at directory : ', model_filename)


def train(Cu, model, optimizer, lr_scheduler, generator , device='cpu'):
    
    mini_validation_dict = generate_big_validtn_data(generator, mini_validtn_batch_size, QR = QR, Cu = Cu)
    # Fix loss criterion
    criterion = nn.CrossEntropyLoss().to(device=device)
    model.train()
    epoch_count = 1
    
    Q,R,HH = createQR(Cu, train_batch_size)
    H = torch.tensor(R).to(device=device).double()

    for i in range(1, train_iter+1):
        print(i)
        # Randomly select number of transmitters
        NT = np.random.choice(NT_list, p=NT_prob)
        rho = np.random.triangular(rho_low, rho_high, rho_high)

        snr_low, snr_high = get_snr_range(NT)
#         H, y, j_indices, noise_sigma = generator.give_batch_data(NT, snr_db_min=snr_low, snr_db_max=snr_high, batch_size=None, correlated_flag=corr_flag, rho=rho)

        if (i%50==0):
            Q,R,HH = createQR(Cu, train_batch_size)
            H = torch.tensor(R).to(device=device)
        
        y, x, j_indices, noise_sigma = generator.give_batch_data_Hinput(H, NT, snr_db_min=snrdb_list[NT][0], snr_db_max=snrdb_list[NT][-1], batch_size=train_batch_size)

        H = H.to(device=device).float()
        y = y.to(device=device).float()
        noise_sigma = noise_sigma.to(device=device).float()

        out = model.forward(H,y, noise_sigma)

        del y, noise_sigma

        j_indices = j_indices.to(device=device)
        loss = loss_function(criterion, out, j_indices)
        del j_indices, out
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_item = loss.item()
        del loss

        if (i%epoch_size==0):
            print('iteration number : ', i, 'Epoch : ', epoch_count, 'User : ', NT, 'loss : ', loss_item)
            print('Now validating')

            model.eval()
            mini_validtn_loss = mini_validation(model, mini_validation_dict, i, device, criterion, save_to_file)
            print('Mini validation loss : ', mini_validtn_loss)
            lr_scheduler.step(mini_validtn_loss)

            model.train()
#             if (save_interim_model):
#                 save_model_func(model, optimizer)

            epoch_count = epoch_count+1


In [None]:
NT = np.random.choice(NT_list, p=NT_prob)
generator = sample_generator(train_batch_size, mod_n, NT)
device = 'cuda'
model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NT, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
model = model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

# if (load_pretrained_model):
#     checkpoint = torch.load(model_filename)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, True, 0.0001, 'rel', 0, 0, 1e-08)
#     print('*******Successfully loaded pre-trained model***********')
# else:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min', 0.91, 0, 0.0001, 'rel', 0, 0, 1e-08, verbose = True)

with open ('/home/nicoz/GNN_project/learning_based/Torch/Tests_sets/Test_set2', 'rb') as fp:
    R_test, Cu = pkl.load(fp)
train(Cu, model, optimizer, lr_scheduler, generator, device)
print('******************************** Now Testing **********************************************')


In [None]:
def accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_indices(indices, mod_n):
    real_indices = (indices//np.sqrt(mod_n)).to(dtype=torch.int32)
    imag_indices = (indices%np.sqrt(mod_n)).to(dtype=torch.int32)
    joint_bit_indices = torch.cat((real_indices, imag_indices), dim=-1)
    return joint_bit_indices

def sym_accuracy(out, j_indices):
    accuracy = (out == j_indices).sum().to(dtype=torch.float32)
    return accuracy.item()/out.numel()

def bit_accuracy(out, j_indices):
    out = out.permute(1,2,0)
    out = out.argmax(dim=1)
    bit_out_indices = bit_indices(out, mod_n)
    bit_j_indices = bit_indices(j_indices, mod_n)
    return sym_accuracy(bit_out_indices, bit_j_indices)

def validate_model_given_data(model, validtn_H, validtn_y, validtn_j_indices, validtn_noise_sigma, device):
    with torch.no_grad():

        validtn_H = validtn_H.to(device=device).float()
        validtn_y = validtn_y.to(device=device).float()
        validtn_noise_sigma = validtn_noise_sigma.to(device=device).float()
        validtn_out = model.forward(validtn_H, validtn_y, validtn_noise_sigma)

        validtn_out = validtn_out[-1].to(device='cpu')
        accr = accuracy(validtn_out, validtn_j_indices)

        del validtn_H, validtn_y, validtn_out, validtn_noise_sigma

    return accr


def validate_model(model, generator, device, save_result=True, Cu = None):
    result_dict = {int(NT): {snr: float(0) for snr in snrdb_list[6]} for NT in validtn_NT_list}    
    for iter in range(validtn_iter):
        validtn_data_dict = generate_big_validtn_data(generator, validtn_batch_size, QR = QR, Cu = Cu)
#         for NT in validtn_NT_list:
        for snr in snrdb_list[NT]:
            big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma = validtn_data_dict[NT][snr]
            accr = validate_model_given_data(model, big_validtn_H, big_validtn_y, big_validtn_j_indices, big_noise_sigma, device)
            result_dict[NT][snr] = result_dict[NT][snr] + (accr-result_dict[NT][snr])/float(iter+1.0)
#                 result_dict[NT][snr] = result_dict[NT][snr] + accr
            print('Big Validtn result, Accr for 16 : ', accr)
        if (save_result):
            with open('REMIMO_result2_testSet2', 'wb') as handle:
                pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
#             print('Intermediate Test results saved at : ', validtn_filename)
        print('Big Validtn result, Accr for 16 : ', result_dict[6])
    #         print('Big Validation resut, Accr for 32 : ', result_dict[32])


def test(model, generator, device, Cu = None):
    model.eval()

    # Testing Trained Network
    validate_model(model, generator, device, True, Cu = Cu)

# corr_list = np.asarray([0.60, 0.70])
NT = np.random.choice(NT_list, p=NT_prob)
generator = sample_generator(validtn_batch_size, mod_n, NT)
# device = 'cuda'
# model = iterative_classifier(d_model, n_head, nhid, nlayers, mod_n, NT, d_transmitter_encoding, generator.real_QAM_const, generator.imag_QAM_const, generator.constellation, device, dropout)
# model = model.to(device=device)

# checkpoint = torch.load('re_mimo_localScatteringModel.pth')
# model.load_state_dict(checkpoint)
# print('*******Successfully loaded pre-trained model*********** from directory : ', 're_mimo_localScatteringModel.pth')

test(model, generator, device, Cu = Cu)
print('******************************** Now Testing **********************************************')

