### Neuralink Challenge - Part 1 

Generate an exhaustive lookup table of all possible combinations, and a corresponding hash value.

Runs in parallel and saves to sqlite. 

In [1]:
import os
import hashlib
import wave
import numpy as np
import sqlite3
from multiprocessing import Pool, cpu_count, Manager

# Function to generate a truncated SHA-256 hash for a given segment
def generate_short_hash(segment, hash_size_bytes):
    hash_sha256 = hashlib.sha256(segment).digest()[:hash_size_bytes]  # Truncate SHA-256 hash
    return hash_sha256.hex()  # Convert to hex string

# Function to pack samples into appropriate byte segments
def pack_samples(samples, bit_depth):
    packed_value = np.uint64(0)
    for i, sample in enumerate(samples):
        packed_value |= np.uint64(sample) << np.uint64(bit_depth * (len(samples) - i - 1))
    num_bytes = (bit_depth * len(samples) + 7) // 8
    packed_bytes = packed_value.tobytes()[:num_bytes]  # Convert to bytes
    return packed_bytes

# Function to save a chunk of the lookup table to the SQLite database
def save_chunk_to_db(chunk, db_name, semaphore):
    with semaphore:
        conn = sqlite3.connect(db_name)
        c = conn.cursor()
        c.execute('''CREATE TABLE IF NOT EXISTS lookup (hash TEXT PRIMARY KEY, segment BLOB)''')
        c.executemany('INSERT OR IGNORE INTO lookup (hash, segment) VALUES (?, ?)', chunk.items())
        conn.commit()
        conn.close()

# Function to generate part of the lookup table
def generate_lookup_table_chunk(args):
    bit_depth, segment_size, chunk_start, chunk_end, db_name, hash_size_bytes, semaphore, progress_update_interval = args

    lookup_table_chunk = {}
    num_values_per_sample = 2 ** bit_depth
    total_items = chunk_end - chunk_start

    def unravel_index(idx, shape):
        out = []
        for dim in reversed(shape):
            out.append(idx % dim)
            idx //= dim
        return tuple(reversed(out))

    for idx in range(chunk_start, chunk_end):
        indices = unravel_index(idx, (num_values_per_sample,) * segment_size)
        segment = pack_samples(indices, bit_depth)
        hash_key = generate_short_hash(segment, hash_size_bytes)
        lookup_table_chunk[hash_key] = segment

        # Save progress every progress_update_interval
        if (idx - chunk_start + 1) % progress_update_interval == 0 or idx == chunk_end - 1:
            save_chunk_to_db(lookup_table_chunk, db_name, semaphore)
            lookup_table_chunk.clear()

# Function to build an exhaustive lookup table for a given segment size using multiple CPUs
def build_exhaustive_lookup_table_parallel(bit_depth, segment_size, db_name, hash_size_bytes, num_workers=None, progress_update_interval=0.05):
    if num_workers is None:
        num_workers = cpu_count()

    num_values_per_sample = 2 ** bit_depth
    total_combinations = num_values_per_sample ** segment_size
    chunk_size = total_combinations // num_workers
    update_interval = int(progress_update_interval * chunk_size)

    manager = Manager()
    semaphore = manager.Semaphore(1)

    # Create arguments for each chunk
    args = [(bit_depth, segment_size, i * chunk_size, (i + 1) * chunk_size, db_name, hash_size_bytes, semaphore, update_interval) for i in range(num_workers)]
    args[-1] = (bit_depth, segment_size, args[-1][2], total_combinations, db_name, hash_size_bytes, semaphore, update_interval)  # Ensure last chunk covers all remaining data

    # Use a pool of workers to generate the lookup table in parallel
    with Pool(num_workers) as pool:
        pool.map(generate_lookup_table_chunk, args)

# Function to encode the wav data using the exhaustive lookup table
def encode_wav_data_exhaustive(wav_data, bit_depth, segment_size, db_name, hash_size_bytes, semaphore):
    conn = sqlite3.connect(db_name)
    c = conn.cursor()
    encoded_data = []
    segment_found = []
    num_segments = len(wav_data) // segment_size

    for i in range(num_segments):
        samples = wav_data[i*segment_size:(i+1)*segment_size]
        segment = pack_samples(samples, bit_depth)
        hash_key = generate_short_hash(segment, hash_size_bytes)
        with semaphore:
            c.execute('SELECT hash FROM lookup WHERE hash=?', (hash_key,))
            result = c.fetchone()
        if result:
            encoded_data.append(result[0])
            segment_found.append(True)
        else:
            segment_found.append(False)

    conn.close()
    return encoded_data, segment_found
 

In [None]:
!rm nl_bit15_seg2_hash3.db

In [None]:
# Set parameters
bit_depth = 15
segment_size = 2   
hash_size_bytes = 3  # Truncated SHA-256 hash size in bytes (3 bytes = 24 bits)
db_name = 'nl_bit15_seg2_hash3.db'

# Build the exhaustive lookup table in parallel and cache to SQLite
build_exhaustive_lookup_table_parallel(bit_depth, 
                                       segment_size, 
                                       db_name, hash_size_bytes, progress_update_interval=0.05)


In [None]:
def count_rows_in_lookup_table(db_name):
    conn = sqlite3.connect(db_name)
    c = conn.cursor()
    c.execute('SELECT COUNT(*) FROM lookup')
    count = c.fetchone()[0]
    conn.close()
    return count

count_rows_in_lookup_table(db_name)
