Low-complexity Near-optimum Symbol Detection Based on Neural Enhancement of Factor Graphs
---
Luca Schmid and Laurent Schmalen

We consider the application of the factor graph framework for symbol detection on linear inter-symbol interference channels. Based on the Ungerboeck observation model, a detection algorithm with appealing complexity properties can be derived. However, since the underlying factor graph contains cycles, the sum-product algorithm (SPA) yields a suboptimal algorithm. In this paper, we develop and evaluate efficient strategies to improve the performance of the factor graph-based symbol detection by means of neural enhancement. In particular, we consider neural belief propagation and generalizations of the factor nodes as an effective way to mitigate the effect of cycles within the factor graph. By applying a generic preprocessor to the channel output, we propose a simple technique to vary the underlying factor graph in every SPA iteration. Using this dynamic factor graph transition, we intend to preserve the extrinsic nature of the SPA messages which is otherwise impaired due to cycles. Simulation results show that the proposed methods can massively improve the detection performance, even approaching the maximum a posteriori performance for various transmission scenarios, while preserving a complexity which is linear in both the block length and the channel memory.

The full paper [1] is available on arXiv and IEEE Xplore.

This jupyter notebook shows an examplary implementation and usage of the GAP algorithm. The code is implemented close to the description in [1]. The comments within the source code assume prior knowledge in the field of message apssing on factor graphs and assume that you have completely studied the ideas and notations in [1]. In specific, this notebook contains:
* Generic implementation of the GAP algorithm, parallelized to handle multiple batches of data blocks in parallel. The class is implemented with pyTorch and can run on CPUs or GPUs.
* Example training and evaluation procedure with configurable parameters.
* Some helper classes (to handle bit2symbol mappings, bit-metric decoding, etc.) and loss functions (BMI and BER).

[1] Schmid, Luca, and Laurent Schmalen. “Low-Complexity Near-Optimum Symbol Detection Based on Neural Enhancement of Factor Graphs.” ArXiv:2203.16417 [Cs, Eess, Math], March 30, 2022. http://arxiv.org/abs/2203.16417.

For questions, discussions or improvements of this code, feel free to contact us (Luca Schmid, first.last@kit.edu) or open an issue/create a pull request. Have fun!

---

This work has received funding in part from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No. 101001899) and in part from the German Federal Ministry of Education and Research (BMBF) within the project Open6GHub (grant agreement 16KISK010).

---

Copyright (c) 2021-2022 Luca Schmid - Communications Engineering Lab (CEL), Karlsruhe Institute of Technology (KIT)

<sup> Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

<sup> The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

<sup> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

In [1]:
# Import external libraries
import numpy as np
import torch as t
device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
import torch.nn as nn
import torch.nn.functional as func

In [2]:
# Define some common symbol constellations for digital baseband transmission.
bpsk_mapping = t.tensor([1.0, -1.0], dtype=t.cfloat)
qpsk_mapping = 1/np.sqrt(2) * t.tensor([-1-1j, -1+1j, 1-1j, 1+1j], dtype=t.cfloat)
qam16_mapping = 1/np.sqrt(10) * t.tensor([-3-3j, -3-1j, -3+3j, -3+1j, # uses Gray coding
                                          -1-3j, -1-1j, -1+3j, -1+1j,
                                          +3-3j, +3-1j, +3+3j, +3+1j,
                                          +1-3j, +1-1j, +1+3j, +1+1j ], dtype=t.cfloat)

# Define impules response of some example channels.
ProakisA = [0.04, -0.05, 0.07, -0.21, -0.5, 0.72, 0.36, 0.0, 0.21, 0.03, 0.07]
ProakisB = [0.407, 0.815, 0.407]
ProakisC = [0.227, 0.460, 0.688, 0.460, 0.277]

