In [1]:
import ccdm

In [1]:
# constant_composition_dm_tf.py
import math
from typing import List, Sequence
import tensorflow as tf
import ccdm
import numpy as np
from ccdm_initialize import initialize

# ---------------------------
# Helper combinatorics utils
# ---------------------------
def nchoosek_log2(n: int, k: int) -> float:
    """log2(n choose k) with safe handling (uses math.comb)."""
    k = min(k, n - k)
    if k < 0:
        return float('-inf')
    if k == 0:
        return 0.0
    # math.comb available in Python 3.8+
    res = 0.0
    # accumulate as in original C++ to avoid huge intermediate factorials
    for i in range(1, k + 1):
        res += math.log((n - (k - i)) / i, 2.0)
    return res

def nchooseks_log2(n: int, k_list: Sequence[int]) -> float:
    """Sum of log2( n choose k_i ), subtracting k_i from n after each."""
    total = 0.0
    n_local = n
    for ki in k_list:
        total += nchoosek_log2(n_local, ki)
        n_local -= ki
    return total

# ---------------------------
# Data classes (simple Python)
# ---------------------------
class CodeCandidate:
    def __init__(self, lower: float, upper: float, prob: float, symbols: List[int]):
        self.lower = lower
        self.upper = upper
        self.prob = prob
        self.symbols = list(symbols)

class SourceInterval:
    def __init__(self, lower: float = 0.0, upper: float = 1.0):
        self.lower = lower
        self.upper = upper

class CodeCandidateList:
    def __init__(self, k: int = 0):
        self.k = k
        self.list: List[CodeCandidate] = []

# ---------------------------
# Core routines
# ---------------------------
def update_src_interval(src_interval: SourceInterval, src_probability: Sequence[float], src_symbol: int):
    """Update source interval based on binary source probabilities (1-based symbol)."""
    # src_probability expected length 2 for this coder (binary source)
    new_border = src_interval.lower + (src_interval.upper - src_interval.lower) * src_probability[0]
    if src_symbol == 0:
        src_interval.upper = new_border
    else:
        src_interval.lower = new_border

def update_code_candidates(cc_list: CodeCandidateList, n_i: Sequence[int]):
    """Rebuild code candidates from remaining composition n_i (0-based counts)."""
    cc_list.list.clear()
    k = cc_list.k
    n_total = int(sum(n_i))
    sum_p = 0.0
    if n_total == 0:
        return
    for i in range(k):
        p_i = float(n_i[i]) / float(n_total) if n_total > 0 else 0.0
        lower = sum_p
        sum_p += p_i
        upper = 1.0 if (i == k - 1) else sum_p
        candidate = CodeCandidate(lower=lower, upper=upper, prob=p_i, symbols=[i + 1])
        cc_list.list.append(candidate)

def find_identified_code_candidate_index(src_interval: SourceInterval, cc_list: CodeCandidateList, n_i: Sequence[int]) -> int:
    """Return index of a code candidate whose lower bound lies inside source interval and count > 0."""
    for idx, cc in enumerate(cc_list.list):
        if src_interval.lower <= cc.lower < src_interval.upper and n_i[idx] != 0:
            return idx
    return -1

def update_N_i(n_i: List[int], symbol_list: Sequence[int]):
    """Decrease counts for each symbol in symbol_list (symbols are 1-based)."""
    for sym in symbol_list:
        n_i[sym - 1] -= 1

def find_code_interval_from_candidates(cc_list: CodeCandidateList, search_list: List[int]) -> SourceInterval:
    """Return interval for exact candidate whose symbols equal search_list."""
    for cc in cc_list.list:
        if cc.symbols == search_list:
            return SourceInterval(lower=cc.lower, upper=cc.upper)
    # default empty interval (shouldn't usually happen)
    return SourceInterval(0.0, 0.0)

