Skip to content

Commit

Permalink
mapping: Implement learning symbolwise demapping in ConstellationDema…
Browse files Browse the repository at this point in the history
…pper
  • Loading branch information
noc0lour committed Mar 25, 2024
1 parent a0431d1 commit cddbddd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
36 changes: 30 additions & 6 deletions src/mokka/mapping/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,12 @@ def forward(self, b, *args, one_hot=False):
"""
Perform mapping of bitstrings.
By specifying one_hot=true symbol-wise geometric constellation shaping is possible
calculating the loss w.r.t MI after demapping.
:params b: input to ConstellationMapper either bits or one-hot vectors
:params one_hot: specify if b is a one-hot vector
:returns: complex symbol per bitstring of length `self.m`
"""
# Generate one-hot representatios for m bits
Expand Down Expand Up @@ -471,29 +477,44 @@ class ConstellationDemapper(torch.nn.Module):
:param m: Bits per symbol
:params depth: Number of hidden layers
:params width: Neurons per layer
:params with_logit: Output LLRS and not bit probabilities
:params bitwise: Bitwise LLRs & probabilities
"""

def __init__(self, m, depth=3, width=128, with_logit=True, demod_extra_params=None):
def __init__(
self,
m,
depth=3,
width=128,
with_logit=True,
bitwise=True,
demod_extra_params=None,
):
"""Construct ConstellationDemapper."""
super(ConstellationDemapper, self).__init__()
self.with_logit = with_logit

self.register_buffer(
"demod_extra_params", torch.tensor(demod_extra_params or [])
"demod_extra_params", torch.as_tensor(demod_extra_params or [])
)
self.register_buffer("m", torch.tensor(m))
self.register_buffer("width", torch.tensor(width))
self.register_buffer("depth", torch.tensor(depth))
self.register_buffer("m", torch.as_tensor(m))
self.register_buffer("width", torch.as_tensor(width))
self.register_buffer("depth", torch.as_tensor(depth))
self.register_buffer("bitwise", torch.as_tensor(bitwise))

self.ReLU = torch.nn.LeakyReLU()
self.sigmoid = torch.nn.Sigmoid()
self.softmax = torch.nn.Softmax()

self.demaps = torch.nn.ModuleList()
input_width = 2 + len(demod_extra_params or [])
self.demaps.append(torch.nn.Linear(input_width, width))
for d in range(depth - 2):
self.demaps.append(torch.nn.Linear(width, width))
self.demaps.append(torch.nn.Linear(width, m))
if self.bitwise:
self.demaps.append(torch.nn.Linear(width, m))
else:
self.demaps.append(torch.nn.Linear(width, 2**m))

for demap in self.demaps:
torch.nn.init.xavier_normal_(demap.weight)
Expand Down Expand Up @@ -525,6 +546,8 @@ def forward(self, y, *args):
logit = -1 * logit
if self.with_logit:
return logit
if not self.bitwise:
return self.softmax(logit)
return self.sigmoid(logit)

@staticmethod
Expand Down Expand Up @@ -567,6 +590,7 @@ class ClassicalDemapper(torch.nn.Module):
:param noise_sigma: $\\sigma$ for the Gaussian assumption
:param constellation: PyTorch tensor of complex constellation symbols
:param optimize: Use $\\sigma$ as trainable paramater
:param bitwise: Perform demapping bitwise returning LLRs
:param p_symbols: PyTorch tensor with symbol probabilities
"""

Expand Down
11 changes: 11 additions & 0 deletions src/mokka/utils/bitops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ def idx2bits(idxs, num_bits_per_symbol):
2,
)
return bits


def onehot_to_idx(onehot):
"""
Convert a one-hot vector to the corresponding index.
:params onehot: Input one-hot vectors
:returns: tensor of indices
"""
idxs = torch.arange(onehot.size()[1]).unsqueeze(0)
return torch.sum(idxs * onehot, axis=1)

0 comments on commit cddbddd

Please sign in to comment.