In [3]:
# Constellation helper class.
class constellation:
    """
    Class which provides some functions, applied to an arbitrary complex constellation, given by a bit2symbol mapping.
    """

    def __init__(self, mapping, device):
        """
        :param mapping: t.Tensor which contains the constellation symbols, 
            sorted according to their binary representation (MSB first).
        :param device: Device on which the tensors are allocated, e.g., "cuda", or "cpu".
        """
        assert len(mapping.shape) == 1 # mapping should be a 1-dim tensor
        self.mapping = mapping.to(device)
         
        self.M = t.numel(mapping) # M is number of constellation symbols
        assert self.M > 1
        self.m = np.log2(self.M).astype(int) # m=ld(M) is number of bits
        assert self.m == np.log2(self.M) # Assert that log2(M) is integer
        
        # Some helpers or bit-metric decoding etc.
        self.mask = 2 ** t.arange(self.m - 1, -1, -1).to(device)
        self.sub_consts = t.stack([t.stack([t.arange(self.M).reshape(2**(i+1),-1)[::2].flatten(), t.arange(self.M).reshape(2**(i+1),-1)[1::2].flatten()]) for i in range(self.m)]).to(device)
        self.device = device
        
    def map(self, bits):
        """
        Maps a given bit_sequence to a sequence of constellation symbols.
        The length of the output sequence is len(bit_sequence) / m.
        The operation is applied to the last axis of bit_sequences.
        bit_sequence is allowed to have other dimensions (e.g. multiple sequences at once)
        as long as the last dimensions is the sequence.
        """
        # Assert that the length of the bit sequence is a multiple of m.
        in_shape = bits.shape
        assert in_shape[-1]/self.m == in_shape[-1]//self.m
        # Reshape and convert bits to decimal and use decimal number as index for mapping.
        return self.mapping[t.sum(self.mask * bits.reshape(in_shape[:-1] + (-1, self.m)), -1)]
    
    def bit_metric_decoder(self, symbol_apps):
        """
        Receives a sequence of symbol probabilities/beliefs. 
        For each symbol, an M-dim tensor indicates the logarithmic probability for each of the M possible 
        constellation symbols.
        The bit metric decoder calculates the bit LLRs for each of the m bits for each symbol.
        """
        assert len(symbol_apps.shape) >= 2 # dim -2: symbol sequence, dim -1: M log APPs
        assert symbol_apps.shape[-1] == self.M

        # For each of the m bits, repartition the M APPs into two subsets regarding the respective bit.
        # The output vector has shape (..., m, 2, M/2).
        if self.M > 2:
            subset_probs = t.index_select(symbol_apps,-1, self.sub_consts.flatten()).view(symbol_apps.shape[:-1] + self.sub_consts.shape)
            # Sum up probabilities of all subsets (in log domain).
            bitwise_apps = self.jacobian_sum(subset_probs, dim=-1)
            # Compute LLR.
            LLR = (bitwise_apps[...,0] - bitwise_apps[...,1]).flatten(start_dim = -2)
        else: # M==2 -> we only have one binary channel -> bit-metric decoder is trivial
            LLR = symbol_apps[...,0] - symbol_apps[...,1]
            
        assert symbol_apps.shape[:-2] == LLR.shape[:-1]
        assert symbol_apps.shape[-2]*self.m == LLR.shape[-1]
        assert not t.isinf(LLR).any()
        return LLR
    
    def jacobian_sum(self, msg, dim):
        """
        Computes ln(e^a_1 + e^a_2 + ... e^a_M) of a tensor with last dimension (a_1, a_2, ..., a_M)
        by applying the Jacobian algorithm. (Also called log-sum-exp operation.)
        """
        assert msg.shape[dim] > 1
        if dim == -1:
            return t.max(msg, dim=-1)[0] + t.log(t.sum(t.exp(msg - t.max(msg, dim=-1, keepdim=True)[0]), dim=-1))
        elif dim == -2:
            return t.max(msg, dim=-2)[0] + t.log(t.sum(t.exp(msg - t.max(msg, dim=-2, keepdim=True)[0]), dim=-2))

