In [48]:
from itertools import product
import itertools
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathos.multiprocessing import ProcessingPool as Pool
import scipy
import os

import itertools
import math
from math import floor, log2, sqrt
from numpy.random import randn, rand
from scipy.signal import lfilter
from numpy.fft import ifft, fft


class iterParams:
    def __init__(self,
                 iter_metric: str = 'BER_limit',
                 nSym: int = 100,
                 SJR_dB: int = 20,
                 jammer_type: str = "random",
                 M: int = 2,
                 N: int = 48,
                 Ncp: int = 16,
                 n: int = 4,
                 k: int = 2,
                 L: int = 5,
                 ICSI: int = 0,
                 Interleaver: int = 0,
                 EbNo_dB: int = 20,
                 detection_threshold = 0.3,
                 BER_limit_dB = 30,
                 J_power_max_dB: int = 0,
                 J_power_avg_dB: int = 20,
                 n_MC: int = 500,
                 n_training: int = 500,
                 n_test: int = 500,
                 comm_mode: str = 'OFDM-IM'):
        self.iter_metric = iter_metric
        self.nSym = nSym  # Number of OFDM Symbols to transmit
        self.SJR_dB = SJR_dB  # Signal to jamming ratio
        self.jammer_type = jammer_type  #
        self.M = M  # Modulation order (2,4,8, 16)
        self.N = N  # FFT size or total number of subcarriers
        self.Ncp = Ncp  # Number of symbols allocated to cyclic prefix
        self.n = n  # #subcarriers in a subblock
        self.k = k  # #active subcarriers in a subblock
        self.L = L  # Channel order
        self.ICSI = ICSI  # 1--> imperfect CSI, 0--> perfect CSI
        self.Interleaver = Interleaver  # 0 --> off, 1 --> on
        self.EbNo_dB = EbNo_dB
        self.detection_threshold = detection_threshold
        self.BER_limit_dB = BER_limit_dB
        self.J_power_max_dB = J_power_max_dB
        self.J_power_avg_dB = J_power_avg_dB
        self.n_MC = n_MC
        self.n_training = n_training
        self.n_test = n_test
        self.comm_mode = comm_mode

    def __iter__(self):
        if self.iter_metric == "BER_limit":
            self.BER_limit_dB = 10
        elif self.iter_metric == "SNR":
            self.EbNo_dB = 10
        elif self.iter_metric == "detection_threshold":
            self.detection_threshold = 0.1
        elif self.iter_metric == "J_power_max_dB":
            self.J_power_max_dB = 0
        elif self.iter_metric == "J_power_avg_dB":
            self.J_power_avg_dB = 10
        elif self.iter_metric == "n_training":
            self.n_training = 50
        return self

    def __next__(self):
        x = self.BER_limit_dB
        y = self.EbNo_dB
        z = self.detection_threshold
        v = self.J_power_max_dB
        f = self.J_power_avg_dB
        s = self.n_training
        if self.iter_metric == "BER_limit":
            if self.BER_limit_dB <= 40:
                self.BER_limit_dB += 2
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration
        elif self.iter_metric == "SNR":
            if self.EbNo_dB <= 40:
                self.EbNo_dB += 2
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration
        elif self.iter_metric == "detection_threshold":
            if self.detection_threshold <= 0.7:
                self.detection_threshold += 0.1
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration
        elif self.iter_metric == "J_power_max_dB":
            if self.J_power_max_dB <= 40:
                self.J_power_max_dB += 2
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration
        elif self.iter_metric == "J_power_avg_dB":
            if self.J_power_avg_dB <= 40:
                self.J_power_avg_dB += 2
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration
        elif self.iter_metric == "n_training":
            if self.n_training <= 1000:
                self.n_training += 50
                return x, y, self.nSym, self.SJR_dB, self.jammer_type, self.M, self.N, self.Ncp, self.n, self.k, self.L, self.ICSI, self.Interleaver, z, v, f, self.n_MC, s, self.n_test, self.comm_mode
            else:
                raise StopIteration


def unpackbits(x, num_bits):
    if np.issubdtype(x.dtype, np.floating):
        raise ValueError("numpy data type needs to be int-like")
    xshape = list(x.shape)
    x = x.reshape([-1, 1])
    mask = 2 ** np.arange(num_bits, dtype=x.dtype).reshape([1, num_bits])
    return np.flip((x & mask).astype(bool).astype(int).reshape(xshape + [num_bits]))


