Structural Optimization of Factor Graphs for Symbol Detection via Continuous Clustering and Machine Learning
---
Lukas Rapp, Luca Schmid, Andrej Rode, and Laurent Schmalen

We propose a novel method to optimize the structure of factor graphs for graph-based inference. As an example inference task, we consider symbol detection on linear inter-symbol interference channels.  The factor graph framework has the potential to yield low-complexity symbol detectors. However, the sum-product algorithm on cyclic factor graphs is suboptimal and its performance is highly sensitive to the underlying graph. Therefore, we optimize the structure of the underlying factor graphs in an end-to-end manner using machine learning. For that purpose, we transform the structural optimization into a clustering problem of low-degree factor nodes that incorporates the known channel model into the optimization. Furthermore, we study the combination of this approach with neural belief propagation, yielding near-maximum a posteriori symbol detection performance for specific channels.

The full paper [1] is available on arXiv.

[1] L. Rapp, L. Schmid, A. Rode, and L. Schmalen, “Structural Optimization of Factor Graphs for Symbol Detection via Continuous Clustering and Machine Learning,” 2023, arXiv:2211.11406. [Online]. Available: https://arxiv.org/abs/2211.11406.

For questions, discussions or improvements of this code, feel free to contact us (Lukas Rapp, first.last@kit.edu) or open an issue/create a pull request.

---

This work has received funding in part from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No. 101001899) and in part from the German Federal Ministry of Education and Research (BMBF) within the project Open6GHub (grant agreement 16KISK010).

---

Copyright (c) 2022-2023 Lukas Rapp - Communications Engineering Lab (CEL), Karlsruhe Institute of Technology (KIT)

<sup> Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

<sup> The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

<sup> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

In [None]:
# Import external libraries
import numpy as np
import torch as t
device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")
from torch.nn import functional as F
from torch import nn
from dataclasses import dataclass
import factor_graph as fg

## Sec. 2.2. Symbol Detection

In [None]:
# Define impulse responses
h_paper = [0.407, 0.1, 0.815, 0.1, 0.407]
h_proakis_c = [0.227, 0.460, 0.688, 0.460, 0.277]

# Define the channel model as defined in equation (1) in Sec. 2.2
class FiniteMemoryChannel:
    """
    Describes the channel model in Sec. 2.2, equation (1), i.e., the cyclic convolution with the finite impulse response h
    followed by AWGN.
    """
    def __init__(self, h: t.tensor, snr_db: float, device='cpu', normalize_taps=True):
        """
        @param h: Channel impulse response
        @param snr_db: SNR of AWGN
        @param device: PyTorch device
        @param normalize_taps: Determines if the impulse response is normalized so that its energy is 1.
        """

        self.h = h.to(device)
        if normalize_taps:
            self.h = self.h / t.sqrt(t.sum(t.abs(self.h) ** 2))

        self.memory = t.numel(self.h) - 1
        self.snr_lin = 10 ** (snr_db / 10)
        self.device = device

    def apply(self, x: t.Tensor) -> t.Tensor:
        """
        Applies channel model equation (1) to x
        @param x: transmitted symbols
        @return: y: received distorted symbols
        """

        # Calculates convolution of x and h: For this, first the convolution of x and h is calculated using F.conv1d
        # and the effect of the cyclic conv. is added afterward.
        y_not_cyclic = (F.conv1d(x.real[:, None, :], t.flip(self.h, dims=[0])[None, None, :], padding=self.memory)
                        + 1.0j * F.conv1d(x.imag[:, None, :], t.flip(self.h, dims=[0])[None, None, :], padding=self.memory))[:, 0, :]

        if self.memory > 0:
            y = y_not_cyclic[:, :-self.memory]
            y[:, :self.memory] += y_not_cyclic[:, -self.memory:]
        else:
            y = y_not_cyclic

        # Add AWGN
        y += t.randn(y.shape, dtype=t.cfloat, device=self.device) / np.sqrt(self.snr_lin)
        return y

In [None]:
# Equation (4): Factors of the UFG