In [4]:
# Implementation of the GAP algorithm
class GAP(nn.Module):
    """
    Symbol detection algorithm for linear ISI channels based on the sum-product algorithm (SPA) 
    on a generalized factor graph. For details, see [1].

    Notes: * All computations are carried out in the log domain.
           * In contrast to the very general description in [1], we drastically reduce the number of
             trainable parameters (the nn.Parameter objects, associated with the Boolean flags weight_Inm, 
             weight_F and nbp) by choosing them to be independent of the index 'k', i.e., along the dimension of the 
             block length (which can typically be quite large). From our experience, this introduces only a minor
             performance degradation but significantly slims down the training and gives space to increase other, 
             more relevant paramters, like B or S.
    """
    def __init__(self, block_len, channel_taps, constellation, 
                       pfilter, branches=1, stages=1, iters=5, 
                       device='cpu',
                       weight_Inm=False, weight_F=False, learn_pfilter=False, 
                       nbp=False, weight_priors=False
                       ):
        """
        :param block_len: Number of transmit symbols per information block.
        :param channel_taps: L+1 channel taps of the impulse response of the linear ISI channel with memory L.
        :param constellation: Constellation object, defining a mapping of M constellation symbols on m = ld(M) bits.
        :param pfilter: Preprocessing filter. The filter must have at least the length of the
            channel filter. If it is longer, the number of additional taps compared to the channel
            must be even. By choosing a filter matched to the channel, the Ungerboeck observation model is applied.
            If more than 1 stage and/or branch is used, you can input multiple (initial) pfilters
            in the shape [stages, branches, pfilter length]. 
            Otherwise, the one pfilter is used for all stages/branches.
        :param branches: Number of parallel branches (called B in [1]). The resulting outputs of the different 
            branches are combined after each stage according to eq. (14) in [1]. 
            Each branch may uses a different pfilter.
        :param stages: Number of dynamic factor graph transitions (called S in [1]). Each stage s uses the output
            of the previous stage as input to initialize its messages.
            Each stage s may use a different pfilter.
        :param iters: Number of iterations (called N' in [1]) for the SPA on the factor graph for each stage and branch.
            One iteration contains a variable2factor (V2F) node update and a factor2variable (F2V) update
            wit full parallel / flooding schedule.
        :param device: Device on which the tensors are allocated, e.g., "cuda", or "cpu".
        :param weight_Inm: Boolean flag. If True, the factors I_kl(c_k,c_l) are weighted with trainable
            parameters. See Sec. IV.A in [1].
        :param weight_F: Boolean flag. If True, the factors F_k(c_k) are weighted with trainable params.
            See Sec. IV.A in [1].
        :param learn_pfilter: Boolean flag. If True, the preprocessing filters P are learned. 
            They are initialized with the given impulse responses, given by the tensor pfilter. 
            Compare IV.B in [1]. If False, the filters, given by the tensor pfilter, are used but not adapted/optimized.
        :param nbp: Boolean flag. If True, all messages are parametrized with a multiplicative weight. 
            Compare IV.A in [1].
        :param weight_priors: If True, the priors which are fed to the individual equalizers
            are weighted with a parametrizable weight. Compare IV.B in [1].
        """
        super(GAP, self).__init__()
        # Process input params.
        self.device = device # device on which the computations are carried out
        assert iters > 0
        self.iters = iters # number of iterations of the SPA in each stage/branch
        assert stages > 0
        self.stages = stages # number of factor graph transitions (called S in [1])
        assert branches > 0
        self.branches = branches # number of parallel detetors (called B in [1])

        # channel specifics
        assert len(channel_taps.shape) == 1 and len(channel_taps) > 1 # implementation-specific restrictions
        self.h = channel_taps # channel impulse response
        self.l = len(channel_taps) - 1 # memory of the channel
        self.const = constellation # constellation object
        self.n = block_len # number of transmit symbols N
        self.k = block_len + self.l # number of receive symbols K (length of y in [1])
        
        # preprocessing filter
        if len(pfilter.shape) == 1: # only one pfilter for all branches and stages is given
            pfilters = pfilter.repeat(stages,branches,1)
        else:
            pfilters = pfilter
            assert len(pfilters.shape) == 3
            assert pfilters.shape[:2] == (stages, branches,)
        self.p_len = pfilters.shape[-1]
        assert self.p_len >= len(channel_taps) # This is only a restriction for this specific implementation.
        assert (self.p_len - len(channel_taps)) % 2 == 0
        self.os_extend = (self.p_len - len(channel_taps)) // 2
        # L2 is the new memory of the overall filter (channel + preprocessor). This is equivalent to L+L_p in [1].
        self.l2 = self.l + self.os_extend 
        # Assign preprocessing filter to a tensor or parameter, depending if it is being optimized or if it is fixed.
        if learn_pfilter:
            # individual filter for each stage and branch
            self.p = t.nn.Parameter(pfilters)
        else:
            # constant preprocessing filter
            self.p = pfilters.to(self.device) 

        # Boolean flags, specifying what we want to optimized.
        self.learn_pfilter = learn_pfilter
        self.weight_Inm = weight_Inm
        self.weight_F = weight_F
        self.nbp = nbp
        self.weight_priors = weight_priors

        # Initial computation of the G matrix and the Inm factors (both independent of observed data).
        # This only needs to be changed, if the preprocessing filter is adapted.
        self.Gnn, Gnm = self.compute_G()
        self.Inm = self.compute_Inm(Gnm)

        # Init nn.Parameters for trainable parameters.
        if self.weight_Inm:
            # 1 scalar weight for the Inm factor in each iteration, stage and branch.
            self.Inm_weight = t.nn.Parameter(t.ones((iters,stages,branches, 2*self.l2), device=device))

        if weight_F:
            #  2 scalar weights for the F factor in each iteration (+ init), stage and branch.
            self.F_weight = t.nn.Parameter(t.ones((self.iters+1, stages, branches, 2), device=device))

        if nbp:
            # Individual weights per iteration and port, but not for each of the N symbols.
            self.v2f_weights = t.nn.Parameter(t.ones((iters,stages, branches, 2*self.l2), device=device))
            self.f2v_weights = t.nn.Parameter(t.ones((iters,stages, branches, 2*self.l2), device=device))

        if self.weight_priors:
            # 1 scalar weight for each stage, branch and iterations, to weight the influence of the previous output.
            self.prior_weight = t.nn.Parameter(t.ones((iters+1, stages, branches), device=device))

    def compute_G(self):
        """
        Computes the matrix G = PH.
        Since the matrix multiplication represents a convolution, G has a band structure.
        Gnn outputs the diagonal value of G (all diagonal elements are equal) 
        and Gnm outputs one band (=row without zeros) of G.
        
        :returns: Tensor Gnn of shape [stages, branches].
        :returns: Tensor Gnm of shape [stages, branches, 2 L2].
        
        Note: At the moment, when this class was written, func.conv1d did not allow complex convolution.
              If you have a complex-valued channel and/or preprocessing filter, you need to adapt this function
              (either convolve RE and IM part separately, or use complex convolution, if provided by torch.)
        """
        # Convolve P and H to compute Gnm. 
        conv = func.conv1d(t.flip(self.p, dims=[-1]).view(self.branches*self.stages,1,self.p_len), 
                           self.h.repeat(1,1,1), 
                           padding=self.l).view(self.stages,self.branches,2*self.l2 + 1)
        
        # The convolution of P and H gives the values of the bands of G. 
        Gnm = t.zeros((self.stages,self.branches,2*self.l2), device=self.device)
        Gnm[:,:, : self.l2] = conv[:,:, : self.l2]
        Gnn = conv[:,:, self.l2].to(self.device)
        Gnm[:,:, self.l2:] = conv[:,:, self.l2+1 : ] 
        return Gnn, Gnm

    def compute_Inm(self, Gnm):
        """
        Compute parts of the factors I_kl(c_k,c_l): Inm = Re[Gnm cm cn]
        The Eb/N0 weight is multiplied in each specific batch, as the Eb/N0 (and possible trainable weights) 
        might vary over the iterations/training steps.

        :returns: Tensor of shape [stages, branches, 2 L2, M, M].
        """
        return -t.real(Gnm[:,:, :,None,None] * 
                self.const.mapping[None,None,None,None,:] * t.conj(self.const.mapping[None,None,None,:,None]))

    def forward(self, y, EsN0_lin, priors=None):
        """
        Accept channel observation y and apply symbol detection algorithm, as described in [1, Algorithm 1].
        The basic structure is:
        * Apply SPA on all parallel branches of one stage.
            * Each branch builds up an individual factor graph, by applying its preprocessing filter to the observation y.
            * Iteratively apply F2V and V2F update rule and pass messages between VNs and FNs (on each branch independently).
        * Merge results of all branches and pass results to the next stage.
        
        :param y: Channel observation of shape [batch_size, K].
        :param EsN0_lin: Es/N0 in linear domain, for each batch individually. Shape [batch_size].
        :returns: Logarithmic beliefs of each symbol. Shape [batch_size, N, M].
        """
        assert y.shape[0] == EsN0_lin.shape[0] # Check for equal batch sizes.
        
        # Compute G (if pfilter is learned, G may change within the lifetime of the detector).
        self.Gnn, Gnm = self.compute_G()
        self.Inm = self.compute_Inm(Gnm)
        
        # Apply preprocessing filters of all stages/branches in parallel.
        x = self.compute_x(y) # x has shape [batch_size, stages, branches, n].

        # Apply equalizer stages serially.
        for stage_i in range(self.stages):
            # Compute initial V2F message, based on (possibly preprocessed) observation.
            v2f_msg = self.compute_init_msg(x, EsN0_lin, priors, stage_i) # shape [batch_size, branches, stages, N, 2L_2, M]
            # Iteratively update factor nodes and variable nodes, based on SPA update rules.
            for i in range(self.iters):
                f2v_msg = self.FN_update(v2f_msg, EsN0_lin, iteration=i, stage=stage_i)
                v2f_msg, indiv_apps = self.VN_update(f2v_msg, x, EsN0_lin, iteration=i, stage=stage_i, priors=priors)
            
            # Combine individual outputs of all branches as input for next stage.
            assert indiv_apps.shape == (y.shape[0], self.branches, self.n, self.const.M)
            # Normalize to true logarithmic probabilities.
            normed_apps = indiv_apps - (self.jacobian_sum(indiv_apps, dim=-1))[...,None] 
            combi = t.sum(normed_apps, dim=1) # Combine APPs of all branches.
            # Normalize again after combi.
            priors = combi - (self.jacobian_sum(combi, dim=-1))[...,None] 
        return priors

    def FN_update(self, v2f_msg, EsN0_lin, iteration, stage):
        """
        Receives v2f messages. A row holds all messages outgoing from a VN.
        1.) Resort, so that a row holds all (future) ingoing messages to one FN.
        2.) Apply FN update rule (add Inm and aply Jacobian algorithm (log-sum-exp)).

        :param v2f_msg: Messages from VNs to FNs in the shape [batch_size, branches, N, 2 L2, M].
        :param iterations: Current iteration in which we are (relevant for iteration-specific parameters).
        :param stage: Current stage in which we are (relevant for stage-specific paramters).
        :returns: Updated messages from factor to variable nodes of shape [batch_size, branches, N, 2 L2, M].
        """
        assert v2f_msg.shape[1:] == (self.branches, self.n, 2*self.l2, self.const.M)
        #1.) Resort, so that all incoming messages to one FN are in the same dimension -2.
        v2f = t.zeros(v2f_msg.shape, dtype=float, device=self.device)
        for i in range(self.l2):
            v2f[:,:,self.l2-i:,i,:] = v2f_msg[:,:,:self.n-self.l2+i,-1-i,:]
            v2f[:,:,:self.n-self.l2+i,-1-i,:] = v2f_msg[:,:,self.l2-i:,i,:]

        #2.) Weight v2f messages (if enabled), add factor and marginalize out one dimension.
        if self.weight_Inm:
            Inm = self.Inm[stage,:,:,:,:] * self.Inm_weight[iteration, stage,:,:,None,None]
        else:
            Inm = self.Inm[stage]
        if self.nbp:
            core = (self.v2f_weights[iteration,stage,None,:,None,:,None] * v2f)[:,:,:,:,None,:] + \
                    EsN0_lin[:,None,None,None,None,None] * Inm[None,:,None,:,:,:]
        else: # no weights
            core = v2f[:,:,:,:,None,:] + \
                   EsN0_lin[:,None,None,None,None,None] * Inm[None,:,None,:,:,:] 
        del v2f

        # After epxansion and multiplication with Inm, core has the shape [batch, branch, n, 2L_2, M1, M2]
        # Marginalize out the last dimension from the incident message.
        f2v = self.jacobian_sum(core, dim=-1)
        del core

        # Normalize, so that maximum value is 0 (not an actual normalization to probabilities, only for numerical stability).
        f2v_normed = f2v - (t.max(f2v, dim=-1)[0])[...,None] 
        del f2v

        # Weight f2v messages (NBP).
        if self.nbp:
            return f2v_normed * self.f2v_weights[iteration,None,stage,:,None,:,None]
        else:
            return f2v_normed

    def VN_update(self, f2v_msg, x, EsN0_lin, iteration, stage, priors):
        """
        Sum up all incoming messages (also msg from F) -> single beliefs.
        For outgoing (extrinsic) messages, subtract intrinsic message, respectively.
        
        :param f2v_msg: Messages from factor to variable nodes of shape [batch_size, branches, N, 2 L2, M].
            Each dimension -2 holds all messages incoming to one specific variable node.
        :param x: Preprocessed observation of shape [batch_size, stages, branches, N].
        :EsN0_lin: Es/N0 in linear domain, for each batch individually. Shape [batch_size].
        :param iterations: Current iteration in which we are (relevant for iteration-specific parameters).
        :param stage: Current stage in which we are (relevant for stage-specific paramters).
        :param priors: Prior information (either a priori information about statistics, or info from previous stage).
        
        :returns: Updated messages from variable to factornodes of shape [batch_size, branches, N, 2 L2, M].
        :returns: Current beliefs about each variable node in each branch of shape [batch_size, branches, N, M].
        """
        assert f2v_msg.shape[1:] == (self.branches, self.n, 2*self.l2, self.const.M)
        # Sum up all incoming messages.
        beliefs = t.sum(f2v_msg, dim=-2) + self.compute_F(x, EsN0_lin, iteration=iteration+1, stage=stage)
        # Apply weights on messages (NBP).
        if priors != None:
            assert priors.shape == (beliefs.shape[0], self.n, self.const.M)
            if self.weight_priors:
                beliefs += self.prior_weight[iteration+1, stage,None,:,None,None] * priors[:,None,:,:]
            else:
                beliefs += priors[:,None,:,:]
        v2f_msg = beliefs[:,:,:,None,:] - f2v_msg # beliefs has shape [batch, branch, N, M].
        
        return v2f_msg, beliefs

    def compute_x(self, y):
        """
        Apply preprocessor (e.g., matched filter), for each branch and stage.
        
        :param y: Channel observation of shape [batch_size, K].
        :returns: x of shape [batch_size, stages, branches, N(= K-L = block length)]
        
        Note: At the moment, when this class was written, func.conv1d did not allow complex convolution.
        """
        batch_size = y.shape[0]

        # Matched filter output of the observation.
        assert len(y.shape) == 2 and y.shape[1] == self.k

        x =(1.0 * func.conv1d(y.real.view(batch_size,1,self.k), t.flip(self.p, dims=[-1]).view(self.branches*self.stages,1,-1), padding=self.os_extend) + 
            1.0j* func.conv1d(y.imag.view(batch_size,1,self.k), t.flip(self.p, dims=[-1]).view(self.branches*self.stages,1,-1), padding=self.os_extend)).view(batch_size,self.stages, self.branches,self.n)
        return x

    def compute_F(self, x, EsN0_lin, iteration, stage):
        """
        Compute F factor (depending on the channel observation x, the SNR, the constellation  and Gnn).
        
        :param x: Preprocessed observation of shape [batch_size, stages, branches, N].
        :EsN0_lin: Es/N0 in linear domain, for each batch individually. Shape [batch_size].
        :param iterations: Current iteration in which we are (relevant for iteration-specific parameters).
        :param stage: Current stage in which we are (relevant for stage-specific paramters).
        
        :returns: factor F in log domain of shape [batch_size, branches, N, M]
        """
        batch_size = x.shape[0]
        if self.weight_F:
            F = EsN0_lin[:,None,None,None] * t.real((self.F_weight[iteration,None,stage,:,None,0] * x[:,stage,:,:])[...,None] * t.conj(self.const.mapping)[None,None,None,:] - 
                    (self.F_weight[iteration,stage,:,None,1] * ((self.Gnn/2)[stage,:,None] * (t.abs(self.const.mapping)**2)[None,:]))[None,:,None,:] )
        else:
            F = EsN0_lin[:,None,None,None] * t.real(x[:,stage,:,:,None] * t.conj(self.const.mapping)[None,None,None,:] - 
                    ((self.Gnn/2)[stage,:,None] * (t.abs(self.const.mapping)**2)[None,:])[None,:,None,:] )
        return F.view(batch_size, self.branches, self.n, self.const.M)

    def compute_init_msg(self, x, EsN0_lin, priors, stage):
        """ 
        Compute messages to initialize the message passing.
        Compute F and broadcast F+priors (both optionally weighted) to all 2*L2 ports. 
        """
        F = self.compute_F(x, EsN0_lin, iteration=0, stage=stage).view((-1, self.branches,self.n, 1, self.const.M))
        # F has shape [batch_size, branches, n, 1, M]
        batch_size = F.shape[0]
        if priors == None: # If no priors are fed, we only use the channel observation, assuming uniformly distributed symbols.       
            return F.repeat(1,1,1,2*self.l2,1) # Repeat msg as broadcast to all factor nodes.
        else:
            assert priors.shape == (batch_size, self.n, self.const.M)
            if self.weight_priors:
                return (F + (self.prior_weight[0,None,stage,:,None,None,None] * priors[:,None,:,None,:])) \
                        .repeat(1,1,1,2*self.l2,1) # Repeat msg as broadcast to all factor nodes.
            else:
                return (F + priors[:,None,None,:,:]).view(
                        batch_size,self.stages,self.branches,self.n,1,self.const.M).repeat(1,1,1,1,2*self.l2,1) # Repeat msg as broadcast to all factor nodes.

    def jacobian_sum(self, msg, dim):
        """
        Computes ln(e^a_1 + e^a_2 + ... e^a_M) of a tensor with last dimension (a_1, a_2, ..., a_M)
        by applying the Jacobian algorithm. (Also called log-sum-exp operation.)
        """
        assert msg.shape[dim] > 1
        if dim == -1:
            return t.max(msg, dim=-1)[0] + t.log(t.sum(t.exp(msg - t.max(msg, dim=-1, keepdim=True)[0]), dim=-1))
        elif dim == -2:
            return t.max(msg, dim=-2)[0] + t.log(t.sum(t.exp(msg - t.max(msg, dim=-2, keepdim=True)[0]), dim=-2))

    def save_weights(self, filename: str):
        """
        Helper function, to save all trainable paramters.
        :param filename: Path (as character string) to the directory where to save the parameters.
        """
        if self.weight_Inm:
            t.save(self.Inm_weight, filename+"Inm_weight.pt")
        if self.weight_F:
            t.save(self.F_weight, filename+"F.pt")
        if self.learn_pfilter:
            t.save(self.p, filename+"pfilter.pt")
        if self.nbp:
            t.save(self.v2f_weights, filename+"v2f.pt")
            t.save(self.f2v_weights, filename+"f2v.pt")
        if self.weight_priors:
            t.save(self.prior_weight, filename+"prior_weight.pt")
        
    def load_weights(self, filename: str):
        """
        Helper function, to load all trainable parameters from files.
        :param filename: Path (as character string) to the directory where to load the parameters from.
        """
        if self.weight_Inm:
            self.Inm_weight = t.nn.Parameter(t.load(filename+"Inm_weight.pt", map_location=self.device))
        if self.weight_F:
            self.F_weight = t.nn.Parameter(t.load(filename+"F.pt", map_location = self.device))
        if self.learn_pfilter:
            self.p = t.nn.Parameter(t.load(filename+"pfilter.pt", map_location = self.device))
        if self.nbp:
            self.v2f_weights = t.nn.Parameter(t.load(filename+"v2f.pt", map_location = self.device))
            self.f2v_weights = t.nn.Parameter(t.load(filename+"f2v.pt", map_location = self.device))
        if self.weight_priors:
            self.prior_weight = t.nn.Parameter(t.load(filename+"prior_weight.pt", map_location = self.device))