def FUN_OFDM_IM_Jammer_Sim(EbNo_dB, nSym, SJR_dB, jamming_mode, M, N, Ncp, n, k, L,
                           ICSI, Interleaver, detection_threshold):
    g = int(N / n)  # Total number of subblocks
    K = int(k * g)  # Total number of active subcarriers
    Total_length = (Ncp + N)  # Total length of each frame

    # Constellation maps
    if M == 2:
        constmap = sqrt(1) * np.array([-1, 1])
    elif M == 4:  # Gray mapping
        constmap = sqrt(1 / 2) * np.array([1 + 1j, -1 + 1j, 1 - 1j, - 1 - 1j])
    elif M == 8:
        constmap = sqrt(1 / 5) * np.array([-3 / sqrt(2) + 1j * 3 / sqrt(2), 3 / sqrt(2) + 1j * 3 / sqrt(2),
                                           -3 / sqrt(2) - 1j * 3 / sqrt(2), 3 / sqrt(2) - 1j * 3 / sqrt(2), 1j, 1, -1,
                                           -1j])
    elif M == 16:  # Gray mapping
        constmap = sqrt(1 / 10) * np.array([-3 - 3j, - 3 - 1j, - 3 + 3j, - 3 + 1j, - 1 - 3j,
                                            - 1 - 1j, - 1 + 3j, - 1 + 1j, 3 - 3j, 3 - 1j, 3 + 3j, 3 + 1j, 1 - 3j,
                                            1 - 1j, 1 + 3j, 1 + 1j])
    else:
        raise ValueError("bad M value")
    
    # IM combinations
    if k == 2 and n == 4:
        p1 = 2
        scComb = np.array([[1, 2], [2, 3], [3, 4], [1, 4]])
    elif k == 3 and n == 4:
        p1 = 2
        scComb = np.array([[1, 2, 3],
                           [2, 3, 4],
                           [1, 3, 4],
                           [1, 2, 4]])
    elif k == 2 and n == 6:
        p1 = 3
        scComb = np.array([[1, 2], [1, 3], [1, 6],
                           [2, 3], [3, 4], [3, 5],
                           [5, 6], [4, 6]])
    elif k == 3 and n == 6:
        p1 = 4
        scComb = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 5], [1, 2, 6],
                           [1, 3, 4], [1, 4, 5], [1, 4, 6], [1, 5, 6],
                           [2, 3, 4], [2, 3, 5], [2, 3, 6], [2, 5, 6],
                           [3, 4, 5], [3, 4, 6], [3, 5, 6], [4, 5, 6]])
    elif k == 1 and n == 1:
        p1 = 0
        scComb = np.array([[1]])
    else:
        raise ValueError("bad n,k value")

    c = 2 ** p1  # Number of active subcarrier combinations in a subblock
    p2 = int(k * np.log2(M))  # Number of bits carried by modulation in a subblock
    p = p1 + p2  # Number of bits carried in a subblock
    m = g * p  # Number of bits carried in a symbol
    Eb = (N + Ncp) / m  # Bit energy

    nErr = 0

    d = np.zeros(N)
    if jamming_mode == 'Edge':
        rho = 0.5
        d[0:12] = 1
        d[36:48] = 1
    elif jamming_mode == 'BJ':
        rho = 1
        d = np.ones(N)
    elif jamming_mode == 'PBJ':
        rho = 0.5
        ss = int(N * rho)
        d[0:ss] = 1
    else:
        rho = 0

    Nj_t = (10 ** (-SJR_dB / 10)) * Eb  # t-domain jamming variance
    if rho != 0:
        Nj_f = Nj_t * (K / N / rho)  # f-domain jamming variance
    else:
        Nj_f = 0

    if rho != 0:
        D = sqrt(K * Nj_t) * np.diag(d) / sqrt(sum(d ** 2))
    else:
        D = 0

    # symbol space
    symbol_space = np.zeros((c * M ** k, n), dtype=np.complex128)
    x = np.array(list(itertools.combinations(np.tile(range(1, M + 1), (1, k))[0], k)))
    perm = np.unique(x, axis=0)
    for ll in range(1, c + 1):
        symbol_space[(ll - 1) * M ** k: ll * M ** k, scComb[ll - 1, :] - 1] = constmap[perm - 1].copy()

    Y_total = np.zeros((nSym, N), dtype=np.complex128)

    for symInd in range(1, nSym + 1):
        data = np.random.randint(2, size=m)  # random data generation
        X_Freq = np.zeros((1, N), dtype=np.complex128)  # signal to be inverse fast fourrier transformed
        # data is treated to be in frequency domain

        # subblock generation
        for jj in range(1, g + 1):
            temp = data[p * (jj - 1): p * jj]
            X_Freq[0, n * (jj - 1): n * jj] = symbol_space[int(str(temp * 1)[1:-1].replace(" ", ""), 2), :]
        #  -----------------------IFFT block - -------------------------

        x_time = N / sqrt(K) * ifft(X_Freq, n=N, axis=-1)

        #  adding cyclic prefix
        x_time_cp = np.hstack((x_time[:, N - Ncp: N], x_time))  # [x_time[:, N - Ncp : N], x_time]

        #  ------------------ Channel Modeling - ---------------------
        noise = sqrt(0.5) * (rand(1, Total_length) + 1j * rand(1, Total_length))
        No_t = (10 ** (-EbNo_dB / 10)) * Eb
        No_f = K / N * No_t
        # h_NLOS = sqrt(0.5 / (L + 1)) * (rand(L + 1, 1) + 1j * rand(L + 1, 1))  # Rayleigh channel
        # h_LOS = 1
        # h_time = sqrt((L0**-1)*(dist**-PLexp)) * (sqrt(RicianF/(RicianF+1))*h_LOS + sqrt(1/(RicianF+1))*h_NLOS)
        h_time = sqrt(0.5 / (L + 1)) * (randn(L + 1, 1) + 1j * randn(L + 1, 1))

        # h_time = sqrt(0.5 / (L + 1)) * (rand(L + 1, 1) + 1j * rand(L + 1, 1))  # Rayleigh channel
        # jamming signal can be added in frequency domain
        jam_f = sqrt(0.5) * (randn(1, N) + 1j * randn(1, N)) @ D

        y_time = lfilter(h_time.flatten(), 1, x_time_cp) + sqrt(No_t) * noise

        #        ----------------- Receiver ------------------------
        # Removing cyclic prefix
        y_Parallel = y_time[:, Ncp:(N + Ncp)]
        # % FFT Block
        Y_fre = sqrt(K) / N * (fft(y_Parallel, n=N, axis=-1)) + jam_f
        # Y_total[symInd-1] = Y_fre

        if ICSI == 1:
            h_noise = sqrt(0.5 * (No_f + Nj_f)) * (randn(1, N) + 1j * randn(1, N))
            if L == 0:
                H_fre = (fft(h_time.transpose(), n=N)) + h_noise
            else:
                H_fre = (fft(h_time.transpose(), n=N)) + h_noise
        else:
            if L == 0:
                H_fre = (fft(h_time.transpose(), n=N))
            else:
                H_fre = (fft(h_time.transpose(), n=N))

        # ----------------- ML Algorithm ---------------------
        for hh in range(1, g + 1):
            metrik = np.sum(abs(np.tile(Y_fre[0, (hh - 1) * n: n + (hh - 1) * n], (c * M ** k, 1)) - np.tile(
                H_fre[0, (hh - 1) * n: n + (hh - 1) * n], (c * M ** k, 1)) * symbol_space) ** 2, axis=1)
            indexx = np.argmin(metrik)
            # ------------------- Error Counting -------------------
            data_hat = unpackbits(np.array([[indexx]]), p)
            numErrors = np.count_nonzero(data[((hh - 1) * p):p + ((hh - 1) * p)] - data_hat)
            nErr = nErr + numErrors
    # ----------------- Jammer Detection Algorithm ---------------------
    # detection_array = np.zeros(g)
    # for jj in range(1, g + 1):
    #    detection_array[jj-1] =np.sort(abs(Y_total[:,n * (jj - 1): n * jj]))[:,0].var() > detection_threshold

    return nErr, (nSym * m)


