In [1]:
from collections import defaultdict
from typing import DefaultDict, Dict, List, Tuple

N = 25  # numbers 0..24


def precompute_B_by_sum(N: int = 25) -> Dict[int, List[int]]:
    """
    Precompute all B 4-combinations from {0..N-1} and group them by sum.
    Store each B as a bitmask (int).
    """
    dict_B: DefaultDict[int, List[int]] = defaultdict(list)

    for b1 in range(N - 3):
        m1 = 1 << b1
        for b2 in range(b1 + 1, N - 2):
            m2 = m1 | (1 << b2)
            for b3 in range(b2 + 1, N - 1):
                m3 = m2 | (1 << b3)
                for b4 in range(b3 + 1, N):
                    s = b1 + b2 + b3 + b4
                    maskB = m3 | (1 << b4)
                    dict_B[s].append(maskB)

    return dict_B


def count_and_store_cases(
    disjoint: bool = True,
    A_ordered: bool = False,
    store_solutions: bool = True,
    N: int = 25,
) -> Tuple[int, List[Tuple[int, int]]]:
    """
    Count the number of ways to pick A (3 numbers) and B (4 numbers) from {0..N-1}
    such that sum(A) == sum(B).

    Matching your earlier default interpretation:
      - disjoint=True: A and B share no elements
      - A_ordered=False: A is an unordered 3-set
      - B is an unordered 4-set (always)

    If store_solutions=True, store every valid solution as a pair of bitmasks:
      solutions = [(maskA, maskB), ...]
    """
    dict_B = precompute_B_by_sum(N)

    A_mult = 6 if A_ordered else 1  # 3! if counting A as ordered
    ans = 0
    solutions: List[Tuple[int, int]] = []

    for a1 in range(N - 2):
        ma1 = 1 << a1
        for a2 in range(a1 + 1, N - 1):
            ma2 = ma1 | (1 << a2)
            for a3 in range(a2 + 1, N):
                s = a1 + a2 + a3
                maskA = ma2 | (1 << a3)

                Bs = dict_B.get(s)
                if not Bs:
                    continue

                if disjoint:
                    for maskB in Bs:
                        if (maskA & maskB) == 0:
                            ans += A_mult
                            if store_solutions:
                                # For A_ordered=False, each (maskA,maskB) is one unique solution.
                                # For A_ordered=True, the "same" maskA represents 6 ordered A's;
                                # if you truly want all ordered A's stored explicitly, you'd need
                                # to store permutations instead of just maskA.
                                solutions.append((maskA, maskB))
                else:
                    # overlap allowed
                    ans += A_mult * len(Bs)
                    if store_solutions:
                        for maskB in Bs:
                            solutions.append((maskA, maskB))

    return ans, solutions


def mask_to_tuple(mask: int, N: int = 25) -> Tuple[int, ...]:
    """Convert a bitmask into the sorted tuple of elements."""
    out = []
    x = mask
    while x:
        lsb = x & -x
        idx = (lsb.bit_length() - 1)
        out.append(idx)
        x ^= lsb
    return tuple(out)


if __name__ == "__main__":
    # This corresponds to your "disjoint=True, A_ordered=False: 281630"
    count, sols = count_and_store_cases(
        disjoint=True, A_ordered=False, store_solutions=True, N=25
    )

    print("count:", count)
    print("stored solutions:", len(sols))

    # Show a few decoded examples (A as 3-set, B as 4-set)
    for i in range(5):
        maskA, maskB = sols[i]
        A = mask_to_tuple(maskA, 25)
        B = mask_to_tuple(maskB, 25)
        print(f"{i}: A={A} sum={sum(A)}   B={B} sum={sum(B)}")

count: 281630
stored solutions: 281630
0: A=(0, 1, 13) sum=14   B=(2, 3, 4, 5) sum=14
1: A=(0, 1, 14) sum=15   B=(2, 3, 4, 6) sum=15
2: A=(0, 1, 15) sum=16   B=(2, 3, 4, 7) sum=16
3: A=(0, 1, 15) sum=16   B=(2, 3, 5, 6) sum=16
4: A=(0, 1, 16) sum=17   B=(2, 3, 4, 8) sum=17


In [2]:
import itertools
import numpy as np

In [3]:
i=0
for A in itertools.permutations(range(25), 3):
    if A[0]>A[1] and A[1]>A[2]:
        for B in itertools.combinations(set(range(25))-set(A), 4):
            if np.sum(A)==np.sum(B):
                i+=1
print(i)

281630