In [5]:
# Helper function for simulations
def detect(EbN0_dB_min, EbN0_dB_max):
    """
    Randomly generates data batch, simulates channel and runs equalization algorithm.
    The Eb/N0 is sampled from a uniform distribution in (EbN0_dB_min, EbN0_dB_max), assuming that the
    average power of the applied constellation is normalized to 1.
    
    :returns: beliefs of algorithm (approximated APPs).
    :returns: label bits
    """
    # Compute Es/N0 (lin) from Eb/N0 (dB).
    EsN0_lin_min = 10 ** (EbN0_dB_min / 10) * const.m
    EsN0_lin_max = 10 ** (EbN0_dB_max / 10) * const.m
    
    # Simulate channel.
    bits = t.randint(2, size=(batch_size, block_len*const.m)).to(device) # Generate random bits.
    tx = const.map(bits) # Map bits to symbols.
    # Convolve with channel.
    rx =(      t.nn.functional.conv1d(tx.real.view(batch_size,1,block_len), t.flip(channel, dims=[0]).view(1,1,-1), padding=l) + \
         1.0j* t.nn.functional.conv1d(tx.imag.view(batch_size,1,block_len), t.flip(channel, dims=[0]).view(1,1,-1), padding=l)).view(batch_size, block_len+l)

    # Decide for noise level for each individual batch (uniformly distributed).
    EsN0_lin_batches = ((EsN0_lin_max - EsN0_lin_min) * t.rand(batch_size) + EsN0_lin_min).to(device)
    rx += t.randn((batch_size, block_len+l), dtype=t.cfloat, device=device) / t.sqrt(EsN0_lin_batches[:,None])
    # Run factor graph equalizer. 
    return eq(rx, EsN0_lin_batches), bits