def finalize_code_symbols(src_interval: SourceInterval, cc_list: CodeCandidateList, n_i: List[int]) -> List[int]:
    """If an identified single code candidate exists, finalize symbols (append remaining counts)."""
    cc_index = find_identified_code_candidate_index(src_interval, cc_list, n_i)
    if cc_index == -1:
        return []
    cc = cc_list.list[cc_index]
    symbols_new = list(cc.symbols)  # copy
    update_N_i(n_i, cc.symbols)
    # append remaining counts in order 1..k
    for i in range(cc_list.k):
        for _ in range(n_i[i]):
            symbols_new.append(i + 1)
    return symbols_new

def check_for_output_and_rescale(src_interval: SourceInterval, cc_list: CodeCandidateList, n_i: List[int]) -> List[int]:
    """
    Check whether src_interval is fully contained in any code candidate interval.
    If yes, rescale the source interval wrt that candidate, update counts and candidates,
    and repeat until no such containing candidate exists.
    Returns the concatenated list of produced code symbols during the process.
    """
    produced_symbols: List[int] = []

    # find initial candidate that contains the entire source interval
    chosen_cc = None
    for cc in cc_list.list:
        if src_interval.lower >= cc.lower and src_interval.upper <= cc.upper:
            chosen_cc = cc
            break

    while chosen_cc is not None:
        # rescale source interval to candidate's sub-interval
        interval_width = (chosen_cc.upper - chosen_cc.lower)
        if interval_width <= 0:
            # numerical safeties
            break
        src_interval.lower = (src_interval.lower - chosen_cc.lower) / interval_width
        src_interval.upper = (src_interval.upper - chosen_cc.lower) / interval_width
        # clamp tiny overshoots
        if src_interval.upper > 1.0:
            src_interval.upper = 1.0

        # consume candidate symbols
        produced_symbols.extend(chosen_cc.symbols)
        update_N_i(n_i, chosen_cc.symbols)

        # rebuild candidate list with new counts
        update_code_candidates(cc_list, n_i)

        # find next chosen candidate
        chosen_cc = None
        for cc in cc_list.list:
            if src_interval.lower >= cc.lower and src_interval.upper <= cc.upper:
                chosen_cc = cc
                break

    return produced_symbols

# ---------------------------
# Main encoder function
# ---------------------------
def encode_constant_composition_arithmetic_matcher(src_symbols: Sequence[int],
                                                   n_total: int,
                                                   n_i_vect: Sequence[int],
                                                   src_prob: Sequence[float] = (0.5, 0.5)) -> tf.Tensor:
    """
    Args:
      src_symbols: sequence/iterable of source symbols (each either 0 or 1 -- 1-based used in original C++: symbol==0 or 1)
                   In the original C++ psrc = {0.5, 0.5} and src_symbol is 0/1.
      n_total: total number of output symbols (n)
      n_i_vect: list of k integers specifying composition counts (length k); sum(n_i_vect) must equal n_total
      src_prob: binary source probabilities, default .5/.5
    Returns:
      Tensor of output code symbols (dtype tf.int32), 1-based symbol indices as in C++ implementation.
    """
    # convert inputs to Python lists for control flow
    k = len(n_i_vect)
    n_i = list(n_i_vect)  # mutable copy

    # initialize source interval and code-candidate list
    src_interval = SourceInterval(0.0, 1.0)
    cc_list = CodeCandidateList(k=k)
    update_code_candidates(cc_list, n_i)

    code_symbols: List[int] = []

    # m (bits) computed in original but unused for core symbol generation; we keep it for parity if desired
    m_bits = nchooseks_log2(n_total, list(n_i_vect))

    # loop through source symbols, update interval, and try to output symbols
    for s in src_symbols:
        # s is expected to be 0 or 1. If user gives 1-based, convert to 1-based? original expects 0 or 1.
        # The C++ call used psrc[2] = {0.5, 0.5}; updateSrcInterval(..., src_symbol)
        # where src_symbol is *it from src_symbols, so they were 0 or 1.
        src_sym = int(s)
        update_src_interval(src_interval, src_prob, src_sym)
        new_syms = check_for_output_and_rescale(src_interval, cc_list, n_i)
        if new_syms:
            code_symbols.extend(new_syms)

    # finalize
    fin_syms = finalize_code_symbols(src_interval, cc_list, n_i)
    code_symbols.extend(fin_syms)

    return tf.convert_to_tensor(code_symbols, dtype=tf.int32)