def calc_F_I_potentials_UFG(y: t.Tensor, channel: FiniteMemoryChannel, constellation: t.Tensor) -> tuple[t.Tensor, t.Tensor]:
    """
    Calculates the factors (in log-domain) of the UFG using equation (4) in the paper
    @param y: Received distorted symbols (equation (1))
    @param channel: FiniteMemoryChannel object representing the channel impulse response and the awgn noise
    @param constellation: tensor container a set of constellation points that each VN can take
    @return: tuple (F_log, I_log) of tensors
    F_log[batch_idx, k, value_idx] = log(F_k(self.constellation[value_idx])) where F_k is the factor in equation (4)
    I_log[batch_idx, l, value_idx_1, value_idx_2]
        = log(I_l(self.constellation[value_idx_1], self.constellation[value_idx_2])) where I_l is the factor in equation (4)
    """

    h = channel.h

    # Auxiliary 1-dim tensor to calculate F and I (q[l] corresponds to q_{l+1} in equation (4) of the paper)
    q = t.zeros(len(h), device=h.device)
    for l in range(len(h)):
        if l == 0:
            q[l] = t.sum(h * t.conj(h))
        else:
            q[l] = t.sum(h[l:] * t.conj(h[:-l]))

    # Calculates the correlation between h and y (first term in the exponential) of F_k
    # correlation_h_and_y_(not_)cyclic are tensors with indices [batch_idx, k] where "k" corresponds to the index of F_k
    # To speed up the calculation, we use the convolution of Pytorch resulting in a non-cyclic convolution,
    # and add the effect of the cyclic convolution in the next step resulting in "correlation_h_and_y_cyclic"
    correlation_h_and_y_not_cyclic = (F.conv1d(y.real[:, None, :], channel.h[None, None, :], padding=channel.memory) +
                                      1.0j * F.conv1d(y.imag[:, None, :], channel.h[None, None, :], padding=channel.memory))[:, 0, :]
    correlation_h_and_y_cyclic = correlation_h_and_y_not_cyclic[:, channel.memory:]
    correlation_h_and_y_cyclic[:, -channel.memory:] += correlation_h_and_y_not_cyclic[:, :channel.memory]

    # F_potentials as defined in the docstring above
    F_potentials = 2 * channel.snr_lin * t.real(
        correlation_h_and_y_cyclic[:, :, None] * t.conj(constellation[None, None, :])
        - q[0] / 2 * t.abs(constellation[None, None, :]) ** 2
    )

    # I_potentials as defined in the docstring above
    I_potentials = -2 * channel.snr_lin * t.real(
        q[1:, None, None] * t.conj(constellation[None, :, None]) * constellation[None, None, :])

    return F_potentials, I_potentials


## Sec. 3.1. Factor Node Containers

The next cell defines a factor graph class that stores the FN containers used for clustering. The implementation follows Sec. 3.1

In [None]:
# The first two data classes are data structures for FNs of degree 1 and 2 in the basic factor graph (UFG) and store lists of FN containers with which the FNs are connected. Similar to the set M_m in equation (7).

@dataclass
class FNDeg2:
    vn_idx_0: int # Index of the first VN with which the FN is connected
    vn_idx_1: int # Index of the second VN with which the FN is connected
    vn_distance: int # Distance between both VNs with which the FN is connected: vn_distance = |vn_idx_1 - vn_idx_0|

    # List of all FN containers with which the FN can be connected:
    # Each entry  [a, [b, c]] is a FN container (idx: a) at which the FN can be connected at slot b and c
    connectable_fn_container: list[tuple[int, tuple[int, int]]]

@dataclass
class FNDeg1:
    vn_idx_0: int # Index of the VN with which the FN is connected

    # List of all FN containers with which the FN can be connected
    # Each entry  [a, b] is an FN container (idx: a) at which the FN can be connected at slot b
    connectable_fn_container: list[tuple[int, int]]