**Configure the GAP algorithm and the transmisison scenario here:**
* Transmission scenario: constellation, channel impulse response, block length
* GAP parameters: 
  * Number of branches B, stages S and SPA iterations 
  * Preprocessing filter(s), e.g., matched filter in case of the Ungerboeck observation model
  * Boolean flags to specifically dis/enable trainable parameters
  
  
Note: to recreate the original factor graph-based detection algorithm based on the Ungerboeck model [2] (which serves as a basis for the GAP algorithm), parametrize the GAP algorithm with B=S=1, use a matched filter as preprocessor and disable all Boolean flags.

[2] Colavolpe, Giulio, Dario Fertonani, and Amina Piemontese. “SISO Detection over Linear Channels with Linear Complexity in the Number of Interferers.” IEEE Journal of Selected Topics in Signal Processing 5, no. 8 (December 2011): 1475–85. https://doi.org/10.1109/JSTSP.2011.2168943.

In [6]:
# Simulation parameters
# Transmission scenario
const = constellation(bpsk_mapping, device) # Select constellation.
channel = t.tensor(ProakisB, device=device) # Select channel.
l = len(channel)-1 # channel memory
block_len = 100 # Number of information symbols per transmission block

# Algorithm parameters
branches = 1 # Number of parallel branches, the algorithm uses (called B in [1])
stages = 1 # Number of serial stages/ dynamic factor graph transitions (called S in [1])
iters = 6 # SPA iterations per stage and branch (called N' in [1])
# Choose preprocessing filter(s). The original Ungerboeck model uses a matched filter, the GAP
p_filter = t.flip(channel, dims=[0]) # matched filter
#p_filter = t.randn((stages, branches, 7), device=device) # (Initial) preprocessing filter