# ---------------------------
# Example / quick test
# ---------------------------
if __name__ == "__main__":
    # small example: composition of 3 symbols across k=3 symbols: counts [1,1,1] (n=3)
    # use a short pseudo source stream (binary symbols 0/1)
    # src = [0, 1, 0, 1, 1, 0]  # arbitrary
    # n_total = 3
    # n_i_vect = [1, 1, 1]  # three output symbols, each one of symbol indices {1,2,3}

    # out = encode_constant_composition_arithmetic_matcher(src, n_total, n_i_vect)
    # print("Encoded output (1-based symbols):", out.numpy())

    # small example: C++ implementation
    # use a short pseudo source stream (binary symbols 0/1)
    pOpt = np.array([0.537577302140556,0.322026673986155,0.115556317270981,0.024839706602308])
    n = 256
    [p_quant,num_info_bits,n_i] = initialize(pOpt,256)
    src_symbols = np.random.randint(0, 2, size=num_info_bits)
    code_symbols = ccdm.encode(src_symbols, n, n_i)

    # out = encode_constant_composition_arithmetic_matcher(src, n_total, n_i_vect)
    print("Encoded output (1-based symbols):", code_symbols)

    out = encode_constant_composition_arithmetic_matcher(src_symbols, n, n_i)
    print("Encoded output (1-based symbols) tf:", out.numpy())


2025-10-18 11:08:58.331430: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760774938.351371 2940517 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760774938.357654 2940517 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760774938.374497 2940517 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1760774938.374513 2940517 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1760774938.374515 2940517 computation_placer.cc:177] computation placer alr

