In [None]:
import sionna as sn
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [None]:
BATCH_SIZE = 500

In [None]:
class OFDMModel(tf.keras.Model):
    
    def __init__(self, num_bs_ant):
        super().__init__()

        DELAY_SPREAD = 100e-9
        DIRECTION = 'uplink'
        CDL_MODEL = 'C'
        SPEED = 0.0

        CARRIER_FREQUENCY = 1.9e9

        self.NUM_BITS_PER_SYMBOL = 4
        self.CODERATE = 0.5

        self.NUM_UT = 1
        self.NUM_BS = 1
        self.NUM_UT_ANT = 1
        self.NUM_BS_ANT = num_bs_ant
        self.NUM_STREAMS_PER_TX = self.NUM_UT_ANT

        self.RX_TX_ASSOCIATION = np.array([[1]])
        self.STREAM_MANAGEMENT = sn.mimo.StreamManagement(self.RX_TX_ASSOCIATION, self.NUM_STREAMS_PER_TX)

        self.resource_grid = sn.ofdm.ResourceGrid(
            num_ofdm_symbols=14,
            fft_size=76,
            subcarrier_spacing=30e3,
            num_tx=self.NUM_UT,
            num_streams_per_tx=self.NUM_STREAMS_PER_TX,
            cyclic_prefix_length=6,
            pilot_pattern="kronecker",
            pilot_ofdm_symbol_indices=[2,11]
        )

        self.NUM_CODED_BITS = int(self.resource_grid.num_data_symbols * self.NUM_BITS_PER_SYMBOL)
        self.NUM_INFO_BITS = int(self.NUM_CODED_BITS * self.CODERATE)

        ut_array = sn.channel.tr38901.Antenna(
            polarization='single',
            polarization_type='V',
            antenna_pattern='omni',
            carrier_frequency=CARRIER_FREQUENCY
        )

        bs_array = sn.channel.tr38901.AntennaArray(
            num_rows=1,
            num_cols=self.NUM_BS_ANT,
            polarization='dual',
            polarization_type='cross',
            antenna_pattern='38.901',
            carrier_frequency=CARRIER_FREQUENCY
        )

        self.cdl = sn.channel.tr38901.CDL(
            CDL_MODEL,
            DELAY_SPREAD,
            CARRIER_FREQUENCY,
            ut_array=ut_array,
            bs_array=bs_array,
            direction=DIRECTION,
            min_speed=SPEED
        )


        constellation = sn.utils.Constellation('qam', self.NUM_BITS_PER_SYMBOL)

        self.binary_source = sn.utils.BinarySource()

        self.encoder = sn.fec.ldpc.LDPC5GEncoder(self.NUM_INFO_BITS, self.NUM_CODED_BITS)
        self.decoder = sn.fec.ldpc.LDPC5GDecoder(encoder=self.encoder, hard_out=True)

        self.mapper = sn.mapping.Mapper('qam', self.NUM_BITS_PER_SYMBOL)
        self.demapper = sn.mapping.Demapper('app', constellation=constellation)

        self.rg_mapper = sn.ofdm.ResourceGridMapper(self.resource_grid)

        self.ls_est = sn.ofdm.LSChannelEstimator(resource_grid=self.resource_grid, interpolation_type='nn')

        self.lmmse_equ = sn.ofdm.LMMSEEqualizer(self.resource_grid, self.STREAM_MANAGEMENT)


        self.channel = sn.channel.OFDMChannel(
            self.cdl,
            self.resource_grid,
            add_awgn=True,
            normalize_channel=True,
        )

    @tf.function
    def __call__(self, batch_size, ebno_db):

        no = sn.utils.ebnodb2no(
            ebno_db,
            self.NUM_BITS_PER_SYMBOL,
            self.CODERATE,
            self.resource_grid
        )

        bits = self.binary_source([batch_size, self.NUM_UT, self.resource_grid.num_streams_per_tx, self.NUM_INFO_BITS])
        
        codewords = self.encoder(bits)
        
        qam_symbols = self.mapper(codewords)
        
        ofdm_symbols = self.rg_mapper(qam_symbols)
          
        response = self.channel([ofdm_symbols, no])

        estimation, err_var = self.ls_est([response, no])
        
        equalized_symbols, no_eff = self.lmmse_equ([response, estimation, err_var, no])

        llr = self.demapper([equalized_symbols, no_eff])
        
        bits_hat = self.decoder(llr)

        return bits, bits_hat

In [None]:
ber_plots = sn.utils.PlotBER('OFDM')
EBNO_DBS = np.linspace(-5,10,16)

model_1 = OFDMModel(2)
model_2 = OFDMModel(4)
model_3 = OFDMModel(8)


In [None]:
model_1.resource_grid.show();

In [None]:
ber_plots.simulate(
    model_1,
    batch_size=BATCH_SIZE,
    ebno_dbs=np.linspace(-5,10,16),
    num_target_bit_errors=1000,
    legend='2 BS ANTENNAS',
    soft_estimates=False,
    max_mc_iter=100,
    show_fig=False
);

In [None]:
ber_plots.simulate(
    model_2,
    batch_size=BATCH_SIZE,
    ebno_dbs=np.linspace(-5,10,16),
    num_target_bit_errors=1000,
    legend='4 BS ANTENNAS',
    soft_estimates=False,
    max_mc_iter=100,
    show_fig=False
);

In [None]:
ber_plots.simulate(
    model_3,
    batch_size=BATCH_SIZE,
    ebno_dbs=np.linspace(-5,10,16),
    num_target_bit_errors=1000,
    legend='8 BS ANTENNAS',
    soft_estimates=False,
    max_mc_iter=100,
    show_fig=False
);

In [None]:
ber_plots()