# Instantiate equalizer. Specify which parts of the algorithm should be optimized in the following training precedure, 
# by setting the respective flags.
eq = GAP(block_len = block_len, channel_taps = channel, constellation = const,
         pfilter=p_filter, branches = branches, stages = stages, iters=iters,
         device=device,
         weight_Inm=True, weight_F=True, learn_pfilter=False, nbp=True, weight_priors=True
         )
# Load parameters with eq.load_weights() here, if desired. Otherwise they are initialized with 1.0.


In [7]:
# Define some objective functions to evaluate detection performance and for training.

def BER(log_app, label_bits, constellation):
    """
    Computes the (hard-decision) bit error rate (BER) for a soft-output equalizer with given labeled data.

    :param log_app: A posterior probabilities of the symbols in logarithmic domain.
    :param label_bits: The actual transmitted bits. The label of the data batch.
    :returns 1-element tensor (scalar) BER value of the labeled data set, averaged over the complete batch.
    """
    # Assert matching shapes of result data and label data.
    assert log_app.shape[-2] == label_bits.shape[-1]/constellation.m
    assert log_app.shape[-1] == constellation.M # The last shape should be the log-probabilities of the symbols.
    llrs = constellation.bit_metric_decoder(log_app)
    return t.sum(t.where(llrs > 0, 0, 1) != label_bits) / label_bits.numel() # Return bit error rate.

