In [33]:
from scipy.special import comb

In [34]:
SQUARES = set([0, 1, 4, 9, 16, 25, 36, 49, 64])

In [37]:
def binary_representation(n: int) -> str:
    return bin(n)[2:]

def hamming_weight(n: int) -> int:
    return binary_representation(n).count("1")

def brute_force_get_filtered_hamming_count(bound: int, filter_set=None) -> int:
    if not filter_set:
        filter_set = set()
    count = 0
    for i in range(bound + 1):
        if hamming_weight(i) in filter_set:          
            count += 1
        # window dressing 
        #print i, binary_representation(i), count
    return count

def get_filtered_hamming_count(bound, filter_set=None) -> int:
    """
        A function which counts the numbers in the interval [0, @param bound] 
        whose hamming weight is in @param filter_set
        
        Complexity: O(log2(bound) * len(filter_set) * max(filter_set)) 
        -- assuming optimal time complexity for the implemntation of 
        comb(n, k) is O(k).
    """
    if not filter_set:
        return 0
    
    bits = binary_representation(bound)
    bit_length = len(bits) # == floor(log2(bound)) + 1
    
    # Preprocess to get allowable_prefixes:
    # We generate a set of prefixes such that the binary strings with those prefixes are contained in the interval [0, bound]
    # and every number in [0, bound] has such a prefix in it's binary representation. 
    allowable_prefixes = set()
    current_prefix = ""
    for i in range(len(bits)):
        if bits[i] == "1":
            # In this case,  numbers which begin with current_prefix + "0" are less than bound (regardless of the values of the following bits),
            # and do not begin with any other prefix that is in allowable_prefixes (by construction).
            allowable_prefixes.add(current_prefix + "0")
        current_prefix += bits[i]
    allowable_prefixes.add(current_prefix)
    
    count = 0
    for k in allowable_prefixes: 
        for s in filter_set:
            count += int(comb(bit_length - len(k) , s - k.count("1")))
    return count

In [38]:
# random small test
for i in range(1024 + 1):
    us = get_filtered_hamming_count(i, SQUARES)
    them = brute_force_get_filtered_hamming_count(i, SQUARES)
    if us != them:
        print us, them