Encoded output (1-based symbols): [1, 2, 1, 2, 1, 1, 1, 3, 2, 1, 1, 1, 3, 4, 3, 2, 2, 3, 2, 2, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 2, 1, 2, 3, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 1, 3, 1, 2, 1, 1, 2, 3, 1, 1, 2, 2, 1, 2, 1, 1, 1, 2, 1, 1, 1, 3, 1, 1, 2, 2, 3, 1, 1, 3, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 3, 2, 4, 2, 2, 1, 1, 3, 2, 1, 2, 1, 3, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 2, 1, 1, 1, 2, 1, 1, 3, 4, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 3, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 2, 3, 1, 1, 2, 1, 2, 1, 3, 3, 1, 2, 1, 1, 2, 1, 1, 1, 3, 2, 1, 2, 2, 1, 2, 1, 4, 1, 1, 2, 1, 1, 3, 1, 1, 1, 2, 2, 1, 3, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 2, 1, 2, 1, 3, 3, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 4, 2, 3, 1, 2, 1, 1, 1, 3, 2, 1, 1, 3, 1, 1, 2, 3, 1, 1, 4, 2]
Encoded output (1-based symbols) tf: [1 2 1 2 1 1 1 3 2 1 1 1 3 4 3 2 2 3 2 2 1 1 2 2 1 1 1 1 1 2 1 1 1 1 2 2 1
 2 3 2 1 1 1 1 1 2 1 2 1 3 1 2 1 3 1 2 1 1 2 3 1 1 2 2 1 2 1 1 1 2 1 1 1 3
 1 1 2 2 3

I0000 00:00:1760774941.510156 2940517 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3244 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB MIG 1g.5gb, pci bus id: 0000:03:00.0, compute capability: 8.0


In [14]:
import tensorflow as tf

@tf.function
def batch_ccdm_encode(src_symbols,        # shape [B, T], dtype int32, values 0 or 1
                      n_total,            # scalar int (total output length n)
                      n_i_vect,           # shape [k], dtype int32, composition counts
                      src_prob=(0.5, 0.5) # binary source probabilities (python tuple ok)
                     ):
    """
    Batch CCDM encoder (tensor only).
    Returns: tf.Tensor shape [B, n_total], dtype tf.int32, 1-based symbol indices.
    Raises: tf.errors.InvalidArgumentError if final lengths != n_total or no finalize candidate found.
    """

    src_prob = tf.convert_to_tensor(src_prob, dtype=tf.float64)  # length 2
    src_prob0 = tf.cast(src_prob[0], tf.float64)

    src_symbols = tf.convert_to_tensor(src_symbols, dtype=tf.int32)
    n_i_vect = tf.convert_to_tensor(n_i_vect, dtype=tf.int32)
    n_total = tf.convert_to_tensor(n_total, dtype=tf.int32)

    B = tf.shape(src_symbols)[0]
    T = tf.shape(src_symbols)[1]
    k = tf.shape(n_i_vect)[0]

    # initialize per-batch mutable state
    # n_i per batch: shape [B, k], init with broadcast of n_i_vect
    n_i = tf.cast(tf.broadcast_to(n_i_vect[tf.newaxis, :], [B, k]), tf.int32)

    # compute initial candidate bounds from n_i
    def compute_bounds_from_n_i(n_i_batch):
        # n_i_batch shape [k], dtype int32
        n_total_batch = tf.cast(tf.reduce_sum(n_i_batch), tf.float64)
        n_i_f = tf.cast(n_i_batch, tf.float64)
        # avoid division by zero (shouldn't happen until finalize)
        p = tf.where(n_total_batch > 0.0, n_i_f / n_total_batch, tf.zeros_like(n_i_f))
        lower = tf.cumsum(p, exclusive=True)
        upper = tf.cumsum(p)
        # ensure last upper = 1.0 if any
        upper = tf.where(tf.range(k, dtype=tf.int32) == (k - 1), tf.ones_like(upper), upper)
        return lower, upper

    # vectorized bounds
    lower_bounds, upper_bounds = tf.map_fn(lambda r: compute_bounds_from_n_i(r),
                                           n_i,
                                           dtype=(tf.float64, tf.float64))

    # source intervals per batch
    src_lower = tf.zeros([B], tf.float64)
    src_upper = tf.ones([B], tf.float64)

    # output buffer: shape [B, n_total], fill with zeros initially
    output = tf.zeros([B, n_total], dtype=tf.int32)
    # next write position per batch
    write_ptr = tf.zeros([B], dtype=tf.int32)

    # helper to find chosen candidate index per batch: first idx with lower <= src_lower < upper and n_i>0
    def choose_candidate_indices(src_lower, src_upper, lower_bounds, upper_bounds, n_i):
        # masks shape [B, k]
        # condition: src_lower >= lower_bounds && src_upper <= upper_bounds AND n_i > 0
        cond1 = tf.greater_equal(tf.expand_dims(src_lower, 1), lower_bounds)  # [B,k]
        cond2 = tf.less_equal(tf.expand_dims(src_upper, 1), upper_bounds)     # [B,k]
        cond3 = tf.greater(n_i, 0)
        mask = tf.logical_and(tf.logical_and(cond1, cond2), cond3)
        # For each batch, compute smallest index where mask is True. If none True, return -1.
        large = tf.cast(k, tf.int32)
        idxs = tf.where(mask, tf.tile(tf.range(k)[tf.newaxis, :], [B, 1]), large)
        min_idx = tf.reduce_min(idxs, axis=1)  # [B]
        chosen_idx = tf.where(tf.equal(min_idx, large), tf.fill([B], -1), min_idx)
        return chosen_idx, mask

    # inner loop that checks for output and rescales repeatedly until none batch has a containing candidate
    # We'll implement it as a tf.while_loop where each iteration processes all batches which currently have a candidate.

    # We'll create the outer time-loop below. First, create a step function that:
    # - updates src interval from one source symbol (vectorized)
    # - then runs the inner "rescale loop" (vectorized) to append produced symbols

    # tf.range arrays used many times
    range_k = tf.range(k, dtype=tf.int32)
    range_k_f64 = tf.cast(range_k, tf.float64)

    # Build while body for the inner "while a candidate exists" loop
    def inner_rescale_loop(body_state):
        # state: (src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)
        src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr = body_state

        chosen_idx, mask = choose_candidate_indices(src_lower, src_upper, lower_bounds, upper_bounds, n_i)
        # boolean whether any batch has candidate
        any_chosen = tf.reduce_any(tf.greater_equal(chosen_idx, 0))

        def some_chosen_branch():
            # For batches with chosen_idx >=0 apply vectorized updates.
            chosen_present = tf.greater_equal(chosen_idx, 0)  # [B] bool

            # gather candidate lower/upper for chosen idx. For batches with idx=-1, we gather zeros and then mask them.
            # prepare indices for gather_nd
            batch_idxs = tf.reshape(tf.range(B, dtype=tf.int32), [B, 1])  # [B,1]
            gather_idx = tf.concat([batch_idxs, tf.reshape(tf.maximum(chosen_idx, 0), [B,1])], axis=1)  # [B,2]
            chosen_lower = tf.gather_nd(lower_bounds, gather_idx)  # [B]
            chosen_upper = tf.gather_nd(upper_bounds, gather_idx)  # [B]

            # only use chosen_lower/upper for batches where chosen_present
            chosen_lower = tf.where(chosen_present, chosen_lower, tf.zeros_like(chosen_lower))
            chosen_upper = tf.where(chosen_present, chosen_upper, tf.ones_like(chosen_upper))

            # compute interval width (avoid zero)
            width = chosen_upper - chosen_lower
            width_safe = tf.where(width <= 0.0, tf.ones_like(width), width)

            # rescale src interval for chosen batches
            new_src_lower = (src_lower - chosen_lower) / width_safe
            new_src_upper = (src_upper - chosen_lower) / width_safe
            new_src_upper = tf.where(new_src_upper > 1.0, tf.ones_like(new_src_upper), new_src_upper)

            # update n_i: decrement count for chosen_idx where present
            # Build one-hot for chosen_idx: shape [B, k], values 1 at chosen idx, 0 elsewhere. For idx=-1, zero vector.
            chosen_idx_pos = tf.maximum(chosen_idx, 0)
            one_hot = tf.one_hot(chosen_idx_pos, depth=k, dtype=tf.int32)  # [B,k]
            one_hot = one_hot * tf.cast(tf.expand_dims(tf.cast(chosen_present, tf.int32), 1), tf.int32)

            n_i_new = n_i - one_hot  # shape [B,k]

            # write produced symbol (chosen_idx+1) into output at write_ptr for batches where chosen_present
            produced_symbol = (chosen_idx + 1)  # -1 -> 0 but we mask
            produced_symbol = tf.where(chosen_present, produced_symbol, tf.zeros_like(produced_symbol))

            # For writing, create mask for each batch to do single-element write:
            # output shape [B, n_total], we update column at index write_ptr[b] for each batch b that has chosen_present
            # Build indices for scatter_nd_update
            write_positions = tf.where(chosen_present, write_ptr, tf.zeros_like(write_ptr))  # use 0 for others
            scatter_idx = tf.stack([tf.range(B, dtype=tf.int32), write_positions], axis=1)  # [B,2]
            # But we should only update for chosen_present. So form updates where chosen_present.
            updates = tf.where(chosen_present, produced_symbol, tf.zeros_like(produced_symbol))
            output = tf.tensor_scatter_nd_update(output, scatter_idx, updates)

            # increment write_ptr only for chosen_present
            write_ptr = write_ptr + tf.cast(chosen_present, tf.int32)

            # recompute candidate bounds from n_i_new
            def compute_bounds_for_all(n_i_full):
                # n_i_full shape [B,k] int32 -> returns lower_bounds, upper_bounds [B,k] float64
                n_total_batch = tf.cast(tf.reduce_sum(n_i_full, axis=1, keepdims=True), tf.float64)  # [B,1]
                n_i_f = tf.cast(n_i_full, tf.float64)  # [B,k]
                p = tf.where(n_total_batch > 0.0, n_i_f / n_total_batch, tf.zeros_like(n_i_f))
                lower = tf.cumsum(p, axis=1, exclusive=True)
                upper = tf.cumsum(p, axis=1)
                # ensure last upper = 1.0
                last_mask = tf.equal(tf.range(k, dtype=tf.int32), k-1)
                last_mask = tf.cast(last_mask, tf.float64)[tf.newaxis, :]
                upper = tf.where(tf.equal(last_mask, 1.0), tf.ones_like(upper), upper)
                return lower, upper

            lower_bounds_new, upper_bounds_new = compute_bounds_for_all(n_i_new)

            # For batches that didn't have chosen candidate, keep src_lower, src_upper unchanged and n_i unchanged
            src_lower = tf.where(chosen_present, new_src_lower, src_lower)
            src_upper = tf.where(chosen_present, new_src_upper, src_upper)
            n_i = tf.where(tf.expand_dims(chosen_present, 1), n_i_new, n_i)
            lower_bounds = tf.where(tf.expand_dims(chosen_present, 1), lower_bounds_new, lower_bounds)
            upper_bounds = tf.where(tf.expand_dims(chosen_present, 1), upper_bounds_new, upper_bounds)

            return (src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)
        # no chosen candidate
        return (src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)

    # Now define the outer time loop body
    # state for outer loop: t, src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr
    t0 = tf.constant(0, tf.int32)
    state0 = (t0, src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)

    def outer_cond(state):
        t, *_ = state
        return tf.less(t, T)

    def outer_body(state):
        t, src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr = state
        # symbol at time t for all batches
        sym_t = tf.cast(src_symbols[:, t], tf.int32)  # [B], 0 or 1

        # Update source interval per batch based on binary probability
        # new_border = lower + (upper-lower) * psrc[0]
        new_border = src_lower + (src_upper - src_lower) * src_prob0
        # If src_symbol == 0 => upper = new_border, else lower = new_border
        is_zero = tf.equal(sym_t, 0)
        src_upper = tf.where(is_zero, new_border, src_upper)
        src_lower = tf.where(is_zero, src_lower, new_border)

        # After updating interval, run inner rescale loop until no chosen candidate exists
        inner_state = (src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)

        def inner_cond(inner_state):
            src_lower_i, src_upper_i, lb_i, ub_i, n_i_i, out_i, wp_i = inner_state
            chosen_idx_i, _ = choose_candidate_indices(src_lower_i, src_upper_i, lb_i, ub_i, n_i_i)
            return tf.reduce_any(tf.greater_equal(chosen_idx_i, 0))

        def inner_body(inner_state):
            return inner_rescale_loop(inner_state)

        src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr = tf.while_loop(
            inner_cond, inner_body, inner_state,
            maximum_iterations= n_total * 2  # safe upper bound
        )

        # advance time
        t = t + 1
        return (t, src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr)

    print(tf.while_loop(outer_cond, outer_body, state0, maximum_iterations=T))
    # _, src_lower, src_upper, lower_bounds, upper_bounds, n_i, output, write_ptr = tf.while_loop(outer_cond, outer_body, state0, maximum_iterations=T)
    # print(final_state)

    # Unpack final state after processing all source symbols
     # = final_state

    # FINALIZE step for each batch: find identified candidate index (if any), then produce:
    # resultSymbols = cc.symbols (i.e., chosen_idx+1) and then append remaining counts of each symbol
    chosen_idx, mask = choose_candidate_indices(src_lower, src_upper, lower_bounds, upper_bounds, n_i)
    # chosen_idx gives candidate with lower in interval? The C++ used findIdentifiedCodeCandidateIndex which checks:
    # if src_interval.lowerBound <= cc.lowerBound  && cc.lowerBound < src_interval.upperBound && n_i[i] != 0
    # Our choose_candidate_indices uses lower<=src_lower and src_upper <= upper. For finalization we need the other condition.
    # So compute identified candidate differently:
    # cond_identified: src_lower <= lower_bounds < src_upper and n_i>0
    cond_identified = tf.logical_and(tf.greater_equal(tf.expand_dims(src_lower,1), tf.zeros_like(lower_bounds)),
                                     tf.zeros_like(lower_bounds) + 0)  # dummy to later build
    # Build mask for identified
    cond_low_ge = tf.less_equal(lower_bounds, tf.expand_dims(src_upper,1))  # lower <= src_upper (not exact)
    # Let's compute explicitly:
    identified_mask = tf.logical_and(
        tf.greater_equal(tf.expand_dims(src_upper,1), lower_bounds + 0.0),  # src_upper >= cc.lower
        tf.greater_equal(lower_bounds + 0.0, tf.expand_dims(src_lower,1))   # cc.lower >= src_lower
    )
    identified_mask = tf.logical_and(identified_mask, tf.greater(n_i, 0))
    # find smallest index per batch
    large = tf.cast(k, tf.int32)
    idxs_id = tf.where(identified_mask, tf.tile(tf.range(k)[tf.newaxis,:], [B,1]), large)
    min_idx_id = tf.reduce_min(idxs_id, axis=1)
    chosen_final_idx = tf.where(tf.equal(min_idx_id, large), tf.fill([B], -1), min_idx_id)

    # For robustness: if any chosen_final_idx == -1 -> error
    any_missing = tf.reduce_any(tf.equal(chosen_final_idx, -1))
    def raise_error():
        # throw TF error
        msg = "CCDM finalize: no identified code candidate for at least one batch (cannot finalize)."
        return tf.debugging.assert_less_equal(1, 0, message=msg)  # intentional assert fail
    def continue_finalize():
        return tf.no_op()

    tf.cond(any_missing, lambda: raise_error(), lambda: continue_finalize())

    # Now write final symbols per batch: first the chosen_final_idx + 1 (the cc.symbols)
    produced_symbol_final = chosen_final_idx + 1  # shape [B]

    # scatter into output at write_ptr if chosen_final present (it is)
    scatter_idx = tf.stack([tf.range(B, dtype=tf.int32), write_ptr], axis=1)
    output = tf.tensor_scatter_nd_update(output, scatter_idx, produced_symbol_final)
    write_ptr = write_ptr + 1

    # Then append remaining counts for each symbol i (1..k)
    # For each symbol i, repeat symbol i+1 n_i[b,i] times for each batch b, writing sequentially at write_ptr
    # We'll perform this by iterating i from 0..k-1 using a small tf.while_loop (k is typically modest).
    i0 = tf.constant(0, dtype=tf.int32)
    cond_k = lambda i, *_: tf.less(i, k)

    def body_k(i, output, write_ptr, n_i):
        # per-batch count to append
        counts = n_i[:, i]  # [B]
        sym_val = tf.fill([B], i + 1)  # [B], 1-based
        # For each batch, we need to write counts[b] copies of sym_val[b] at increasing write_ptr.
        # We'll loop per repetition across all batches using another loop up to max_count (max across batches).
        max_count = tf.reduce_max(counts)

        j0 = tf.constant(0, dtype=tf.int32)
        def cond_j(j, output, write_ptr, counts):
            return tf.less(j, max_count)
        def body_j(j, output, write_ptr, counts):
            # mask for batches where counts > j
            mask = tf.greater(counts, j)  # [B]
            write_positions = tf.where(mask, write_ptr, tf.zeros_like(write_ptr))
            scatter_idx = tf.stack([tf.range(B, dtype=tf.int32), write_positions], axis=1)
            updates = tf.where(mask, sym_val, tf.zeros_like(sym_val))
            output = tf.tensor_scatter_nd_update(output, scatter_idx, updates)
            write_ptr = write_ptr + tf.cast(mask, tf.int32)
            return j + 1, output, write_ptr, counts
        _, output, write_ptr, _ = tf.while_loop(cond_j, body_j, (j0, output, write_ptr, counts), maximum_iterations=n_total)
        # set counts to zero for this symbol now
        n_i = tf.tensor_scatter_nd_update(n_i, tf.reshape(tf.range(B, dtype=tf.int32), [-1,1]), n_i)  # no-op to keep shape
        return i + 1, output, write_ptr, n_i

    _, output, write_ptr, n_i = tf.while_loop(cond_k, body_k, (i0, output, write_ptr, n_i),
                                             maximum_iterations=k)

    # Final check: write_ptr must equal n_total for all batches
    all_done = tf.reduce_all(tf.equal(write_ptr, n_total))
    tf.debugging.assert_equal(all_done, True, message="CCDM failed: produced symbol count != n_total for some batch.")

    return output


In [15]:
pOpt = np.array([0.537577302140556,0.322026673986155,0.115556317270981,0.024839706602308])
n = 256
[p_quant,num_info_bits,n_i] = initialize(pOpt,256)
n_i = np.array(n_i, dtype=np.int32)
src_symbols = np.reshape(np.random.randint(0, 2, size=num_info_bits, dtype=np.int32),(1,num_info_bits))
code_symbols = batch_ccdm_encode(src_symbols, n, n_i)
# def batch_ccdm_encode(src_symbols,        # shape [B, T], dtype int32, values 0 or 1
#                       n_total,            # scalar int (total output length n)
#                       n_i_vect,           # shape [k], dtype int32, composition counts
#                       src_prob=(0.5, 0.5) # binary source probabilities (python tuple ok)
#                      ):

TypeError: in user code:

    File "/tmp/rjayarat/5282792/ipykernel_2940517/856147342.py", line 209, in batch_ccdm_encode  *
        print(tf.while_loop(outer_cond, outer_body, state0, maximum_iterations=T))

    TypeError: outer_factory.<locals>.inner_factory.<locals>.tf__batch_ccdm_encode.<locals>.outer_cond() takes 1 positional argument but 8 were given


In [4]:
code_symbols

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 1,
 2,
 2,
 2,
 2,
 2,
 1,
 1,
 2,
 3,
 2,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 4,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 4,
 1,
 1,
 1,
 3,
 2,
 1,
 1,
 1,
 1,
 1,
 4,
 3,
 2,
 2,
 1,
 1,
 1,
 2,
 3,
 3,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 2,
 1,
 1,
 2,
 2,
 2,
 1,
 3,
 3,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 3,
 1,
 1,
 2,
 2,
 1,
 3,
 1,
 2,
 1,
 3,
 1,
 1,
 2,
 2,
 1,
 2,
 2,
 2,
 3,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 2,
 3,
 2,
 2,
 2,
 2,
 2,
 1,
 4,
 3,
 1,
 2,
 1,
 2,
 1,
 1,
 3,
 4,
 1,
 3,
 2,
 1,
 3,
 1,
 3,
 3,
 2,
 1,
 1,
 1,
 2,
 2,
 2,
 1,
 2,
 2,
 1,
 3,
 1,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 2,
 3,
 1,
 1,
 2,
 1,
 1,
 1,
 3,
 1,
 3,
 1,
 1,
 3,
 3,
 1,
 2,
 1,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 3,
 1,
 1,
 2,
 1,
 2,
 3,
 2,
 1,
 2,
 2,
 1,
 1,
 1,
 2,
 1,
 2,
 1,
 1,
 1,
 2,
 1,
 3,
 3,
 1,
 2,
 1,
 2,
 2,
 1,
 3,
 3,
 4,
 1,