def BMI(log_app, label_bits, constellation):
    """
    Computes the bitwise mutual information (BMI). See Sec.IV.D in [1] for details.
    Note that if the input probabilities (log_app) are mismatched, this BMI estimation can be negative.

    :param log_app: Logarithmic APP estimations of the symbol detector.
    :param label_bits: The actually sent bits.
    :returns BMI (scalar value)
    """
    # Apply bit metric decoder to log APPs of symbols to get bit-wise LLRs.
    assert log_app.shape[-2] == label_bits.shape[-1]/constellation.m
    assert log_app.shape[-1] == constellation.M # The last shape should be the log-probabilities of the symbols.
    llrs = constellation.bit_metric_decoder(log_app)
    return constellation.m * (1 - t.mean(1/np.log(2) * (t.clamp((2*label_bits-1) * llrs, 0) + t.log(1+t.exp(-t.abs((2*label_bits-1) * llrs))))))


**Configure the training parameters here.**

If you don't want to optimize any parameters, you can skip this cell.

Note: 
* You might need to adapt the batch_size and the number of batches, depending on your specific transmission scenario and GAP parametrization to enable an effective optimization.
* This implementation of the BMI does not correct a mismatching of the LLR values. It is thus possible for the BMI to be negative, if the LLRs are not based on 'real' probabilities, but e.g., on approximations of probabilities. There are ways to correct this (see [1, Sec. IV.D]), however, for training purposes, this implementation of the BMI is sufficient.

