In [None]:
def sim_ber(mc_fun,
            ebno_dbs,
            batch_size,
            max_mc_iter,
            soft_estimates=False,
            num_target_bit_errors=None,
            num_target_block_errors=None,
            early_stop=True,
            graph_mode=None,
            verbose=True,
            forward_keyboard_interrupt=True,
            dtype=tf.complex64):
    """Simulates until target number of errors is reached and returns BER/BLER.

    The simulation continues with the next SNR point if either
    ``num_target_bit_errors`` bit errors or ``num_target_block_errors`` block
    errors is achieved. Further, it continues with the next SNR point after
    ``max_mc_iter`` batches of size ``batch_size`` have been simulated.

    Input
    -----
    mc_fun:
        Callable that yields the transmitted bits `b` and the
        receiver's estimate `b_hat` for a given ``batch_size`` and
        ``ebno_db``. If ``soft_estimates`` is True, b_hat is interpreted as
        logit.

    ebno_dbs: tf.float32
        A tensor containing SNR points to be evaluated.

    batch_size: tf.int32
        Batch-size for evaluation.

    max_mc_iter: tf.int32
        Max. number of Monte-Carlo iterations per SNR point.

    soft_estimates: bool
        A boolean, defaults to False. If True, `b_hat``
        is interpreted as logit and an additional hard-decision is applied
        internally.

    num_target_bit_errors: tf.int32
        Defaults to None. Target number of bit errors per SNR point until
        the simulation continues to next SNR point.

    num_target_block_errors: tf.int32
        Defaults to None. Target number of block errors per SNR point
        until the simulation continues

    early_stop: bool
        A boolean defaults to True. If True, the simulation stops after the
        first error-free SNR point (i.e., no error occurred after
        ``max_mc_iter`` Monte-Carlo iterations).

    graph_mode: One of ["graph", "xla"], str
        A string describing the execution mode of ``mc_fun``.
        Defaults to `None`. In this case, ``mc_fun`` is executed as is.

    verbose: bool
        A boolean defaults to True. If True, the current progress will be
        printed.

    forward_keyboard_interrupt: bool
        A boolean defaults to True. If False, KeyboardInterrupts will be
        catched internally and not forwarded (e.g., will not stop outer loops).
        If False, the simulation ends and returns the intermediate simulation
        results.

    dtype: tf.complex64
        Datatype of the model / function to be used (``mc_fun``).

    Output
    ------
    (ber, bler) :
        Tuple:

    ber: tf.float32
        The bit-error rate.

    bler: tf.float32
        The block-error rate.

    Raises
    ------
    AssertionError
        If ``soft_estimates`` is not bool.

    AssertionError
        If ``dtype`` is not `tf.complex`.

    Note
    ----
    This function is implemented based on tensors to allow
    full compatibility with tf.function(). However, to run simulations
    in graph mode, the provided ``mc_fun`` must use the `@tf.function()`
    decorator.

    """

    # utility function to print progress
    def _print_progress(is_final, rt, idx_snr, idx_it, header_text=None):
        """Print summary of current simulation progress.

        Input
        -----
        is_final: bool
            A boolean. If True, the progress is printed into a new line.
        rt: float
            The runtime of the current SNR point in seconds.
        idx_snr: int
            Index of current SNR point.
        idx_it: int
            Current iteration index.
        header_text: list of str
            Elements will be printed instead of current progress, iff not None.
            Can be used to generate table header.
        """
        # set carriage return if not final step
        if is_final:
            end_str = "\n"
        else:
            end_str = "\r"

        # prepare to print table header
        if header_text is not None:
            row_text = header_text
            end_str = "\n"
        else:
            # calculate intermediate ber / bler
            ber_np = (tf.cast(bit_errors[idx_snr], tf.float64)
                        / tf.cast(nb_bits[idx_snr], tf.float64)).numpy()
            ber_np = np.nan_to_num(ber_np) # avoid nan for first point
            bler_np = (tf.cast(block_errors[idx_snr], tf.float64)
                        / tf.cast(nb_blocks[idx_snr], tf.float64)).numpy()
            bler_np = np.nan_to_num(bler_np) # avoid nan for first point

            # load statuslevel
            # print current iter if simulation is still running
            if status[idx_snr]==0:
                status_txt = f"iter: {idx_it:.0f}/{max_mc_iter:.0f}"
            else:
                status_txt = status_levels[int(status[idx_snr])]

            # generate list with all elements to be printed
            row_text = [str(np.round(ebno_dbs[idx_snr].numpy(), 3)),
                        f"{ber_np:.4e}",
                        f"{bler_np:.4e}",
                        np.round(bit_errors[idx_snr].numpy(), 0),
                        np.round(nb_bits[idx_snr].numpy(), 0),
                        np.round(block_errors[idx_snr].numpy(), 0),
                        np.round(nb_blocks[idx_snr].numpy(), 0),
                        np.round(rt, 1),
                        status_txt]

        # pylint: disable=line-too-long, consider-using-f-string
        print("{: >9} |{: >11} |{: >11} |{: >12} |{: >12} |{: >13} |{: >12} |{: >12} |{: >10}".format(*row_text), end=end_str)


     # init table headers
    header_text = ["EbNo [dB]", "BER", "BLER", "bit errors",
                   "num bits", "block errors", "num blocks",
                   "runtime [s]", "status"]

    # replace status by text
    status_levels = ["not simulated", # status=0
            "reached max iter       ", # status=1; spacing for impr. layout
            "no errors - early stop", # status=2
            "reached target bit errors", # status=3
            "reached target block errors"] # status=4

    # check inputs for consistency
    assert isinstance(early_stop, bool), "early_stop must be bool."
    assert isinstance(soft_estimates, bool), "soft_estimates must be bool."
    assert dtype.is_complex, "dtype must be a complex type."
    assert isinstance(verbose, bool), "verbose must be bool."

    if graph_mode is None:
        graph_mode="default" # applies default graph mode
    assert isinstance(graph_mode, str), "graph_mode must be str."

    if graph_mode=="default":
        pass # nothing to do
    elif graph_mode=="graph":
        # avoid retracing -> check if mc_fun is already a function
        if not isinstance(mc_fun, tf.types.experimental.GenericFunction):
            mc_fun = tf.function(mc_fun,
                                 jit_compile=False,
                                 experimental_follow_type_hints=True)
    elif graph_mode=="xla":
        # avoid retracing -> check if mc_fun is already a function
        if not isinstance(mc_fun, tf.types.experimental.GenericFunction) or \
           not mc_fun.function_spec.jit_compile:
            mc_fun = tf.function(mc_fun,
                                 jit_compile=True,
                                 experimental_follow_type_hints=True)
    else:
        raise TypeError("Unknown graph_mode selected.")

    ebno_dbs = tf.cast(ebno_dbs, dtype.real_dtype)
    batch_size = tf.cast(batch_size, tf.int32)
    num_points = tf.shape(ebno_dbs)[0]
    bit_errors = tf.Variable(   tf.zeros([num_points], dtype=tf.int64),
                            dtype=tf.int64)
    block_errors = tf.Variable(  tf.zeros([num_points], dtype=tf.int64),
                            dtype=tf.int64)
    nb_bits = tf.Variable(  tf.zeros([num_points], dtype=tf.int64),
                            dtype=tf.int64)
    nb_blocks = tf.Variable(  tf.zeros([num_points], dtype=tf.int64),
                            dtype=tf.int64)

    # track status of simulation (early termination etc.)
    status = np.zeros(num_points)

    # measure runtime per SNR point
    runtime = np.zeros(num_points)

    # ensure num_target_errors is a tensor
    if num_target_bit_errors is not None:
        num_target_bit_errors = tf.cast(num_target_bit_errors, tf.int64)
    if num_target_block_errors is not None:
        num_target_block_errors = tf.cast(num_target_block_errors, tf.int64)

    try:
        # simulate until a target number of errors is reached
        for i in tf.range(num_points):
            runtime[i] = time.perf_counter() # save start time
            iter_count = -1 # for print in verbose mode
            for ii in tf.range(max_mc_iter):

                iter_count += 1

                outputs = mc_fun(batch_size=batch_size, ebno_db=ebno_dbs[i])

                # assume first and second return value is b and b_hat
                # other returns are ignored
                b = outputs[0]
                b_hat = outputs[1]

                if soft_estimates:
                    b_hat = hard_decisions(b_hat)

                # count errors
                bit_e = count_errors(b, b_hat)
                block_e = count_block_errors(b, b_hat)

                # count total number of bits
                bit_n = tf.size(b)
                block_n = tf.size(b[...,-1])

                # update variables
                bit_errors = tf.tensor_scatter_nd_add(  bit_errors, [[i]],
                                                    tf.cast([bit_e], tf.int64))
                block_errors = tf.tensor_scatter_nd_add(  block_errors, [[i]],
                                                tf.cast([block_e], tf.int64))
                nb_bits = tf.tensor_scatter_nd_add( nb_bits, [[i]],
                                                    tf.cast([bit_n], tf.int64))
                nb_blocks = tf.tensor_scatter_nd_add( nb_blocks, [[i]],
                                                tf.cast([block_n], tf.int64))

                # print progress summary
                if verbose:
                    # print summary header during first iteration
                    if i==0 and iter_count==0:
                        _print_progress(is_final=True,
                                        rt=0,
                                        idx_snr=0,
                                        idx_it=0,
                                        header_text=header_text)
                        # print seperator after headline
                        print('-' * 135)

                    # evaluate current runtime
                    rt = time.perf_counter() - runtime[i]
                    # print current progress
                    _print_progress(is_final=False, idx_snr=i, idx_it=ii, rt=rt)


                # bit-error based stopping cond.
                if num_target_bit_errors is not None:
                    if tf.greater_equal(bit_errors[i], num_target_bit_errors):
                        status[i] = 3 # change internal status for summary
                        # stop runtime timer
                        runtime[i] = time.perf_counter() - runtime[i]
                        break # enough errors for SNR point have been simulated

                # block-error based stopping cond.
                if num_target_block_errors is not None:
                    if tf.greater_equal(block_errors[i],
                                        num_target_block_errors):
                        # stop runtime timer
                        runtime[i] = time.perf_counter() - runtime[i]
                        status[i] = 4 # change internal status for summary
                        break # enough errors for SNR point have been simulated

                # max iter have been reached -> continue with next SNR point
                if iter_count==max_mc_iter-1: # all iterations are done
                    # stop runtime timer
                    runtime[i] = time.perf_counter() - runtime[i]
                    status[i] = 1 # change internal status for summary

            # print results again AFTER last iteration / early stop (new status)
            if verbose:
                _print_progress(is_final=True,
                                idx_snr=i,
                                idx_it=iter_count,
                                rt=runtime[i])

            # early stop if no error occurred
            if early_stop: # only if early stop is active
                if block_errors[i]==0:
                    status[i] = 2 # change internal status for summary
                    if verbose:
                        print("\nSimulation stopped as no error occurred " \
                              f"@ EbNo = {ebno_dbs[i].numpy():.1f} dB.\n")
                    break

    # Stop if KeyboardInterrupt is detected and set remaining SNR points to -1
    except KeyboardInterrupt as e:

        # Raise Interrupt again to stop outer loops
        if forward_keyboard_interrupt:
            raise e

        print("\nSimulation stopped by the user " \
              f"@ EbNo = {ebno_dbs[i].numpy()} dB")
        # overwrite remaining BER / BLER positions with -1
        for idx in range(i+1, num_points):
            bit_errors = tf.tensor_scatter_nd_update( bit_errors, [[idx]],
                                                    tf.cast([-1], tf.int64))
            block_errors = tf.tensor_scatter_nd_update( block_errors, [[idx]],
                                                    tf.cast([-1], tf.int64))
            nb_bits = tf.tensor_scatter_nd_update( nb_bits, [[idx]],
                                                    tf.cast([1], tf.int64))
            nb_blocks = tf.tensor_scatter_nd_update( nb_blocks, [[idx]],
                                                    tf.cast([1], tf.int64))

    # calculate BER / BLER
    ber = tf.cast(bit_errors, tf.float64) / tf.cast(nb_bits, tf.float64)
    bler = tf.cast(block_errors, tf.float64) / tf.cast(nb_blocks, tf.float64)

    # replace nans (from early stop)
    ber = tf.where(tf.math.is_nan(ber), tf.zeros_like(ber), ber)
    bler = tf.where(tf.math.is_nan(bler), tf.zeros_like(bler), bler)

    return ber, bler