def FUN_NOISE_MOD_Jammer_Sim(EbNo_dB, nSym, SJR_dB, jamming_mode, M, N, Ncp, n, k, L,
                             ICSI, Interleaver, detection_threshold):
    g = int(N / n)  # total number of subblocks
    K = int(k * g)  # total number of active subcarriers
    Total_length = (Ncp + N)  # Total length of each frame

    # constellation maps
    constmap = sqrt(1) * np.array([0, 1])
    p1 = 0
    scComb = np.array([[1]])

    c = 2 ** p1  # number of active subcarrier combinations in a subblock
    p2 = int(k * np.log2(M))  # number of bits carried by modulation in a subblock
    p = p1 + p2  # number of bits carried in a subblock
    m = g * p  # number of bits carried in a symbol
    Eb = (N + Ncp) / m  # bit energy

    nErr = 0

    d = np.zeros(N)
    if jamming_mode == 'Edge':
        rho = 0.5
        d[0:12] = 1
        d[36:48] = 1
    elif jamming_mode == 'BJ':
        rho = 1
        d = np.ones(N)
    elif jamming_mode == 'PBJ':
        rho = 0.5
        ss = int(N * rho)
        d[0:ss] = 1
    else:
        rho = 0

    Nj_t = (10 ** (-SJR_dB / 10)) * Eb  # t-domain jamming variance
    if rho != 0:
        Nj_f = Nj_t * (K / N / rho)  # f-domain jamming variance
    else:
        Nj_f = 0

    if rho != 0:
        D = sqrt(K * Nj_t) * np.diag(d) / sqrt(sum(d ** 2))
    else:
        D = 0

    # symbol space
    symbol_space = np.zeros((c * M ** k, n), dtype=np.complex128)
    x = np.array(list(itertools.combinations(np.tile(range(1, M + 1), (1, k))[0], k)))
    perm = np.unique(x, axis=0)
    for ll in range(1, c + 1):
        symbol_space[(ll - 1) * M ** k: ll * M ** k, scComb[ll - 1, :] - 1] = constmap[perm - 1].copy()

    Y_total = np.zeros((nSym, N), dtype=np.complex128)
    data = np.random.randint(2, size=m)  # random data generation

    for symInd in range(1, nSym + 1):
        # data = np.random.randint(2, size=m)  # random data generation
        X_Freq = np.zeros((1, N), dtype=np.complex128)  # signal to be inverse fast fourrier transformed
        # data is treated to be in frequency domain

        # subblock generation
        for jj in range(1, g + 1):
            temp = data[p * (jj - 1): p * jj]
            X_Freq[0, n * (jj - 1): n * jj] = symbol_space[int(str(temp * 1)[1:-1].replace(" ", ""), 2), :]
        #  -----------------------IFFT block - -------------------------

        x_time = N / sqrt(K) * ifft(X_Freq, n=N, axis=-1)

        #  adding cyclic prefix
        x_time_cp = np.hstack((x_time[:, N - Ncp: N], x_time))  # [x_time[:, N - Ncp : N], x_time]

        #  ------------------ Channel Modeling - ---------------------
        noise = sqrt(0.5) * (rand(1, Total_length) + 1j * rand(1, Total_length))
        No_t = (10 ** (-EbNo_dB / 10)) * Eb
        No_f = K / N * No_t
        # h_NLOS = sqrt(0.5 / (L + 1)) * (rand(L + 1, 1) + 1j * rand(L + 1, 1))  # Rayleigh channel
        # h_LOS = 1
        # h_time = sqrt((L0**-1)*(dist**-PLexp)) * (sqrt(RicianF/(RicianF+1))*h_LOS + sqrt(1/(RicianF+1))*h_NLOS)
        h_time = sqrt(0.5 / (L + 1)) * (randn(L + 1, 1) + 1j * randn(L + 1, 1))

        # h_time = sqrt(0.5 / (L + 1)) * (rand(L + 1, 1) + 1j * rand(L + 1, 1))  # Rayleigh channel
        # jamming signal can be added in frequency domain
        jam_f = sqrt(0.5) * (randn(1, N) + 1j * randn(1, N)) @ D

        y_time = lfilter(h_time.flatten(), 1, x_time_cp) + sqrt(No_t) * noise

        #        ----------------- Receiver ------------------------
        # Removing cyclic prefix
        y_Parallel = y_time[:, Ncp:(N + Ncp)]
        # % FFT Block
        Y_fre = sqrt(K) / N * (fft(y_Parallel, n=N, axis=-1)) + jam_f
        Y_total[symInd - 1] = Y_fre

    # ----------------- Energy detection Algorithm ---------------------
    var_array = Y_total.var(axis=0)
    detection_threshold = (var_array.min() + var_array.max()) / 2
    data_hat = var_array > detection_threshold
    nErr = np.count_nonzero(data - data_hat)

    return nErr, m