In [8]:
# Training parameters
EbN0_db_training_min = 10 # lower bound for uniform distribution of Eb/N0 in dB during the training
EbN0_db_training_max = 10 # upper bound for uniform distribution of Eb/N0 in dB during the training
batch_size = 100 # number of parallel blocks per training batch
learning_rate = 0.001 # Learning rate of the optimizer.
optimizer = t.optim.Adam(eq.parameters(), lr=learning_rate) # optimizer for paramter optimization
batches = 1000 # number of batches to train on, e.g., the number of gradient descent steps

# training loop, optimize parameters w.r.t. the BMI
for batch_no in range(batches):
    loss = - BMI(*(detect(EbN0_db_training_min, EbN0_db_training_max)), const) # minimize the negative BMI (i.e., maximize the BMI).
    # optimization step(stochastic gradient descent)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Print training progress regularly.
    if batch_no % 100 == 0:
        print(f"BMI: {-loss:>7f}  [{batch_no:>7d}/{batches:>7d}]")

# (optionally) save the optimization result to files
#eq.save_weights(f"path/to/directory/param_")

BMI: -1.255365  [      0/   1000]
BMI: 0.411927  [    100/   1000]
BMI: 0.742328  [    200/   1000]
BMI: 0.794135  [    300/   1000]
BMI: 0.852713  [    400/   1000]
BMI: 0.882576  [    500/   1000]
BMI: 0.918109  [    600/   1000]
BMI: 0.928771  [    700/   1000]
BMI: 0.932534  [    800/   1000]
BMI: 0.945980  [    900/   1000]


**Configure the evaluation of the GAP algorithm here:**

Note:
* You might need to adapt the batch_size and the number of batches, depending on your specific transmission scenario and GAP parametrization in order to get reliable approximations of the BMI and BER.
* This implementation of the BMI does not correct a mismatching of the LLR values. It is thus possible for this estimation of the BMI to be negative, if the LLRs are not based on 'real' probabilities, but e.g., on approximations of probabilities. There are ways to correct this (scaling of the LLR values) (see [1, Sec. IV.D] for a more detailed discussion). To keep this notebook simple, we don't incluce this mismatch correction here.

In [9]:
# eval config
EbN0_dB_range = t.arange(0,13,2) # Range of Eb/N0 values (in dB) that we evaluate
batch_size = 10**2 # number of parallel blocks for the evaluation

# evaluate the optimized GAP algorithm w.r.t. the BER and BMI for some Eb/N0 values
BER_over_EbN0 = t.empty(len(EbN0_dB_range), device=device) # prepare results tensor
BMI_over_EbN0 = t.empty(len(EbN0_dB_range), device=device) # prepare results tensor
with t.no_grad():
    for i,EbN0_dB in enumerate(EbN0_dB_range):
        BMI_over_EbN0[i] = BMI(*(detect(EbN0_dB, EbN0_dB)), const)
        BER_over_EbN0[i] = BER(*(detect(EbN0_dB, EbN0_dB)), const)

print(f" Evaluation for Eb/N0: {EbN0_dB_range}\nBMI: {BMI_over_EbN0} \nBER: {BER_over_EbN0}")
        

 Evaluation for Eb/N0: tensor([ 0,  2,  4,  6,  8, 10, 12])
BMI: tensor([0.4331, 0.5472, 0.6628, 0.7753, 0.8736, 0.9342, 0.9648],
       device='cuda:0') 
BER: tensor([0.1743, 0.1440, 0.0927, 0.0675, 0.0279, 0.0152, 0.0110],
       device='cuda:0')
