Skip to content

Commit

Permalink
Implement transform precoding (preliminary draft)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielschaeufele committed May 6, 2024
1 parent 8ad32bc commit e2e2d57
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 19 deletions.
3 changes: 2 additions & 1 deletion sionna/nr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .pusch_dmrs_config import PUSCHDMRSConfig
from .pusch_pilot_pattern import PUSCHPilotPattern
from .pusch_precoder import PUSCHPrecoder
from .pusch_transform_precoder import PUSCHTransformPrecoder, PUSCHTransformDeprecoder
from .pusch_transmitter import PUSCHTransmitter
from .pusch_receiver import PUSCHReceiver
from .pusch_channel_estimation import PUSCHLSChannelEstimator
from .tb_config import TBConfig
from .utils import generate_prng_seq, select_mcs, calculate_tb_size
from .utils import generate_prng_seq, generate_low_papr_seq_type_1, select_mcs, calculate_tb_size
from .tb_encoder import TBEncoder
from .tb_decoder import TBDecoder
from .layer_mapping import LayerMapper, LayerDemapper
29 changes: 20 additions & 9 deletions sionna/nr/pusch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pylint: disable=line-too-long

import numpy as np
from .utils import generate_prng_seq
from .utils import generate_prng_seq, generate_low_papr_seq_type_1
from .config import Config
from sionna import nr
from .utils import calculate_tb_size
Expand Down Expand Up @@ -233,7 +233,7 @@ def n_rnti(self, value):
assert value in range(65536), "n_rnti must be in [0, 65535]"
self._n_rnti = value