class Channel:
    def __init__(self, BER_limit_dB, EbNo_dB, nSym, SJR_dB, jammer_type, M, N, Ncp, n, k, L,
                 ICSI, Interleaver, detection_threshold, J_power_max_dB, J_power_avg_dB, comm_mode):

        if comm_mode in ['OFDM-IM', 'OFDM-IM-BER-policy', 'OFDM-IM-SJR-policy', 'OFDM-IM-random-policy']:
            self.action_space = ['112', '114', '118',
                                 '422', '424', '428',
                                 '432', '434', '438',
                                 '622', '624', '628',
                                 '632', '634', '638',
                                 'noisemod', 'idle']
        elif comm_mode in ['OFDM', 'OFDM-BER-policy', 'OFDM-SJR-policy', 'OFDM-random-policy']:
            self.action_space = ['112', '114', '118',
                                 'noisemod', 'idle']
        elif comm_mode == 'OFDM-IM-without-noisemod':
            self.action_space = ['112', '114', '118',
                                 '422', '424', '428',
                                 '432', '434', '438',
                                 '622', '624', '628',
                                 '632', '634', '638',
                                 'idle']
        elif comm_mode == 'OFDM-without-noisemod':
            self.action_space = ['112', '114', '118',
                                 'idle']

        self.n_actions = len(self.action_space)
        self.caught_flag = False
        self.mode_change_flag = False
        self.undetected_counter = 0
        self.detected_counter = 0
        self.J_power_counter = 0
        self.J_power_avg_lin = (10 ** (-J_power_avg_dB / 10))

        self.BER_limit_lin = (10 ** (-BER_limit_dB / 10))
        self.EbNo_dB = EbNo_dB
        self.nSym = nSym
        self.J_power_avg_dB = J_power_avg_dB
        self.SJR_dB = J_power_avg_dB  # SJR_dB
        self.jammer_type = jammer_type
        if jammer_type == 'aggressive_bj':
            self.jamming_mode = 'BJ'
        elif jammer_type == 'aggressive_pbj':
            self.jamming_mode = 'PBJ'
        elif jammer_type == 'aggressive_edge':
            self.jamming_mode = 'Edge'
        else:
            self.jamming_mode = 'BJ'
        self.M = M
        self.N = N
        self.Ncp = Ncp
        self.n = n
        self.k = k
        self.L = L
        self.ICSI = ICSI
        self.Interleaver = Interleaver
        self.detection_threshold = detection_threshold
        self.J_power_max_dB = J_power_max_dB
        self.comm_mode = comm_mode

    def benchmark_transmitter_move(self, BER):
        if self.comm_mode in ['OFDM-IM-random-policy', 'OFDM-random-policy']:
            action = np.random.choice(self.action_space)
        elif self.comm_mode in ['OFDM-IM-BER-policy', 'OFDM-BER-policy']:
            if np.isnan(BER):  # initialization
                action = np.random.choice(self.action_space)
            elif BER <= self.BER_limit_lin:
                action = self.action_space[0]
            elif self.BER_limit_lin < BER <= self.BER_limit_lin * 20:
                action = 'noisemod'
            elif BER > self.BER_limit_lin * 20:
                action = 'idle'
            else:
                raise ValueError("bad BER value")
        elif self.comm_mode in ['OFDM-IM-SJR-policy', 'OFDM-SJR-policy']:
            if self.SJR_dB >= 15:
                action = self.action_space[0]
            elif 15 > self.SJR_dB >= 5:
                action = 'noisemod'
            elif self.SJR_dB < 5:
                action = 'idle'
            else:
                raise ValueError("bad BER value")
        else:
            action = np.random.choice(self.action_space)
        return action

    def jammer_move(self):
        if self.jammer_type == 'covert':
            if self.caught_flag:
                if not self.mode_change_flag:
                    self.SJR_dB += 5
                    self.mode_change_flag = True
                elif self.mode_change_flag:
                    if self.jamming_mode == 'BJ':
                        self.jamming_mode = 'PBJ'
                    else:
                        self.jamming_mode = 'BJ'
                    self.mode_change_flag = False
                self.caught_flag = False
            elif not self.caught_flag:
                self.undetected_counter += 1
                if self.undetected_counter >= 10:
                    self.SJR_dB -= 5
                    self.undetected_counter = 0

        elif self.jammer_type == 'random':
            self.SJR_dB = np.random.choice(np.arange(15, 36, 5))
            self.jamming_mode = np.random.choice(['BJ', 'PBJ'])

        elif self.jammer_type == 'reactive_bj':
            if self.caught_flag:
                self.SJR_dB = 500
                self.caught_flag = False
            elif not self.caught_flag:
                self.SJR_dB = self.J_power_avg_dB
                self.detected_counter += 1

        elif self.jammer_type == 'reactive_pbj':
            if self.caught_flag:
                self.SJR_dB = 500
                self.caught_flag = False
            elif not self.caught_flag:
                self.SJR_dB = self.J_power_avg_dB
                self.detected_counter += 1

        elif self.jammer_type == 'reactive_combined':
            if self.caught_flag:
                self.SJR_dB = 500
                self.caught_flag = False
            elif not self.caught_flag:
                self.SJR_dB = self.J_power_avg_dB
                self.detected_counter += 1
                if self.jamming_mode == 'BJ':
                    self.jamming_mode = 'PBJ'
                else:
                    self.jamming_mode = 'BJ'

        elif self.jammer_type == 'probabilistic':
            if self.caught_flag:
                self.SJR_dB = 500
                self.caught_flag = False
            elif not self.caught_flag:
                self.SJR_dB = \
                    np.random.choice([self.J_power_avg_dB + 3, self.J_power_avg_dB, self.J_power_avg_dB - 3], 1,
                                     p=[1 / 6, 3 / 6, 2 / 6])[0]
                self.detected_counter += 1

        if self.detected_counter >= 3:
            self.SJR_dB -= 5
            self.detected_counter = 0

    def step(self, action):
        s = (self.jamming_mode, self.SJR_dB)
        action_params = list(action)
        nErr = 0
        nBits = 1

        if action_params[0] == 'i':  # Idle
            self.caught_flag = True
            self.jammer_move()
            s_ = (self.jamming_mode, self.SJR_dB)

        elif action_params[0] == 'n':  # Noise Modulation
            nErr, nBits = FUN_NOISE_MOD_Jammer_Sim(EbNo_dB=self.EbNo_dB,
                                                   nSym=self.nSym,
                                                   SJR_dB=self.SJR_dB,
                                                   jamming_mode=self.jamming_mode,
                                                   M=2,
                                                   N=self.N,
                                                   Ncp=self.Ncp,
                                                   n=1,
                                                   k=1,
                                                   L=self.L,
                                                   ICSI=self.ICSI,
                                                   Interleaver=self.Interleaver,
                                                   detection_threshold=self.detection_threshold)
            self.caught_flag = True
            self.jammer_move()
            s_ = (self.jamming_mode, self.SJR_dB)

        else:  # OFDM-IM
            n = int(action_params[0])
            k = int(action_params[1])
            M = int(action_params[2])

            nErr, nBits = FUN_OFDM_IM_Jammer_Sim(EbNo_dB=self.EbNo_dB,
                                                 nSym=self.nSym,
                                                 SJR_dB=self.SJR_dB,
                                                 jamming_mode=self.jamming_mode,
                                                 M=M,
                                                 N=self.N,
                                                 Ncp=self.Ncp,
                                                 n=n,
                                                 k=k,
                                                 L=self.L,
                                                 ICSI=self.ICSI,
                                                 Interleaver=self.Interleaver,
                                                 detection_threshold=self.detection_threshold)

            self.jammer_move()
            s_ = (self.jamming_mode, self.SJR_dB)

        # reward function
        if action_params[0] == 'i':
            reward = 0
            BER = np.nan
        elif nErr / nBits < self.BER_limit_lin:
            reward = (nBits - nErr) / self.nSym
            BER = nErr / nBits
        else:
            reward = -self.N
            BER = nErr / nBits
        throughput = max(0, reward)
        return s_, reward, throughput, BER


