In [None]:
# OFDM system with or without modulation and channel coding
#
# Adapted from https://nvlabs.github.io/sionna/examples/Sionna_tutorial_part3.html

import math
import os
gpu_num = 4 # Use "" to use the CPU
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_num}"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import sionna as sn
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.random.set_seed(42)

import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

from PIL import Image

from digcom import OFDMSystemRaw
from utils import normalize


def digital_simulation():
    """Simulate transmitting random bits via digital modulation, channel coding, OFDM and complex channel"""

    EBN0_DB_MIN = 0.0  # Minimum value of Eb/N0 [dB] for simulations
    EBN0_DB_MAX = 40.0  # Maximum value of Eb/N0 [dB] for simulations
    NVALS = 15
    BATCH_SIZE = 16

    ber_plots = sn.utils.PlotBER("OFDM over 3GPP UMa")

    model_ls = OFDMSystemRaw(BATCH_SIZE, digital=True, perfect_csi=False, no_coding=True, show=True)
    ber_plots.simulate(
        model_ls,
        ebno_dbs=np.linspace(EBN0_DB_MIN, EBN0_DB_MAX, NVALS),
        batch_size=BATCH_SIZE,
        num_target_block_errors=100,  # simulate until 100 block errors occured
        legend="LS Estimation",
        soft_estimates=True,
        max_mc_iter=100,  # run 100 Monte-Carlo simulations (each with batch_size samples)
        show_fig=False,
    )

    model_pcsi = OFDMSystemRaw(BATCH_SIZE, digital=True, no_coding=True, perfect_csi=True)
    ber_plots.simulate(
        model_pcsi,
        ebno_dbs=np.linspace(EBN0_DB_MIN, EBN0_DB_MAX, NVALS),
        batch_size=BATCH_SIZE,
        num_target_block_errors=100,  # simulate until 100 block errors occured
        legend="Perfect CSI",
        soft_estimates=True,
        max_mc_iter=100,  # run 100 Monte-Carlo simulations (each with batch_size samples)
        show_fig=False,
    )

    ber_plots()
    
def analog_simulation():
    """Simulate transmitting random data via raw OFDM and complex channel"""

    EBN0_DB = 0.0  
    BATCH_SIZE = 128

    model = OFDMSystemRaw(BATCH_SIZE, digital=False, no_coding=True, perfect_csi=False)
    x, x_hat = model(BATCH_SIZE, EBN0_DB)

    x_flat = tf.reshape(x[0, :, :, :], shape=-1)
    y_flat = tf.reshape(x_hat[0, :, :, :], shape=-1)

    plt.figure()
    plt.title("Real x_hat vs x")
    plt.scatter(tf.math.real(x_flat).numpy(), tf.math.real(y_flat).numpy())

    plt.figure()
    plt.title("Imag x_hat vs x")
    plt.scatter(tf.math.imag(x_flat).numpy(), tf.math.imag(y_flat).numpy())
    plt.show()

def analog_image():
    """Open image and transmit it using raw OFDM and complex channel simulation"""

    # Read image
    img = Image.open("kodim23.png")
    img = np.array(img, dtype=float)
    plt.imshow(normalize(img, 0.0, 255.0, 0.0, 1.0))
    plt.show()

    # samples uniformly distributed over a square of width 1.2, form a signal with unit variance
    img = normalize(img, 0.0, 255.0, -1.2, 1.2)
    img = tf.constant(img, dtype=float)

    img_h, img_w, nch = img.shape
    num_symbols = img_h * img_w * nch
    
    fft_size = 192 
    num_ofdm_symbols = 14
    pilot_ofdm_symbol_indices = [2, 11]
    num_nonpilot_ofdm_symbols = num_ofdm_symbols - len(pilot_ofdm_symbol_indices)
    
    # Set batch size based on the number of pixels, 
    # (divided by 2 because of mapping symbols to complex numbers)
    BATCH_SIZE = math.ceil(num_symbols / fft_size / num_nonpilot_ofdm_symbols / 2) 
    
    # Pad image to the required size and flatten
    num_padded = BATCH_SIZE * 2 * fft_size * num_nonpilot_ofdm_symbols
    
    if num_padded % 2 != 0:
        num_padded += 1
    
    inp_img = tf.pad(np.reshape(img, -1), [[0, num_padded-num_symbols]])

    # Assign pair of samples as real/image values of a complex tensor
    inp_re = inp_img[::2]
    inp_im = inp_img[1::2]
    inp = tf.complex(inp_re, inp_im)
    
    # Reshape for Sionna processing (assuming 1 TX device with 1 antenna and 1 stream per antenna)
    num_ut = 1
    num_ut_ant = 1
    inp = tf.reshape(inp, [BATCH_SIZE, num_ut, num_ut_ant, fft_size * num_nonpilot_ofdm_symbols]) 
    print("inp shape:", inp.shape)

    # Run simulation
    EBN0_DB = 10.0  
    resource_grid_config = {
        "num_ofdm_symbols": num_ofdm_symbols,
        "fft_size": fft_size,
        "pilot_ofdm_symbol_indices": pilot_ofdm_symbol_indices
    }
    model = OFDMSystemRaw(BATCH_SIZE, resource_grid_config, num_ut, num_ut_ant, 
                          digital=False, perfect_csi=False, show=True)
    print("OFDM duration", model.RESOURCE_GRID.ofdm_symbol_duration)
    out = model.run_with_input(inp, EBN0_DB) 
    
    # Flatten the output and extract real/imag values
    out = tf.reshape(out, -1)
    out_re = tf.math.real(out)
    out_im = tf.math.imag(out)
    
    # Reconstruct the output image from the real/imag values
    out_img = np.zeros_like(inp_img)
    out_img[::2] = out_re.numpy()
    out_img[1::2] = out_im.numpy()
    out_img = out_img[:num_symbols].reshape([img_h, img_w, nch])
    plt.imshow(normalize(out_img, -1.2, 1.2, 0.0, 1.0))
    plt.show()
    
    psnr = 10 * np.log10(1.0 / np.mean((img - out_img) ** 2))
    print("PSNR: ", psnr, "db")
    

# digital_simulation()
# analog_simulation()
analog_image()