class ContainerFactorGraph(fg.FactorGraph):
    """
    Extends the general factor graph class by containers of a fixed degree in which FNs of degree 1 and 2 of the basic factor graph for symbol detection
    (here Ungerboeck-based factor graph (UFG)) can be clustered
    """

    # List of all FN containers: Each entry is a tuple containing the indices of the VNs with which the FN container is connected
    fn_container_list = None

    def __init__(self, block_len: int, number_constellation_points: int, max_span_fn_container: int, channel_length: int,
                 fn_container_degree: int, device: t.device):
        """
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @param number_constellation_points: Number of different constellation points each symbol can take
        @param max_span_fn_container: Maximal span of the fn containers as defined in the paper Sec 3.1
        @param channel_length: Length of the channel impulse response
        @param fn_container_degree: degree of all fn containers
        @param device: PyTorch device
        """
        self.fn_container_degree = fn_container_degree

        if fn_container_degree == 3:
            self.fn_container_list = self._create_degree_3_fn_containers(block_len, max_span_fn_container)
        elif fn_container_degree == 4:
            self.fn_container_list = self._create_degree_4_fn_containers(block_len, max_span_fn_container)
        else:
            raise NotImplementedError(f'degree {fn_container_degree} fn container not implemented.')

        self.fn_degree_2_list = self._create_fns_degree_2(channel_length, block_len)
        self.fn_degree_1_list = self._create_fns_degree_1(block_len)

        # Using the fn_lists, calculate the parameters which parameterize this FG and create a general FG with them.
        biadjacency, fn_start_slots = self._create_fg_parameters(block_len)
        super().__init__(biadjacency, number_constellation_points, device, False, fn_start_slots)

    def _create_degree_4_fn_containers(self, block_len: int, max_span_fn_container: int) -> list[tuple]:
        """
        Create all the degree 4 fn containers with maximal span max_span_fn_container that can be connected with the VNs
        representing the symbols
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @param max_span_fn_container: Maximal span of the fn containers as defined in the paper Sec 3.1
        @return: List of FN containers: Each entry is a tuple containing the indices of the VNs with which the FN container is connected
        """

        fn_container_list = []
        for vn_idx_0 in range(block_len):
            for local_idx_1 in range(1, max_span_fn_container - 2):
                for local_idx_2 in range(local_idx_1 + 1, max_span_fn_container - 1):
                    for local_idx_3 in range(local_idx_2 + 1, max_span_fn_container):
                        vn_idx_1 = (vn_idx_0 + local_idx_1) % block_len
                        vn_idx_2 = (vn_idx_0 + local_idx_2) % block_len
                        vn_idx_3 = (vn_idx_0 + local_idx_3) % block_len

                        fn_container_list.append((vn_idx_0, vn_idx_1, vn_idx_2, vn_idx_3))

        return fn_container_list

    def _create_degree_3_fn_containers(self, block_len: int, max_span_fn_container: int) -> list[tuple]:
        """
        Create all the degree 3 fn containers with maximal span max_span_fn_container that can be connected with the VNs
        (#block_len) representing the symbols
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @param max_span_fn_container: Maximal span of the fn containers as defined in the paper Sec 3.1
        @return: List of FN containers: Each entry is a tuple containing the indices of the VNs with which the FN container is connected
        """

        fn_container_list = []
        for vn_idx_0 in range(block_len):
            for local_idx_1 in range(1, max_span_fn_container - 1):
                for local_idx_2 in range(local_idx_1 + 1, max_span_fn_container):
                    vn_idx_1 = (vn_idx_0 + local_idx_1) % block_len
                    vn_idx_2 = (vn_idx_0 + local_idx_2) % block_len

                    fn_container_list.append((vn_idx_0, vn_idx_1, vn_idx_2))

        return fn_container_list

    def _create_fns_degree_2(self, channel_length: int, block_len: int) -> list[FNDeg2]:
        """
        Create all FNs of degree 2 of the UFG that can be connected with the VNs (#block_len) representing the symbols
        @param channel_length: Length of the channel impulse response
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @return: List of FNDeg2 of all FNs of degree 2 which can be connected with the VNs of the FG
        """

        fn_deg2_list = []
        for vn_distance in range(1, channel_length):
            for vn_idx_0 in range(block_len):
                vn_idx_1 = (vn_idx_0 + vn_distance) % block_len

                local_fn_degX_list = []
                for slot_1 in range(1, self.fn_container_degree):
                    for slot_0 in range(slot_1):
                        local_fn_degX_list.extend([(fn_degX_Idx, (slot_0, slot_1))
                            for fn_degX_Idx, fn_degX in enumerate(self.fn_container_list)
                            if (vn_idx_0 == fn_degX[slot_0] and vn_idx_1 == fn_degX[slot_1])])

                fn_deg2_list.append(FNDeg2(vn_idx_0, vn_idx_1, vn_distance, local_fn_degX_list))

        return fn_deg2_list

    def _create_fns_degree_1(self, block_len: int) -> list[FNDeg1]:
        """
        Create all FNs of degree 1 of the UFG that can be connected with the VNs (#block_len) representing the symbols
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @return: List of FNDeg1 of all FNs of degree 1 which can be connected with the VNs of the FG
        """

        fn_deg1_list = []
        for vn_idx_0 in range(block_len):
            local_fn_degX_list = []
            for slot_fn_degX in range(self.fn_container_degree):
                local_fn_degX_list.extend([(fn_degX_idx, slot_fn_degX) for fn_degX_idx, fn_degX in enumerate(self.fn_container_list) if vn_idx_0 == fn_degX[slot_fn_degX]])
            fn_deg1_list.append(FNDeg1(vn_idx_0, local_fn_degX_list))
        return fn_deg1_list

    def _create_fg_parameters(self, block_len) -> tuple[np.ndarray, np.ndarray]:
        """
        Calculate parameters needed to create FG
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @return: tuple (biadjacency, fn_start_slots)
        """

        biadjacency = t.zeros(block_len, len(self.fn_container_list), dtype=t.long)
        fn_start_slots = t.zeros(len(self.fn_container_list), dtype=t.long)
        for idx, vn_list in enumerate(self.fn_container_list):
            for i in range(self.fn_container_degree - 1):
                if vn_list[i + 1] < vn_list[i]:
                    fn_start_slots[idx] = i + 1
                    break

            for vn_idx in vn_list:
                biadjacency[vn_idx, idx] = 1

        return biadjacency, fn_start_slots