class QLearningTable:
    def __init__(self, actions, exploration_decay_rate, learning_rate=0.01, reward_decay=0.9, e_greedy=1):
        self.actions = actions  # a list
        self.edr = exploration_decay_rate
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # action selection
        if np.random.uniform() > self.epsilon:
            # choose best action
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)

        self.epsilon = max(0.001, self.epsilon - self.edr)
        return action

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        q_target = r + self.gamma * self.q_table.loc[s_, :].max() 

        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update Q-table

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # add new state to Q table
            df_temp = pd.DataFrame(data=np.zeros((1, len(self.actions))), index=[state],
                                   columns=self.actions, dtype=np.float64)
            self.q_table = pd.concat([self.q_table, df_temp])


def training(params):
    BER_limit_dB, EbNo_dB, nSym, SJR_dB, jammer_type, M, N, Ncp, n, k, L, ICSI, Interleaver, detection_threshold, J_power_max_dB, J_power_avg_dB, n_MC, n_training, n_test, comm_mode = params

    avg_throughput = np.zeros((n_test, n_MC))
    avg_BER = np.zeros((n_test, n_MC))
    for ii in range(n_MC):
        env = Channel(BER_limit_dB, EbNo_dB, nSym, SJR_dB, jammer_type, M, N, Ncp, n, k, L, ICSI, Interleaver,
                      detection_threshold, J_power_max_dB, J_power_avg_dB, comm_mode)
        RL = QLearningTable(exploration_decay_rate=1 / n_training, actions=env.action_space)

        observation = (env.jamming_mode, env.SJR_dB)
        BER = np.nan
        if comm_mode in ['OFDM-IM', 'OFDM', 'OFDM-IM-without-noisemod', 'OFDM-without-noisemod']:
            for i in range(n_training):
                action = RL.choose_action(str(observation))
                observation_, reward, _, _ = env.step(action)
                RL.learn(str(observation), action, reward, str(observation_))
                observation = observation_

            for i in range(n_test):
                action = RL.choose_action(str(observation))
                observation_, reward, throughput, BER = env.step(action)
                RL.learn(str(observation), action, reward, str(observation_))
                observation = observation_

                avg_throughput[i, ii] = throughput
                avg_BER[i, ii] = BER

        elif comm_mode in ['OFDM-IM-BER-policy', 'OFDM-BER-policy', 'OFDM-IM-random-policy', 'OFDM-random-policy',
                           'OFDM-IM-SJR-policy', 'OFDM-SJR-policy']:
            for i in range(n_test):
                action = env.benchmark_transmitter_move(BER)
                observation_, reward, throughput, BER = env.step(action)
                observation = observation_

                avg_throughput[i, ii] = throughput
                avg_BER[i, ii] = BER

    return avg_throughput, avg_BER


