In [93]:
import random
from typing import Generator

FILE_PATH = "ml-1m/ratings.dat"



In [94]:
def get_sample(file_path: str) -> Generator[tuple[int, int], None, None]:
    #1::1193::5::978300760
    """
    UserID::MovieID::Rating::Timestamp

    - UserIDs range between 1 and 6040 
    - MovieIDs range between 1 and 3952
    - Ratings are made on a 5-star scale (whole-star ratings only)
    - Timestamp is represented in seconds since the epoch as returned by time(2)
    - Each user has at least 20 ratings

    Returns: Generator of tuples (UserID, MovieID)
    """
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            content = line.strip()
            if content:  
                parts = content.split("::")
                sample = (int(parts[0]), int(parts[1]))
                yield sample

In [95]:
def trailing_zeros(n: int) -> int:
    """Count trailing zeros in binary representation of n"""

    # skip the '0b1' prefix (as mentioned in reference video)
    binary_str = bin(n)[2:]
    total_zeros = 0
    for char in reversed(binary_str):
        if char == '0':
            total_zeros += 1
        else:
            break
    return total_zeros

def median(values: list[float]) -> float:
    """Return the median of a list of numbers."""
    n = len(values)
    if n == 0:
        raise ValueError("median() arg is an empty sequence")
    sorted_vals = sorted(values)
    mid = n // 2
    if n % 2 == 1:
        return sorted_vals[mid]
    else:
        return (sorted_vals[mid - 1] + sorted_vals[mid]) / 2
    
def trailing_zeros(n: int, register_bits: int) -> tuple[int,int]:
    """Count trailing zeros in binary representation of n"""

    # skip the '0b1' prefix (as mentioned in reference video)
    binary_str = bin(n)[2:]

    register_bits_str = binary_str[:register_bits] or '0'
    bitmap = binary_str[register_bits:]
    register_id = int(register_bits_str, 2)

    total_zeros = 0
    for char in reversed(bitmap):
        if char == '0':
            total_zeros += 1
        else:
            break
    return register_id, total_zeros

In [96]:
# 2 reg - 1 bit
# 4 reg - 2 bits
# 8 reg - 3 bits
# 16 reg - 4 bits

# FM uses 1 register, HLL uses multiple registers (for some reason not works with HLL case)
NUM_REGISTERS_BITS  = 0
NUM_REGISTERS = 2 ** NUM_REGISTERS_BITS
# at least R = log_2(N) -> N = 2^R, 1M elements - at least 20 binary str lenghth + NUM_REGISTERS_BITS for registers
binary_str_length = 20 + NUM_REGISTERS_BITS

print(f"Using {NUM_REGISTERS} registers ({NUM_REGISTERS_BITS} bits each)")
print(f"Total binary string length: {binary_str_length} bits")

Using 1 registers (0 bits each)
Total binary string length: 20 bits


In [97]:
def multiplicative_hash_function(x):
    '''Knuth's multiplicative hash function'''
    shift = (1 << binary_str_length) 
    mask = shift - 1
    return (x * 2654435761) & mask  

def default_hash_function(x: int) -> int:
    '''Python's built-in hash function'''

    # this gives 2**binary_str_length in binary (like 1000000...)
    shift = (1 << binary_str_length) 
    mask = shift - 1
    #return only the lower binary_str_length bits
    return hash(x) & mask

In [98]:
def get_alpha(num_registers: int) -> float:
    """Get the alpha constant based on the number of registers."""
    if num_registers >= 16 and num_registers < 32:
        alpha = 0.673
    elif num_registers >= 32 and num_registers < 64:
        alpha = 0.697
    elif num_registers >= 64 and num_registers < 128:
        alpha = 0.709
    elif num_registers >= 128:
        alpha = 0.7213 / (1 + 1.079 / num_registers)
    else:
        alpha = 0.5  # rough approximation for smaller m
    return alpha

In [99]:
register_table = {}
ten_percent_of_movies = set()

for entry in get_sample(FILE_PATH):
    user_hash = multiplicative_hash_function(entry[0])
    movie_hash = multiplicative_hash_function(entry[1])
  
    # can be any hash remainder from 0 to 9
    if movie_hash % 10 == 1:
        ten_percent_of_movies.add(entry[1]) 

    register_id, tz = trailing_zeros(user_hash, NUM_REGISTERS_BITS)
    if register_id not in register_table:
        register_table[register_id] = tz
    else:
        if tz > register_table[register_id]:
            register_table[register_id] = tz

# Estimate number of unique users
estimates = []
for reg_id, tz in register_table.items():
    estimate = 2 ** tz / 0.77351
    estimates.append(estimate)

final_estimate = sum(estimates)/len(estimates) #NUM_REGISTERS ** 2 / sum(1 / est for est in estimates)
print("Approx of total unique users: ", final_estimate)
print("Approx of 10% of unique movies: ", len(ten_percent_of_movies))


Approx of total unique users:  5295.342012385101
Approx of 10% of unique movies:  371


#### Compare with ground truth

In [100]:
unqiue_movies = set()
unique_users = set()
for entry in get_sample(FILE_PATH):
    unique_users.add(entry[0])
    unqiue_movies.add(entry[1])

print(f"Ground truth unique users: {len(unique_users)}")
print(f"Ground truth unique movies: {len(unqiue_movies)}")

print("10% of unique movies ", len(unqiue_movies) // 10)

Ground truth unique users: 6040
Ground truth unique movies: 3706
10% of unique movies  370
