This notebook perform sanity check to make sure baseline working as expected.

* Inputs: (CFO + Channel Inference + AWGN) signals
* Outputs: Estimated message bits

Tests:
  * Baseline on inputs with no CFO, no Channel Interference
  * Baseline on inputs with no CFO, channel_tap = 2
  * Baseline on inputs with CFO = 1/100, channel_tap =2

## Environment Setup

In [1]:
# Import packages from other direction. Itis necessary if the project is structured as:
import multiprocessing as mp
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
from radioml.models import Baseline
from radioml.metrics import get_ber_bler
from radioml.dataset import RadioDataGenerator

# For visualization
import matplotlib.pyplot as plt
import seaborn as sns
import pylab
sns.set_style("white")

## Define paramters

In [2]:
NUM_PACKETS  = 40
PREAMBLE_LEN = 40
DATA_LEN     = 100
SNRs         = [0.0, 5.0, 7.0, 10, 12.0, 15.0]

## Define Baseline (MSSE, Classic Demodulation, Viterbi)


In [3]:
import scipy.signal as sig

class MMSEEqualizer(object):
    """Minimum Mean Squared Error Equalizer.
    """
    
    def __init__(self, equalizer_order=None, random_starts=True, update_rate=0.001):
        
        if equalizer_order is None:
            raise ValueError("MMSE Equalizer: equalizer_order is missing")   
        if (equalizer_order % 2 == 0):
            raise ValueError("MMSE Equalizer: equalizer_order must be odd")
        if (equalizer_order < 3):
            raise ValueError("MMSE Equalizer: equalizer_order must be at least 3")
            
        self.order = equalizer_order
        self.random_starts = random_starts
        self.h = None
        self.L = (self.order-1)//2
        self.mu = learning_rate
            
    def update(self, x, y):
        constant = 0j if isinstance(x[0], complex) else 0.0
        A = []
        x = np.pad(x, self.L, 'constant', constant_values=(constant))
        for i in range(len(y)):
            A += [np.flip(x[i: i+self.order],0)]
        A = np.array(A)
        h,_,_,_ = np.linalg.lstsq(A, y,rcond=-1)
        self.h = h
    
    def predict(self, x):
        if (self.h is None):
            if (self.random_starts):
                self.h = np.random.randn(self.order) + \
                         1j * np.random.randn(self.order) \
                         if isinstance(x[0], complex) else np.random.randn(self.order)
            else:
                self.h = np.zeros(self.order, dtype=np.complex_ if isinstance(x[0], complex) else np.float32)
        return sig.convolve(x, self.h , mode="full")[self.L:]

In [None]:
mmse = MMSEEqualizer(equalizer_order=5, 
                     random_starts=True, learning_rate=0.01)

baseline = Baseline(equalizer=mmse, modulation_scheme='QPSK')

## Define benchmark func

In [None]:
def run_benchmark(radio, snr_range, omega, channel_len):
    bit_error_rates, block_error_rates = [], []
    for snr in snr_range:
        print('SNR_dB = %f' % snr)
        generator = radio.end2end_data_generator(OMEGA, snr, 
                                                 batch_size=NUM_PACKETS)

        [preambles, corrupted_packets], [message_bits, w, channels] = next(generator)

        # Extract preamble_conv and convert to complex
        preambles = preambles.view(complex)
        convolved_preamble = np.array(corrupted_packets[:, :radio.preamble_len, :]).view(complex)
        convolved_data     = np.array(corrupted_packets[:, radio.preamble_len:, :]).view(complex)

        # Esimate message bits with Baseline
        with mp.Pool(mp.cpu_count()) as pool:
            baseline_results = pool.starmap(baseline,[(i,j,k) for i, j,k in 
                                        zip(convolved_data, preambles, convolved_preamble)])
            baseline_results = np.array(baseline_results)

        ber, bler = get_ber_bler(baseline_results, np.squeeze(message_bits, -1))
        
        print('\t[Baseline] Ber = {:.8f} | Bler ={:.8f} '.format(ber, bler))
        bit_error_rates.append([ber, ber])
        block_error_rates.append([bler, bler])
        
    return bit_error_rates, block_error_rates

##  Baseline with No CFO, No Intersymbol Interference (Channel_len = 1)

In [None]:
CHANNEL_LEN  = 1        # Ignore channel interference
OMEGA        = 1/10000  # no CFO.
radio        = RadioDataGenerator(DATA_LEN, 
                                  PREAMBLE_LEN, 
                                  CHANNEL_LEN, modulation_scheme='QPSK')

bers, blers = run_benchmark(radio, 
                            snr_range=SNRs, 
                            omega=OMEGA,
                            channel_len=CHANNEL_LEN)

##  Baseline with No CFO, Intersymbol Interference  (Channel_len = 2)

In [None]:
CHANNEL_LEN  = 2
OMEGA        = 1/10000  # no CFO.
radio        = RadioDataGenerator(DATA_LEN, PREAMBLE_LEN, 
                                  CHANNEL_LEN, 
                                  modulation_scheme='QPSK')

bers1, blers1 = run_benchmark(radio, snr_range=SNRs, 
                              omega=OMEGA, 
                              channel_len=CHANNEL_LEN)

##  Baseline with CFO = 1/100,  Intersymbol Interference  (Channel_len = 2)

In [None]:
CHANNEL_LEN  = 2
OMEGA        = 1/100  # no CFO.
radio        = RadioDataGenerator(DATA_LEN, 
                                  PREAMBLE_LEN, 
                                  CHANNEL_LEN, 
                                  modulation_scheme='QPSK')

bers2, blers2 = run_benchmark(radio, 
                              snr_range=SNRs, 
                              omega=OMEGA, 
                              channel_len=CHANNEL_LEN)

## Visualize Result

In [None]:
def visualize_ber_bler(ax1, ax2, ber_logs, bler_logs, snr_range, title):
    
    ax1.plot(snr_range, np.array(ber_logs).T[1, :], label=title)

    ax1.set_xlabel('SNR (in dB)')
    ax1.set_ylabel('Bit Error Rate (BER)')
    ax1.grid(True,'both')
    
    ax1.set_xlim(np.min(snr_range), np.max(snr_range))
    ax2.plot(snr_range, np.array(bler_logs).T[1,:], label=title)
    ax2.set_xlabel('SNR (in dB)')
    ax2.set_ylabel('Block Error Rate (BLER)')
    ax2.set_xlim(np.min(snr_range), np.max(snr_range))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

fig.suptitle('Evaluate Baseline (MMSE + Classic Demod + Viterbi) on different signals', 
             fontsize='x-large')

visualize_ber_bler(ax1, ax2, bers, blers, SNRs, 'Inputs 1 (No CFO, No Intersymbol Interference)')
visualize_ber_bler(ax1, ax2, bers1, blers1, SNRs, 'Inputs 2 (No CFO,  Channel  = 2)')
visualize_ber_bler(ax1, ax2, bers2, blers2, SNRs,  'Inputs 3 (CFO=1/100, channel_tap = 2)')

ax1.semilogy()
ax2.semilogy()
ax1.legend(loc=3)
_ =ax2.legend(loc=3)