def process_results(results_ofdmim_, results_ofdm_,
                    results_ofdmim_without_noisemod_, results_ofdm_without_noisemod_,
                    results_ofdmim_BER_policy_, results_ofdm_BER_policy_,
                    results_ofdmim_SJR_policy_, results_ofdm_SJR_policy_):
    results_ofdmim = {"avg_throughput": [], "avg_BER": []}
    results_ofdm = {"avg_throughput": [], "avg_BER": []}
    results_ofdmim_without_noisemod = {"avg_throughput": [], "avg_BER": []}
    results_ofdm_without_noisemod = {"avg_throughput": [], "avg_BER": []}
    results_ofdmim_BER_policy = {"avg_throughput": [], "avg_BER": []}
    results_ofdm_BER_policy = {"avg_throughput": [], "avg_BER": []}
    results_ofdmim_SJR_policy = {"avg_throughput": [], "avg_BER": []}
    results_ofdm_SJR_policy = {"avg_throughput": [], "avg_BER": []}

    for item in results_ofdmim_:
        results_ofdmim["avg_throughput"].append(item[0].mean())
        results_ofdmim["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdm_:
        results_ofdm["avg_throughput"].append(item[0].mean())
        results_ofdm["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdmim_without_noisemod_:
        results_ofdmim_without_noisemod["avg_throughput"].append(item[0].mean())
        results_ofdmim_without_noisemod["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdm_without_noisemod_:
        results_ofdm_without_noisemod["avg_throughput"].append(item[0].mean())
        results_ofdm_without_noisemod["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdmim_BER_policy_:
        results_ofdmim_BER_policy["avg_throughput"].append(item[0].mean())
        results_ofdmim_BER_policy["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdm_BER_policy_:
        results_ofdm_BER_policy["avg_throughput"].append(item[0].mean())
        results_ofdm_BER_policy["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdmim_SJR_policy_:
        results_ofdmim_SJR_policy["avg_throughput"].append(item[0].mean())
        results_ofdmim_SJR_policy["avg_BER"].append(np.nanmean(item[1]))
    for item in results_ofdm_SJR_policy_:
        results_ofdm_SJR_policy["avg_throughput"].append(item[0].mean())
        results_ofdm_SJR_policy["avg_BER"].append(np.nanmean(item[1]))

    return results_ofdmim, results_ofdm, results_ofdmim_without_noisemod, results_ofdm_without_noisemod, results_ofdmim_BER_policy, results_ofdm_BER_policy, results_ofdmim_SJR_policy, results_ofdm_SJR_policy


def plot_results(x_label, y_label, results_ofdmim, results_ofdm,
                 results_ofdmim_without_noisemod, results_ofdm_without_noisemod,
                 results_ofdmim_BER_policy, results_ofdm_BER_policy,
                 results_ofdmim_SJR_policy, results_ofdm_SJR_policy):
    if x_label == "detection_threshold":
        y_set = np.linspace(0.1, 0.7, num=7)
    elif x_label == "n_training":
        y_set = range(50, 1001, 50)
    else:
        y_set = range(10, 41, 2)

    if y_label == 'avg_BER':
        plt.semilogy(y_set, results_ofdmim[y_label], 'ko-', label='OFDM-IM (RL) *Proposed approach*')
        plt.semilogy(y_set, results_ofdm[y_label], 'k*--', label='OFDM (RL)')
        plt.semilogy(y_set, results_ofdmim_without_noisemod[y_label], 'ko--', label='OFDM-IM (RL \ noisemod)')
        plt.semilogy(y_set, results_ofdm_without_noisemod[y_label], 'k*--', label='OFDM (RL \ noisemod)')
        plt.semilogy(y_set, results_ofdmim_BER_policy[y_label], 'ko:', label='OFDM-IM (fixed BER policy)')
        plt.semilogy(y_set, results_ofdm_BER_policy[y_label], 'k*:', label='OFDM (fixed BER policy)')
        plt.semilogy(y_set, results_ofdmim_SJR_policy[y_label], 'k^:', label='OFDM-IM (fixed SJR policy)')
        plt.semilogy(y_set, results_ofdm_SJR_policy[y_label], 'kx:', label='OFDM (fixed SJR policy)')
    else:
        plt.plot(y_set, results_ofdmim[y_label], 'ko-', label='OFDM-IM (RL) *Proposed approach*')
        plt.plot(y_set, results_ofdm[y_label], 'k*--', label='OFDM (RL)')
        plt.plot(y_set, results_ofdmim_without_noisemod[y_label], 'ko--', label='OFDM-IM (RL \ noisemod)')
        plt.plot(y_set, results_ofdm_without_noisemod[y_label], 'k*--', label='OFDM (RL \ noisemod)')
        plt.plot(y_set, results_ofdmim_BER_policy[y_label], 'ko:', label='OFDM-IM (fixed BER policy)')
        plt.plot(y_set, results_ofdm_BER_policy[y_label], 'k*:', label='OFDM (fixed BER policy)')
        plt.plot(y_set, results_ofdmim_SJR_policy[y_label], 'k^:', label='OFDM-IM (fixed SJR policy)')
        plt.plot(y_set, results_ofdm_SJR_policy[y_label], 'kx:', label='OFDM (fixed SJR policy)')

    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend(fontsize="7")
    plt.show()


def save_results(jammer_type, x_label, n_training, n_test, n_MC, results_ofdmim, results_ofdm,
                 results_ofdmim_without_noisemod, results_ofdm_without_noisemod,
                 results_ofdmim_BER_policy, results_ofdm_BER_policy,
                 results_ofdmim_SJR_policy, results_ofdm_SJR_policy):
    path = str(x_label) + '_' + str(jammer_type) + '_' + str(n_training) + '_' + str(n_test) + '_' + str(n_MC)
    isExist = os.path.exists(path)
    if not isExist:
        os.makedirs(path)
    scipy.io.savemat(path + '/ofdmim.mat', results_ofdmim)
    scipy.io.savemat(path + '/ofdm.mat', results_ofdm)
    scipy.io.savemat(path + '/ofdmim_without_noisemod.mat', results_ofdmim_without_noisemod)
    scipy.io.savemat(path + '/ofdm_without_noisemod.mat', results_ofdm_without_noisemod)
    scipy.io.savemat(path + '/ofdmim_BER_policy.mat', results_ofdmim_BER_policy)
    scipy.io.savemat(path + '/ofdm_BER_policy.mat', results_ofdm_BER_policy)
    scipy.io.savemat(path + '/ofdmim_SJR_policy.mat', results_ofdmim_SJR_policy)
    scipy.io.savemat(path + '/ofdm_SJR_policy.mat', results_ofdm_SJR_policy)




In [None]:

## MAIN
## iter_metric:str ='BER_limit', 'SNR', 'detection_threshold', 'J_power_avg_dB'
## nSym:int = 100,
## SJR_dB:int = 15,
## jammer_type = "covert","reactive_bj","reactive_pbj","reactive_combined","probabilistic","random",
## M:int = 2,
## N:int = 48,
## Ncp:int = 16,
## n:int = 4,
## k:int = 2,
## L:int = 5,
## ICSI:int = 0,
## Interleaver:int = 0,
## EbNo_dB:int = 20,
## detection_threshold = 0.3,
## BER_limit_dB = 30,
## J_power_max_dB = 0,
## J_power_avg_dB = 15,
## n_MC:int = 1,
## n_training:int = 1000,
## n_test:int = 100,
## comm_mode:str = 'OFDM-IM', 'OFDM', 'OFDM-IM-without-noisemod', 'OFDM-without-noisemod',
##                 'OFDM-IM-BER-policy', 'OFDM-BER-policy', 'OFDM-IM-SJR-policy', 'OFDM-SJR-policy'


## 'BER_limit', 'SNR', 'detection_threshold', 'J_power_avg_dB', 'n_training'
iter_metric = 'J_power_avg_dB'
n_training = 500
n_test = 500
n_MC = 200
jammer_type = "reactive_bj"
## "covert","reactive_bj","reactive_pbj","reactive_combined","probabilistic","random",

iter_set_ofdmim = iterParams(iter_metric=iter_metric, comm_mode='OFDM-IM', jammer_type=jammer_type,
                             n_training=n_training, n_test=n_test, n_MC=n_MC)
iter_set_ofdm = iterParams(iter_metric=iter_metric, comm_mode='OFDM', jammer_type=jammer_type,
                           n_training=n_training,
                           n_test=n_test, n_MC=n_MC)
iter_set_ofdmim_without_noisemod = iterParams(iter_metric=iter_metric, comm_mode='OFDM-IM-without-noisemod',
                                              jammer_type=jammer_type, n_training=n_training, n_test=n_test,
                                              n_MC=n_MC)
iter_set_ofdm_without_noisemod = iterParams(iter_metric=iter_metric, comm_mode='OFDM-without-noisemod',
                                            jammer_type=jammer_type, n_training=n_training, n_test=n_test,
                                            n_MC=n_MC)
iter_set_ofdmim_BER_policy = iterParams(iter_metric=iter_metric, comm_mode='OFDM-IM-BER-policy',
                                        jammer_type=jammer_type, n_training=n_training, n_test=n_test, n_MC=n_MC)
iter_set_ofdm_BER_policy = iterParams(iter_metric=iter_metric, comm_mode='OFDM-BER-policy', jammer_type=jammer_type,
                                      n_training=n_training, n_test=n_test, n_MC=n_MC)
iter_set_ofdmim_SJR_policy = iterParams(iter_metric=iter_metric, comm_mode='OFDM-IM-SJR-policy',
                                        jammer_type=jammer_type, n_training=n_training, n_test=n_test, n_MC=n_MC)
iter_set_ofdm_SJR_policy = iterParams(iter_metric=iter_metric, comm_mode='OFDM-SJR-policy', jammer_type=jammer_type,
                                      n_training=n_training, n_test=n_test, n_MC=n_MC)


with Pool(30) as p:
    results_ofdmim_ = p.map(training, iter_set_ofdmim)
    results_ofdm_ = p.map(training, iter_set_ofdm)
    results_ofdmim_without_noisemod_ = p.map(training, iter_set_ofdmim_without_noisemod)
    results_ofdm_without_noisemod_ = p.map(training, iter_set_ofdm_without_noisemod)
    results_ofdmim_BER_policy_ = p.map(training, iter_set_ofdmim_BER_policy)
    results_ofdm_BER_policy_ = p.map(training, iter_set_ofdm_BER_policy)
    results_ofdmim_SJR_policy_ = p.map(training, iter_set_ofdmim_SJR_policy)
    results_ofdm_SJR_policy_ = p.map(training, iter_set_ofdm_SJR_policy)

results_ofdmim, results_ofdm, results_ofdmim_without_noisemod, results_ofdm_without_noisemod, results_ofdmim_BER_policy, results_ofdm_BER_policy, results_ofdmim_SJR_policy, results_ofdm_SJR_policy = process_results(
    results_ofdmim_, results_ofdm_,
    results_ofdmim_without_noisemod_, results_ofdm_without_noisemod_,
    results_ofdmim_BER_policy_, results_ofdm_BER_policy_,
    results_ofdmim_SJR_policy_, results_ofdm_SJR_policy_)

save_results(jammer_type, iter_metric, n_training, n_test, n_MC, results_ofdmim, results_ofdm,
             results_ofdmim_without_noisemod, results_ofdm_without_noisemod,
             results_ofdmim_BER_policy, results_ofdm_BER_policy,
             results_ofdmim_SJR_policy, results_ofdm_SJR_policy)


In [None]:
plot_results(iter_metric, "avg_throughput", results_ofdmim, results_ofdm,
             results_ofdmim_without_noisemod, results_ofdm_without_noisemod,
             results_ofdmim_BER_policy, results_ofdm_BER_policy,
             results_ofdmim_SJR_policy, results_ofdm_SJR_policy)

In [None]:
plot_results(iter_metric, "avg_BER", results_ofdmim, results_ofdm,
             results_ofdmim_without_noisemod, results_ofdm_without_noisemod,
             results_ofdmim_BER_policy, results_ofdm_BER_policy,
             results_ofdmim_SJR_policy, results_ofdm_SJR_policy)