#---transform_precoding---#
#---precoding---#
@property
def precoding(self):
"""
Expand Down Expand Up @@ -518,7 +518,7 @@ def dmrs_grid(self):
This property returns for each configured DMRS port an empty
resource grid filled with DMRS signals as defined in
Section 6.4.1.1 [3GPP38211]. Not all possible options are implemented,
e.g., frequency hopping and transform precoding are not available.
e.g., frequency hopping is not available.
This property provides the *unprecoded* DMRS for each configured DMRS port.
Precoding might be applied to map the DMRS to the antenna ports. However,
Expand Down Expand Up @@ -546,15 +546,24 @@ def dmrs_grid(self):
# For every l_prime
for l_prime in self.l_prime:

# Compute c_init
l = l_bar + l_prime
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init)
if self.transform_precoding:
if self.dmrs.n_sid is None:
n_id = self.carrier.n_cell_id
else:
n_id = self.dmrs.n_sid
r = generate_low_papr_seq_type_1(self.num_subcarriers // 2, n_id % 30, 0, 0)
print(r)
else:
# Compute c_init
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init)

# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))
# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))

# For every port in the dmrs port set
for j_ind, _ in enumerate(self.dmrs.dmrs_port_set):
Expand Down Expand Up @@ -904,6 +913,7 @@ def show(self):
def check_config(self):
"""Test if the compound configuration is valid"""

# TODO: check transform precoding conditions
self.carrier.check_config()
self.dmrs.check_config()
if self.precoding=="codebook":
Expand Down Expand Up @@ -1038,6 +1048,7 @@ def check_pusch_configs(pusch_configs):
"num_antenna_ports" : pc.num_antenna_ports,
"precoding" : pc.precoding,
"precoding_matrices" : [],
"transform_precoding" : pc.transform_precoding,
"pusch_config" : pc,
"carrier_config" : pc.carrier,
"num_coded_bits" : pc.num_coded_bits,
Expand Down
68 changes: 68 additions & 0 deletions sionna/nr/pusch_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import tensorflow as tf

import sionna
from sionna.utils import split_dim
from .pusch_transform_precoder import PUSCHTransformDeprecoder


class LinearTransformPrecodingMimoDetector(sionna.mimo.detection.LinearDetector):
def __init__(self,
equalizer,
output,
demapping_method,
num_subcarriers,
constellation_type=None,
num_bits_per_symbol=None,
constellation=None,
hard_out=False,
dtype=tf.complex64,
**kwargs):
super().__init__(equalizer, output, demapping_method, constellation_type, num_bits_per_symbol, constellation,
hard_out, dtype, **kwargs)
self._transform_deprecoder = PUSCHTransformDeprecoder(num_subcarriers, dtype)

def call(self, inputs):
x_hat, no_eff = self._equalizer(*inputs)
x_transform_deprecoded = self._transform_deprecoder(x_hat)
z = self._demapper([x_transform_deprecoded, no_eff])

# Reshape to the expected output shape
num_streams = tf.shape(inputs[1])[-1]
if self._output == 'bit':
num_bits_per_symbol = self._constellation.num_bits_per_symbol
z = split_dim(z, [num_streams, num_bits_per_symbol], tf.rank(z) - 1)

return z


class LinearTransformPrecodingDetector(sionna.ofdm.detection.OFDMDetector):
def __init__(self,
equalizer,
output,
demapping_method,
resource_grid,
stream_management,
constellation_type=None,
num_bits_per_symbol=None,
constellation=None,
hard_out=False,
dtype=tf.complex64,
**kwargs):
# Instantiate the linear detector
detector = LinearTransformPrecodingMimoDetector(equalizer=equalizer,
output=output,
demapping_method=demapping_method,
num_subcarriers=resource_grid.num_effective_subcarriers,
constellation_type=constellation_type,
num_bits_per_symbol=num_bits_per_symbol,
constellation=constellation,
hard_out=hard_out,
dtype=dtype,
**kwargs)

super().__init__(detector=detector,
output=output,
resource_grid=resource_grid,
stream_management=stream_management,
dtype=dtype,
**kwargs)
15 changes: 15 additions & 0 deletions sionna/nr/pusch_dmrs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ def n_id(self, value):
assert e in list(range(65536)), "Each element of n_id must be in [0, 65535]"
self._n_id = value

#---n_sid---#
@property
def n_sid(self):
r"""
None (default), [0,...,1007] : DMRS scrambling identity for DFT-s-OFDM
:math:`n_\text{ID}^\text{PUSCH}`
"""
self._ifndef("n_sid", None)
return self._n_scid

@n_sid.setter
def n_sid(self, value):
assert value is None or (isinstance(value, int) and value in range(1008)), "n_sid must None or in [0, 1007]"
self._n_sid = value

#---n_scid---#
@property
def n_scid(self):
Expand Down
25 changes: 18 additions & 7 deletions sionna/nr/pusch_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sionna.ofdm import OFDMDemodulator, LinearDetector
from sionna.utils import insert_dims
from sionna.channel import time_to_ofdm_channel
from .pusch_detection import LinearTransformPrecodingDetector

class PUSCHReceiver(Layer):
# pylint: disable=line-too-long
Expand Down Expand Up @@ -197,14 +198,25 @@ def __init__(self,
# Use or create default MIMODetector
if mimo_detector is None:
# Default MIMO detector
self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
dtype=dtype)
if pusch_transmitter._transform_precoding:
self._mimo_detector = LinearTransformPrecodingDetector("lmmse", "bit", "maxlog",
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
dtype=dtype)
else:
self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
dtype=dtype)
else:
# User-provided MIMO detector
if pusch_transmitter._transform_precoding and not isinstance(mimo_detector,
LinearTransformPrecodingDetector):
print("WARNING: Using mimo detector which does not support transform precoding")
self._mimo_detector = mimo_detector

# Create LayerDemapper
Expand Down Expand Up @@ -248,7 +260,6 @@ def call(self, inputs):
if self._input_domain=="time":
h = time_to_ofdm_channel(h, self.resource_grid, self._l_min)


if self._w is not None:
# Reshape h to put channel matrix dimensions last
# [batch size, num_rx, num_tx, num_ofdm_symbols,...
Expand Down
34 changes: 34 additions & 0 deletions sionna/nr/pusch_transform_precoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensorflow as tf
from tensorflow.keras.layers import Layer


class PUSCHTransformPrecoder(Layer):
def __init__(self,
num_subcarriers,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
self._num_subcarriers = num_subcarriers

def call(self, y):
orig_shape = tf.shape(y)
y_reshaped = tf.reshape(y, [-1, self._num_subcarriers])
y_transformed = tf.cast(tf.sqrt(1 / self._num_subcarriers), self._dtype) * tf.signal.fft(y_reshaped)
y_result = tf.reshape(y_transformed, orig_shape)
return y_result


class PUSCHTransformDeprecoder(Layer):
def __init__(self,
num_subcarriers,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
self._num_subcarriers = num_subcarriers

def call(self, y):
orig_shape = tf.shape(y)
y_reshaped = tf.reshape(y, [-1, self._num_subcarriers])
y_transformed = tf.cast(tf.sqrt(float(self._num_subcarriers)), self._dtype) * tf.signal.ifft(y_reshaped)
y_result = tf.reshape(y_transformed, orig_shape)
return y_result
14 changes: 13 additions & 1 deletion sionna/nr/pusch_transmitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .pusch_config import PUSCHConfig, check_pusch_configs
from .pusch_pilot_pattern import PUSCHPilotPattern
from .pusch_precoder import PUSCHPrecoder
from .pusch_transform_precoder import PUSCHTransformPrecoder
from .tb_encoder import TBEncoder
from .layer_mapping import LayerMapper

Expand Down Expand Up @@ -172,6 +173,11 @@ def __init__(self,
pilot_pattern=self._pilot_pattern,
dtype=dtype)

# Create PUSCHTransformPrecoder
if self._transform_precoding:
self._transform_precoder = PUSCHTransformPrecoder(self.resource_grid.num_effective_subcarriers,
dtype=dtype)

# Create ResourceGridMapper
self._resource_grid_mapper = ResourceGridMapper(self._resource_grid,
dtype=dtype)
Expand Down Expand Up @@ -227,8 +233,14 @@ def call(self, inputs):
# Map to layers
x_layer = self._layer_mapper(x_map)

# (Optionally) apply PUSCH transform precoding (DFT-s-OFDM)
if self._transform_precoding:
x_trans_pre = self._transform_precoder(x_layer)
else:
x_trans_pre = x_layer

# Apply resource grid mapping
x_grid = self._resource_grid_mapper(x_layer)
x_grid = self._resource_grid_mapper(x_trans_pre)

# (Optionally) apply PUSCH precoding
if self._precoding=="codebook":
Expand Down
Loading

0 comments on commit e2e2d57

Please sign in to comment.