## Sec. 3.2 and 3.3. Continuous Clustering of FNs

The following class ClusteredFactorGraph implements continuous clustering and calculates the factors of the clustered FN containers (equation (6)).
These factors are later used in the SPA in the FactorGraph class.

Since the practical implementation of "factors" differs from the mathematical definition, we will give a brief overview of how factors are handled in the FactorGraph class:
The FactorGraph class applies the SPA in the log-domain and therefore, also the local function of an FN is stored in log-domain which we call the "potential" of an FN in the following.
The potential $t$ of a FN  $f(x_1, ..., x_n)$ is an n-dimensional PyTorch tensor with dimension [len(self.constellation, ..., len(self.constellation)] whose entries store all values which $f(\cdot)$ can take:
$$
t(i_1, ..., i_n) = \log(f(\text{self.constellation}[i_1], ..., \text{self.constellation}[i_n]), \qquad \text{for all $(i_1, ..., i_n) \in |\text{len(self.constellation)}|^n$}
$$

In [None]:
class ClusteredFactorGraph(nn.Module):
    """
    Implements the continuous clustering of the FNs of the basis factor graph and manages the weights of the NBP
    """

    def __init__(self, block_len: int, constellation: t.Tensor, max_span_fn_container: int, channel_length: int,
                 fn_container_degree: int, bp_iterations: int, device: t.device):
        """
        @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
        @param constellation: tensor container a set of constellation points that each VN can take
        @param max_span_fn_container: Maximal span of the fn containers as defined in the paper Sec 3.1
        @param channel_length: Length of the channel impulse response
        @param fn_container_degree: degree of all fn containers
        @param bp_iterations: Number of belief propagation that will be executed
        @param device: PyTorch device
        """

        super().__init__()

        self.bp_iterations = bp_iterations

        self.constellation = constellation
        self.block_len = block_len
        self.device = device

        self.graph = ContainerFactorGraph(block_len, len(constellation), max_span_fn_container,
                                          channel_length, fn_container_degree, device)

        # Weights for clustering of FNs of degree 1 (Called "beta_ij" in the paper. See Sec. 3.3 of the paper equation (8).)
        self.clustering_weights_fn_deg_1 = []
        for idx, fn in enumerate(self.graph.fn_degree_1_list):
            weights = nn.Parameter(t.randn(len(fn.connectable_fn_container), requires_grad=True, device=device))
            self.register_parameter(f'deg1_{idx}', weights)
            self.clustering_weights_fn_deg_1.append(weights)

        # Weights for clustering of FNs of degree 1 (Called "beta_ij" in the paper. See Sec. 3.3 of the paper equation (8).)
        self.beta_fn_deg_2 = []
        for idx, fn in enumerate(self.graph.fn_degree_2_list):
            weights = nn.Parameter(t.randn(len(fn.connectable_fn_container), requires_grad=True, device=device))
            self.register_parameter(f'deg2_{idx}', weights)
            self.beta_fn_deg_2.append(weights)

        # Weights for NBP
        self.weights_vn_incoming_messages = \
            nn.Parameter(t.ones((bp_iterations, self.graph.vn_connections.shape[0], self.graph.vn_degree_max), device=device))
        self.weights_fn_incoming_messages = \
            nn.Parameter(t.ones((bp_iterations, self.graph.fn_connections.shape[0], self.graph.fn_degree_max), device=device))

        self.graph.set_nbp_weights(self.weights_vn_incoming_messages, self.weights_fn_incoming_messages)

    def forward(self, y: t.Tensor, channel: FiniteMemoryChannel) -> t.Tensor:
        """
        Executes the sum-product algorithm on the factor graph and returns the log beliefs of the VNs
        @param y: Received symbol sequence with noise (called y in the paper, see equation (1)).
        @param channel: Channel with which the transmitted symbols are distorted.
        @return:Logarithmic belief for each VN: tensor with indices (batch, vn_idx, message_values)
        """

        batch_size = y.shape[0]

        # Calculates the factors F_k(x_k), I_l(x_k, x_{k+l}) of equation (4) in the paper in log-domain where
        # F_log[batch_idx, k, value_idx] = log(F_k(self.constellation[value_idx])) and
        # I_log[batch_idx, l, value_idx_1, value_idx_2] = log(I_l(self.constellation[value_idx_1], self.constellation[value_idx_2]))
        F_log, I_log = calc_F_I_potentials_UFG(y, channel, self.constellation)

        # Calculate the potentials of the clustered factor graph
        log_potentials_clustered_fg = t.zeros((len(self.graph.fn_container_list), batch_size,) +
            (len(self.constellation),) * self.graph.fn_container_degree, device=self.device)

        # Implements equation (7) for FNs of degree 2 (the variables i, j, and m correspond to the variables in equation (7)):
        # For all FNs f of degree 2 in the graph, the loop iterates over all FN containers that are connected with the FN f
        # and adds the potential of f weighted by the corresponding clustering coefficient to the potential of the FN container.
        for i, fn_i_data in enumerate(self.graph.fn_degree_2_list):
            alpha_i = F.softmax(self.beta_fn_deg_2[i], 0)

            # Iterate over all FN containers with which the FN_i can be connected.
            for j in range(len(fn_i_data.connectable_fn_container)):
                (m, (fn_slot_vn_1, fn_slot_vn_2)) = fn_i_data.connectable_fn_container[j]

                # Adds the potential of the current FN of degree 2 (I_log) weighted by alpha_fn_deg_2
                # (Since the factors of the FNs are in log-domain the multiplication in equation (7) becomes an addition
                # and the exponentiation with alpha becomes a multiplication with alpha)
                log_potentials_clustered_fg[m] += \
                    alpha_i[j] * I_log[
                        [fn_i_data.vn_distance - 1, None]
                        + fg.create_indexing(self.graph.fn_container_degree, fn_slot_vn_1, fn_slot_vn_2)
                        ]

        # Implements equation (7) for FNs of degree 1 (similar to the loop above for FNs of degree 2)
        for fn_deg_1_idx, fn_deg_1_data in enumerate(self.graph.fn_degree_1_list):
            alpha_fn_deg_1 = F.softmax(self.clustering_weights_fn_deg_1[fn_deg_1_idx], 0)

            # Adds the potential of the current FN of degree 1 (F_log) weighted by alpha_fn_deg_1
            # (Since the factors of the FNs are in log-domain the multiplication in equation (7) becomes an addition
            # and the exponentiation with alpha becomes a multiplication with alpha)
            for j in range(len(fn_deg_1_data.connectable_fn_container)):
                (m, fn_slot_vn_1) = fn_deg_1_data.connectable_fn_container[j]

                log_potentials_clustered_fg[m] += \
                    alpha_fn_deg_1[j] * F_log[
                        [slice(None), fn_deg_1_idx] + fg.create_indexing(self.graph.fn_container_degree, fn_slot_vn_1)
                        ]

        self.graph.load_potentials(log_potentials_clustered_fg, batch_size)
        return self.graph.sum_product_algorithm(self.bp_iterations, True)

## Sec. 3.3: Training

In [None]:
def simulate_detection(batch_size: int, block_len: int, constellation: t.Tensor,
                       channel: FiniteMemoryChannel, clustered_factor_graph: ClusteredFactorGraph,
                       device: t.device) -> tuple[t.tensor, t.tensor]:
    """
    This function distorts random bits using the channel model, applies symbol detection on them, and calculates the resulting BER and loss.
    @param batch_size:
    @param block_len: Length of the symbol sequence for symbol detection ("K" in the paper)
    @param constellation: tensor container a set of constellation points that each VN can take
    @param channel: Channel with which the transmitted symbols are distorted.
    :param clustered_factor_graph: Clustered factor graph model which is used for symbol detection.
    :param device: PyTorch device
    :return:
    """
    # Section 2.2
    tx_bits = t.randint(2, (batch_size, block_len))
    x = constellation[tx_bits]
    # Apply channel model equation (1)
    y = channel.apply(x)

    # Approximates marginalization of equation (2) by sum-product algorithm on factor graph
    x_beliefs = clustered_factor_graph(y, channel)
    x_hat = t.argmax(x_beliefs, 2)

    # Transforms beliefs of sum-product algo. into probabilities.
    x_probs = t.sigmoid(-x_beliefs[..., 0]) #1 / (1 + t.exp(tx_beliefs[..., 0]))

    # Calculates soft-ber (loss) and ber
    loss = t.sum(t.where(tx_bits.to(device) == 1, 1 - x_probs, x_probs)) / batch_size
    ber = t.count_nonzero(x_hat.cpu() != tx_bits) / tx_bits.numel()

    return ber, loss

In [None]:
# Simulation parameters
block_len = 30
spa_iterations = 10
snr_db = 10
constellation = t.tensor([1, -1], dtype=t.cfloat, device=device)
channel_taps = t.tensor(h_paper)
max_span_fn_container = len(channel_taps)
fn_container_degree = 3

# Training parameters
learning_rate = 0.001
batch_size = 10000
batches = 60000

channel = FiniteMemoryChannel(channel_taps, snr_db, device)
model = ClusteredFactorGraph(block_len, constellation, max_span_fn_container, len(channel_taps), fn_container_degree, spa_iterations, device)
optimizer = t.optim.Adam(model.parameters(), learning_rate)

# Training loop
model.train()
for batch_idx in range(batches):
    ber, loss = simulate_detection(batch_size, block_len, constellation, channel, model, device)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'BER: {ber.item()} [{batch_idx+1}/{batches}]')

In [None]:
# Evaluation
snr_list = np.arange(0, 13, 1)
ber_list = np.zeros(snr_list.shape[0], dtype=float)

batch_size = 100

model.eval()
with t.no_grad():
    for snr_idx, snr_db in enumerate(snr_list):
        channel = FiniteMemoryChannel(channel_taps, snr_db, device)

        ber_list[snr_idx], _ = simulate_detection(batch_size, block_len, constellation, channel, model, device)

print(f" Evaluation for Eb/N0: {snr_list}\nBER: {ber_list}")