[Neuralink Compression Challenge](https://content.neuralink.com/compression-challenge/README.html)

content.neuralink.com/compression-challenge/data.zip is one hour of raw electrode recordings from a Neuralink implant.

This Neuralink is implanted in the motor cortex of a non-human primate, and recordings were made while playing a video game, like this.

Compression is essential: N1 implant generates ~200Mbps of eletrode data (1024 electrodes @ 20kHz, 10b resolution) and can transmit ~1Mbps wirelessly.
So > 200x compression is needed.
Compression must run in real time (< 1ms) at low power (< 10mW, including radio).

Neuralink is looking for new approaches to this compression problem, and exceptional engineers to work on it.
If you have a solution, email compression@neuralink.com

Leaderboard

Name	Compression ratio	Compressed size	./encode size	./decode size
zip	2.2	63M	231K	480K

Task

Build executables ./encode and ./decode which pass eval.sh. This verifies compression is lossless and measures compression ratio.

Your submission will be scored on the compression ratio it achieves on a different set of electrode recordings.
Bonus points for optimizing latency and power efficiency

Submit with source code and build script. Should at least build on Linux.

Data

$ ls -lah data/
total 143M
193K 0052503c-2849-4f41-ab51-db382103690c.wav
193K 006c6dd6-d91e-419c-9836-c3f320da4f25.wav
...

Uncompressed monochannel WAV files.
5 seconds per file.

## Load Data

In [1]:
! wget https://content.neuralink.com/compression-challenge/data.zip

--2024-05-22 02:52:28--  https://content.neuralink.com/compression-challenge/data.zip
Resolving content.neuralink.com (content.neuralink.com)... 13.35.116.119, 13.35.116.99, 13.35.116.66, ...
Connecting to content.neuralink.com (content.neuralink.com)|13.35.116.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 65414611 (62M) [application/zip]
Saving to: ‘data.zip’


2024-05-22 02:52:30 (65.6 MB/s) - ‘data.zip’ saved [65414611/65414611]



In [2]:
! wget https://content.neuralink.com/compression-challenge/eval.sh

--2024-05-22 02:52:30--  https://content.neuralink.com/compression-challenge/eval.sh
Resolving content.neuralink.com (content.neuralink.com)... 13.35.116.119, 13.35.116.99, 13.35.116.66, ...
Connecting to content.neuralink.com (content.neuralink.com)|13.35.116.119|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1248 (1.2K) [application/x-sh]
Saving to: ‘eval.sh’


2024-05-22 02:52:30 (1.08 GB/s) - ‘eval.sh’ saved [1248/1248]



In [3]:
! unzip data.zip

Archive:  data.zip
   creating: data/
  inflating: data/102b47d9-371e-412a-8995-0dc6115ab2bb.wav  
  inflating: data/2eef5d4d-93d1-4c0e-9d23-0989abaa34d0.wav  
  inflating: data/fa2c5efb-cc0d-4292-ab99-91d345cf17d9.wav  
  inflating: data/0458e9fc-6403-427e-afec-6a659104399a.wav  
  inflating: data/30dee5fe-ded7-4978-9480-e40155e7b060.wav  
  inflating: data/3953a230-d130-40e9-9dc4-068dda9bcef1.wav  
  inflating: data/760ba446-aae7-4136-922c-9351c97504b8.wav  
  inflating: data/2b1627e1-85a5-4155-ba52-4400e036b034.wav  
  inflating: data/8559aba4-3f0b-45e3-add2-fcda2f9d586b.wav  
  inflating: data/4036d06b-fc56-47ca-8a2f-d7e1c3f3d9a0.wav  
  inflating: data/43758ea3-e5d2-4636-ad9c-ff5bfaa2914a.wav  
  inflating: data/9397fd43-e0df-4c75-b27e-04c3ed2f0fb8.wav  
  inflating: data/741e1978-11fb-4f4a-bf25-c43c239c226c.wav  
  inflating: data/fea48b69-ccde-439e-9e7c-f06c72832c52.wav  
  inflating: data/bc9115d8-2363-4159-ae41-295a9129a9aa.wav  
  inflating: data/3ac7abb5-1e3e-4852-bfc8-e3f28

## Binary encoding of samples and cardinality

Documentation says that sample size is 10bit.
However, WAV files are encoded with sample size 16bit.
This leads to believe that 6 bits are not utilized.

- WAV: 16bit integer PCM, int16, -32768, +32767
- real: 10bit, max uint10 = 2^10 - 1 = 1024 - 1 = 0b0011_1111_1111  0x3FF


Moreover, basic approach is to construct frequency based conding scheme for possible binary numbers.

In [None]:
import os
from scipy.io import wavfile

def get_sample_unique_counts():
  sample_count = dict()
  num_samples_total = 0

  for fname in os.listdir("data"):
      sample_rate, samples = wavfile.read(os.path.join("data", fname))

      for q in samples:
          num_samples_total += 1
          if not q in sample_count:
              sample_count[q] = 0
          sample_count[q] +=1

  return sample_count, num_samples_total

sample_unique_counts, num_samples_total = get_sample_unique_counts()
print("num_samples_unique", len(sample_unique_counts), "num_samples_total", num_samples_total)

num_samples_unique 1023 num_samples_total 73383917


In [None]:
import numpy as np

In [None]:
def print_samples_binary(sample_unique_count: dict[int, int]):
  vals = sorted([(k,v) for k,v in sample_unique_count.items()], key=lambda x: x[0])
  for k,v in vals:
    print(np.binary_repr(k, width=16), v)

print_samples_binary(sample_unique_counts)

1000000000000000 118472
1000000001000001 615
1000000010000001 704
1000000011000001 1168
1000000100000001 386
1000000101000001 744
1000000110000001 510
1000000111000001 1372
1000001000000001 366
1000001001000001 657
1000001010000001 858
1000001011000001 1569
1000001100000001 294
1000001101000001 881
1000001110000001 624
1000001111000001 1192
1000010000000001 352
1000010001000010 597
1000010010000010 647
1000010011000010 1140
1000010100000010 357
1000010101000010 735
1000010110000010 514
1000010111000010 1418
1000011000000010 184
1000011001000010 500
1000011010000010 587
1000011011000010 1078
1000011100000010 186
1000011101000010 674
1000011110000010 440
1000011111000010 746
1000100000000010 109
1000100001000011 391
1000100010000011 584
1000100011000011 962
1000100100000011 361
1000100101000011 623
1000100110000011 383
1000100111000011 1097
1000101000000011 286
1000101001000011 489
1000101010000011 554
1000101011000011 1072
1000101100000011 200
1000101101000011 561
1000101110000011 406
1

In [None]:
def print_samples_binary_hist(sample_unique_count: dict[int, int]):
  vals = sorted([(k,v) for k,v in sample_unique_count.items()], key=lambda x: x[1], reverse=True)
  for k,v in vals:
    print(np.binary_repr(k, width=16), v)

print_samples_binary_hist(sample_unique_counts)

0000010111100000 2189442
0000001011100000 2120370
0000000111011111 1838285
0000010011100000 1793017
0000011011100001 1732594
0000010010100000 1702413
0000001010100000 1617331
0000011010100001 1431003
0000000011011111 1408201
0000010000100000 1387878
0000001111100000 1381060
0000001110100000 1357027
0000010100100000 1319413
1111111011100000 1315300
0000001101100000 1301035
0000010110100000 1279820
0000001000100000 1256113
0000000010011111 1242625
0000010101100000 1241704
1111110111011111 1166578
0000000110011111 1158477
0000000101011111 1086565
0000000100011111 1083359
0000100010100001 1055203
0000100011100001 1035108
0000001100100000 994240
0000100000100001 967623
0000011101100001 964102
0000010001100000 962278
0000011110100001 938184
1111111010100000 937266
0000011000100001 917226
0000100111100001 916668
0000001001100000 903822
0000011001100001 825225
0000000000011111 786808
0000011100100001 780927
0000100100100001 746474
0000101011100010 733402
1111110011011111 687800
000000000101111

In [None]:
def get_bit_set_frequency(sample_unique_counts: dict[int, int]) -> dict[int, int]:
    count = dict()
    for k,v in sample_unique_counts.items():
        for i in range(16):
          b = (1 << i)
          if (k & b) != 0:
            if not b in count:
              count[b] = 0
            count[b] += v
    return count

print_samples_binary(get_bit_set_frequency(sample_unique_counts))

0000000000000001 35773168
0000000000000010 28036378
0000000000000100 22345292
0000000000001000 21581388
0000000000010000 21546732
0000000000100000 51894771
0000000001000000 38256480
0000000010000000 42825444
0000000100000000 34605379
0000001000000000 33979101
0000010000000000 35428955
0000100000000000 28878321
0001000000000000 18362515
0010000000000000 16857586
0100000000000000 16828773
1000000000000000 16767546


In [None]:
def samples_to_bits(samples: np.array):
  v = np.zeros((samples.shape[0], 16), dtype=np.uint8)
  for i, sample in enumerate(samples):
    for j in range(16):
      b = (1 << j)
      if (sample & b) != 0:
        v[i,j] = 1
  return v

fname = "d40b3d0a-21fd-42a8-a0bd-a38a431e9401.wav"
sample_rate, samples = wavfile.read(os.path.join("data", fname))
samples_bits = samples_to_bits(samples)
for q in samples_bits[300:500]:
  print(q)

[0 1 0 0 0 1 0 1 0 1 0 1 0 0 0 0]
[1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0]
[1 0 0 0 0 1 1 0 1 0 0 1 0 0 0 0]
[1 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0]
[1 0 0 0 0 1 0 1 1 0 0 1 0 0 0 0]
[1 0 0 0 0 1 1 0 1 1 1 0 0 0 0 0]
[0 0 0 0 0 1 1 1 1 0 1 0 0 0 0 0]
[1 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0]
[1 0 0 0 0 1 1 0 1 1 1 0 0 0 0 0]
[0 1 0 0 0 1 1 1 1 1 0 1 0 0 0 0]
[0 1 0 0 0 1 1 0 0 0 1 1 0 0 0 0]
[1 0 0 0 0 1 1 0 1 1 1 0 0 0 0 0]
[0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0]
[1 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0]
[1 0 0 0 0 1 1 0 1 0 0 1 0 0 0 0]
[1 0 0 0 0 1 1 1 1 0 0 1 0 0 0 0]
[0 1 0 0 0 1 1 1 1 1 0 1 0 0 0 0]
[1 0 0 0 0 1 0 1 0 0 0 1 0 0 0 0]
[0 1 0 0 0 1 0 1 0 1 0 1 0 0 0 0]
[0 1 0 0 0 1 0 1 0 1 0 1 0 0 0 0]
[1 0 0 0 0 1 1 0 0 0 0 1 0 0 0 0]
[1 0 0 0 0 1 0 0 1 0 0 1 0 0 0 0]
[0 0 0 0 0 1 0 0 1 0 1 0 0 0 0 0]
[1 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0]
[1 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0]
[0 1 0 0 0 1 1 0 0 0 1 1 0 0 0 0]
[1 0 0 0 0 1 1 0 1 0 0 1 0 0 0 0]
[0 1 0 0 0 1 0 1 1 1 0 1 0 0 0 0]
[1 0 0 0 0 1 0 0 1 0 0 1 0 0 0 0]
[1 0 0 0 0 1 0

In [None]:
def len_sequence_same_per_bit(samples: np.array):
  counts = [[dict() for q in range(16)], [dict() for q in range(16)]]
  prev_bit = [0 for q in range(16)]
  count = [-1 for q in range(16)]

  for s in samples:
    for i,b in enumerate(s):
      if b == prev_bit[i]:
        if count[i] == -1:
          count[i] = 0
        count[i] += 1
      else:
        if count[i] > 0:
          if not count[i] in counts[b][i]:
            counts[b][i][count[i]] = 0
          counts[b][i][count[i]] += 1
          count[i] = 1
          prev_bit[i] = b

  return counts

cont_seq_lens = len_sequence_same_per_bit(samples_bits)

print("cont seq of zeroes lengths count:")
for i in range(16):
  print(i, sorted([(k,v) for k,v in cont_seq_lens[0][i].items()], key=lambda x: x[0]))

print("cont seq of ones lengths count:")
for i in range(16):
  print(i, sorted([(k,v) for k,v in cont_seq_lens[1][i].items()], key=lambda x: x[0]))

cont seq of zeroes lengths count:
0 [(1, 9433), (2, 4505), (3, 2415), (4, 1414), (5, 848), (6, 537), (7, 306), (8, 208), (9, 151), (10, 88), (11, 68), (12, 44), (13, 21), (14, 26), (15, 8), (16, 9), (17, 6), (18, 5), (19, 2), (20, 3), (22, 1), (23, 2), (24, 1), (25, 1)]
1 [(1, 6223), (2, 2430), (3, 1103), (4, 622), (5, 385), (6, 239), (7, 161), (8, 113), (9, 69), (10, 58), (11, 43), (12, 41), (13, 22), (14, 34), (15, 17), (16, 11), (17, 20), (18, 11), (19, 12), (20, 13), (21, 5), (22, 10), (23, 5), (24, 9), (25, 7), (26, 5), (27, 2), (28, 8), (29, 2), (30, 4), (32, 1), (33, 3), (34, 2), (35, 1), (36, 2), (37, 4), (38, 1), (40, 2), (41, 1), (42, 1), (45, 1), (46, 2), (48, 1), (49, 1), (50, 1), (61, 1), (63, 1), (65, 1)]
2 [(1, 4137), (2, 1550), (3, 658), (4, 353), (5, 210), (6, 122), (7, 72), (8, 47), (9, 23), (10, 22), (11, 14), (12, 15), (13, 7), (14, 9), (15, 4), (16, 3), (17, 3), (18, 2), (19, 3), (20, 2), (21, 2), (24, 3), (25, 2), (26, 1), (28, 2), (33, 1), (36, 1), (40, 1), (42, 

In [None]:
import plotly.graph_objects as go

for b in [0,1]:
  fig = go.Figure()
  fig.update_layout(title=f'total number of bits in sequences of "{b}" by length', xaxis_title='len', yaxis_title='count')

  for i in range(16):
    x = []
    y = []

    for l, count in cont_seq_lens[b][i].items():
      x.append(l)
      y.append(count * l)

    fig.add_trace(go.Scatter(x=x, y=y, mode ='markers', name=f"bit{i}"))

  fig.show()

In [None]:
def num_bits_in_seq_longer_than(counts: list[dict[int, int]], min_len: int = 16) -> int:
  sum = 0
  for q in counts:
    for l, count in q.items():
      if l >= min_len:
        sum += count * l
  return sum

nb = 16
num_total_bits = 16 * len(samples)
num_bits_in_seq_longer_than_nb_0 = num_bits_in_seq_longer_than(cont_seq_lens[0], min_len=nb)
num_bits_in_seq_longer_than_nb_1 = num_bits_in_seq_longer_than(cont_seq_lens[1], min_len=nb)

print("num_total_bits", num_total_bits)
print(f"num_bits_in_seq_longer_than({nb}) for zeroes:", num_bits_in_seq_longer_than_nb_0, num_bits_in_seq_longer_than_nb_0 / num_total_bits)
print(f"num_bits_in_seq_longer_than({nb}) for ones:", num_bits_in_seq_longer_than_nb_1, num_bits_in_seq_longer_than_nb_1 / num_total_bits)

num_total_bits 1579936
num_bits_in_seq_longer_than(16) for zeroes: 85607 0.05418384035809046
num_bits_in_seq_longer_than(16) for ones: 553834 0.3505420472728009


In [None]:
def changed_bits(samples: np.array):
  counts = [0]

  for i, s in enumerate(samples):
    if i == 0:
      continue

    num_diff = 0
    for j in range(16):
      if samples[i-1][j] != s[j]:
        num_diff += 1

    counts.append(num_diff)

  return np.array(counts)

samples_changed_bits = changed_bits(samples_bits)

import plotly.express as px
df = px.data.tips()
fig = px.histogram(samples_changed_bits, histnorm='probability density')
fig.show()

In [None]:
from collections import Counter
import itertools

def get_changed_masks(samples):
  for i, s in enumerate(samples):
    if i == 0:
      continue
    m = samples[i-1] ^ s
    yield (m, samples[i-1] & m, s & m)

def count_all_changed_masks():
    for fname in os.listdir("data"):
      sample_rate, samples = wavfile.read(os.path.join("data", fname))
      for s in get_changed_masks(samples):
        yield s

changed_masks = Counter(count_all_changed_masks())
print("num_diff1_masks", len(changed_masks))
for k,v in sorted([(k,v) for k,v in changed_masks.items()], key=lambda x: x[1], reverse=True)[:2000]:
  print(np.binary_repr(k[0], width=16), np.binary_repr(k[1], width=16), np.binary_repr(k[2], width=16), v, v/num_samples_total)

num_diff1_masks 24065
0000000000000000 0000000000000000 0000000000000000 5779830 0.07876153571905953
0000000001000000 0000000000000000 0000000001000000 2516619 0.03429387668145324
0000000001000000 0000000001000000 0000000000000000 2514429 0.03426403363014814
0000000010000000 0000000000000000 0000000010000000 2175660 0.029647640640387184
0000000010000000 0000000010000000 0000000000000000 2174659 0.029634000049356863
0000000100000000 0000000000000000 0000000100000000 1829770 0.024934210039510428
0000000100000000 0000000100000000 0000000000000000 1823583 0.02484990001283251
0000000110000000 0000000010000000 0000000100000000 1339228 0.01824961183252183
0000000110000000 0000000100000000 0000000010000000 1335634 0.018200636523667713
0000000011000000 0000000000000000 0000000011000000 1132390 0.015431037838985891
0000000011000000 0000000011000000 0000000000000000 1129387 0.015390116065894929
0000000011000000 0000000001000000 0000000010000000 1015144 0.01383333081007382
0000000011000000 0000000

In [None]:
selected_changed_masks = sorted([(k,v) for k,v in changed_masks.items()], key=lambda x: x[1], reverse=True)[:1028]
len(selected_changed_masks), sum(v for k,v in selected_changed_masks), sum(v for k,v in selected_changed_masks) / num_samples_total

(1028, 72417813, 0.9868349355077353)

In [None]:
selected_changed_masks = [((-1,-1,-1),144835626)] + selected_changed_masks

In [None]:
from collections import Counter
from collections import defaultdict
import itertools

def get_changed_masks_4b(samples):
  for i, s in enumerate(samples):
    if i == 0:
      continue
    m = samples[i-1] ^ s
    m0 = (m & 0xF000, s & m & 0xF000)
    m1 = (m & 0x0F00, s & m & 0x0F00)
    m2 = (m & 0x00F0, s & m & 0x00F0)
    m3 = (m & 0x000F, s & m & 0x000F)
    yield (m0, m1, m2, m3)

def count_all_changed_masks_4b():
    num_total = 0
    count_m0 = defaultdict(int)
    count_m1 = defaultdict(int)
    count_m2 = defaultdict(int)
    count_m3 = defaultdict(int)

    for fname in ["d40b3d0a-21fd-42a8-a0bd-a38a431e9401.wav"]:#os.listdir("data"):
      sample_rate, samples = wavfile.read(os.path.join("data", fname))
      for m0,m1,m2,m3 in get_changed_masks_4b(samples):
        num_total += 1
        count_m0[m0] += 1
        count_m1[m1] += 1
        count_m2[m2] += 1
        count_m3[m3] += 1

    return num_total, (count_m0, count_m1, count_m2, count_m3)

num_total, changed_masks_4b = count_all_changed_masks_4b()

for i in range(4):
  changed_masks = changed_masks_4b[i]
  print(f"num_diff1_4bmasks: {i}", len(changed_masks))
  for k,v in sorted([(k,v) for k,v in changed_masks.items()], key=lambda x: x[1], reverse=True):
    print(np.binary_repr(k[0], width=16), np.binary_repr(k[1], width=16), v, v/num_total)

num_diff1_4bmasks: 0 5
0000000000000000 0000000000000000 90673 0.9182540888146236
1111000000000000 1111000000000000 3615 0.03660944857967492
1111000000000000 0000000000000000 3615 0.03660944857967492
0001000000000000 0001000000000000 421 0.004263507013013317
0001000000000000 0000000000000000 421 0.004263507013013317
num_diff1_4bmasks: 1 81
0000000000000000 0000000000000000 17589 0.17812547470758014
0000000100000000 0000000100000000 7756 0.07854574915185579
0000000100000000 0000000000000000 7649 0.07746214998227759
0000001000000000 0000000000000000 5581 0.05651931743379412
0000001000000000 0000001000000000 5517 0.05587118335105575
0000001100000000 0000000100000000 4577 0.046351714010835994
0000001100000000 0000001000000000 4577 0.046351714010835994
0000011000000000 0000010000000000 3170 0.03210289128563472
0000011000000000 0000001000000000 3141 0.03180920552939389
0000111000000000 0000011000000000 1980 0.020051648184718215
0000111000000000 0000100000000000 1939 0.019636437287963947
0000

In [None]:
"""Huffman encoding and decoding. Requires Python >= 3.7."""
from __future__ import annotations

from collections import Counter

from heapq import heapify
from heapq import heappush
from heapq import heappop

from itertools import chain
from itertools import islice

from typing import BinaryIO
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple


LEFT_BIT = "0"
RIGHT_BIT = "1"
WORD_SIZE = 8  # Assumed to be a multiple of 8.
READ_SIZE = WORD_SIZE // 8
P_EOF = 1 << WORD_SIZE


class Node:
    """Huffman tree node."""

    def __init__(
        self,
        weight: int,
        symbol: Optional[int] = None,
        left: Optional[Node] = None,
        right: Optional[Node] = None,
    ):
        self.weight = weight
        self.symbol = symbol
        self.left = left
        self.right = right

    def is_leaf(self) -> bool:
        """Return `True` if this node is a leaf node, or `False` otherwise."""
        return self.left is None and self.right is None

    def __lt__(self, other: Node) -> bool:
        return self.weight < other.weight


def huffman_tree(weights: Dict[int, int]) -> Node:
    """Build a prefix tree from a map of symbol frequencies."""
    heap = [Node(v, k) for k, v in weights.items()]
    heapify(heap)

    # Pseudo end-of-file with a weight of 1.
    heappush(heap, Node(1, P_EOF))

    while len(heap) > 1:
        left, right = heappop(heap), heappop(heap)
        node = Node(weight=left.weight + right.weight, left=left, right=right)
        heappush(heap, node)

    return heappop(heap)


def huffman_table(tree: Node) -> Dict[int, str]:
    """Build a table of prefix codes by visiting every leaf node in `tree`."""
    codes: Dict[int, str] = {}

    def walk(node: Optional[Node], code: str = ""):
        if node is None:
            return

        if node.is_leaf():
            assert node.symbol
            codes[node.symbol] = code
            return

        walk(node.left, code + LEFT_BIT)
        walk(node.right, code + RIGHT_BIT)

    walk(tree)
    return codes

def _decode(bits: Iterable[str], tree: Node) -> Iterable[int]:
    node = tree

    for bit in bits:
        if bit == LEFT_BIT:
            assert node.left
            node = node.left
        else:
            assert node.right
            node = node.right

        if node.symbol == P_EOF:
            break

        if node.is_leaf():
            assert node.symbol
            yield node.symbol
            node = tree  # Back to the top of the tree.

In [None]:
tree = huffman_tree(sample_unique_counts)
table = huffman_table(tree)
print("len huffman table", len(table))
print(f"Symbol Code\n------ ----")
for k, v in sorted(table.items(), key=lambda x: len(x[1])):
    print(np.binary_repr(k, width=16), v)

len huffman table 1024
Symbol Code
------ ----
0000010010100000 00000
0000011011100001 00001
0000010011100000 00011
0000000111011111 00101
0000001011100000 01101
0000010111100000 10000
0000001001100000 001000
0000100111100001 001001
0000011000100001 001100
1111111010100000 001110
0000011110100001 001111
0000010001100000 010000
0000011101100001 010001
0000100000100001 010010
0000001100100000 010100
0000100011100001 010111
0000100010100001 011100
0000000100011111 011110
0000000101011111 011111
0000000110011111 100011
1111110111011111 100100
0000010101100000 101000
0000000010011111 101001
0000001000100000 101011
0000010110100000 101100
0000001101100000 101110
1111111011100000 101111
0000010100100000 110000
0000001110100000 110011
0000001111100000 110100
0000010000100000 110110
0000000011011111 110111
0000011010100001 111000
0000001010100000 111110
0000101111100010 0001010
1111101011011111 0101011
1111111000100000 0101100
1111111001100000 0110000
1111110100011111 0110010
1111101111011111 0

In [None]:
freq = [(2786, 34), (1504, 34), (1761, 30), (2529, 28), (2273, 27), (2978, 24), (2209, 24), (2401, 23), (2722, 22), (1248, 20), (2081, 19), (1889, 17), (2337, 17), (1184, 16), (2914, 16), (2465, 15), (736, 15), (1697, 14), (2593, 14), (3042, 14), (1953, 13), (2145, 12), (3106, 11), (1825, 11), (1440, 11), (2017, 10), (864, 10), (1056, 10), (1633, 10), (1569, 10), (1312, 10), (1376, 10), (3170, 9), (3234, 9), (992, 8), (2850, 7), (3490, 7), (928, 7), (2658, 7), (3298, 7), (800, 6), (3426, 6), (3554, 6), (1120, 5), (287, 4), (415, 4), (479, 4), (544, 3), (3362, 3), (3683, 3), (3618, 2), (672, 2), (351, 2), (-545, 2), (223, 2), (4003, 2), (3811, 2), (3747, 2), (608, 1), (4131, 1), (95, 0), (159, 0), (-160, 0), (3939, 0), (4259, 0), (4451, 0), (3875, 0)]

In [None]:
tree = huffman_tree({v[0]: v[1] for v in freq})
table = huffman_table(tree)
print("len huffman table", len(table))
print(f"Symbol Code\n------ ----")
for k, v in sorted(table.items(), key=lambda x: len(x[1])):
    print(k, v)

len huffman table 68
Symbol Code
------ ----
2786 0011
1504 0100
2465 00000
1184 00010
2914 00011
1889 00101
2337 01010
2081 01100
1248 10000
2722 10010
2401 10101
2209 10110
2978 11000
2273 11010
2529 11011
1761 11111
992 000011
3234 010110
3170 011010
1312 011011
1633 011100
864 011101
1056 011110
1376 011111
2017 100010
1569 100011
1825 100110
3106 100111
1440 101000
2145 101001
1953 110010
2593 111000
1697 111001
3042 111011
736 111101
479 0000100
415 0000101
287 0101110
1120 0101111
3554 1011110
3426 1011111
800 1100110
3298 1100111
3490 1110100
2850 1110101
2658 1111000
928 1111001
3811 00100000
4003 00100010
223 00100011
-545 00100100
351 00100101
672 00100110
3618 00100111
3683 10111001
3362 10111010
544 10111011
608 001000010
256 101110000
3747 101110001
4131 0010000111
3939 001000011000
-160 0010000110010
159 0010000110011
95 0010000110100
3875 0010000110101
4451 0010000110110
4259 0010000110111


In [None]:
from collections import Counter
import itertools

def get_is_match_prev(samples, n_prev: int = 256):
  for i, s in enumerate(samples):
    if i < n_prev + 1:
      yield 0
      continue

    is_match = False
    for j in range(n_prev):
      if s == samples[i - n_prev - j]:
        is_match = True

    yield 1 if is_match else 0

def get_is_match_prev_one(fname: str, n_prev: int = 256):
    sample_rate, samples = wavfile.read(os.path.join("data", fname))
    for (s, is_match) in zip(samples, get_is_match_prev(samples, n_prev)):
      yield s, is_match

def count_is_match_prev_one(s):
  sum = 0
  total = 0
  for _, is_match in s:
    total += 1
    sum += is_match
  return sum, total

def lengths_is_match_prev_one(s):
  lengths0 = []
  lengths1 = []

  prev = -1
  count = 0
  for _, is_match in s:
    if prev == -1:
      prev = is_match
      count = 1
      continue

    if is_match == prev:
      count += 1
    else:
      if is_match == 1:
        lengths1.append(count)
      else:
        lengths0.append(count)
      prev = is_match
      count = 1

  return np.array(lengths0), np.array(lengths1)

n_prev = 128
print("n_prev", n_prev)
df_is_match = list(get_is_match_prev_one("d40b3d0a-21fd-42a8-a0bd-a38a431e9401.wav", n_prev=n_prev))

num_matched, total = count_is_match_prev_one(df_is_match)
print("num_matched", num_matched, "num_total", num_total, "ratio_matched", num_matched / num_total)
for i,(s,v) in enumerate(df_is_match[:1000]):
  print(i, np.binary_repr(s, width=16), v)

n_prev 128
num_matched 72571 num_total 98745 ratio_matched 0.7349334143500936
0 0000101010100010 0
1 0000100110100001 0
2 0000111000100010 0
3 0000101100100010 0
4 0000101010100010 0
5 0000101011100010 0
6 0000100000100001 0
7 0000010010100000 0
8 0000100011100001 0
9 0000110000100010 0
10 0000101110100010 0
11 0000101110100010 0
12 0000110000100010 0
13 0000100010100001 0
14 0000100111100001 0
15 0000110000100010 0
16 0000100111100001 0
17 0000011111100001 0
18 0000011101100001 0
19 0000100011100001 0
20 0000100111100001 0
21 0000100000100001 0
22 0000011101100001 0
23 0000011010100001 0
24 0000011110100001 0
25 0000010011100000 0
26 0000001101100000 0
27 0000011011100001 0
28 0000011100100001 0
29 0000011010100001 0
30 0000100000100001 0
31 0000110001100010 0
32 0000101110100010 0
33 0000101101100010 0
34 0000100101100001 0
35 0000100110100001 0
36 0000101000100001 0
37 0000101110100010 0
38 0000101110100010 0
39 0000101100100010 0
40 0000101111100010 0
41 0000010000100000 0
42 00000

In [None]:
import math
import plotly.express as px

for n_prev in [2 ** n for n in range(12)]:
  df_is_match = list(get_is_match_prev_one("d40b3d0a-21fd-42a8-a0bd-a38a431e9401.wav", n_prev=n_prev))
  num_matched, total = count_is_match_prev_one(df_is_match)
  compression_ratio = 16 / ((1 - num_matched / num_total) * (16 + 1) + num_matched / num_total * (math.log2(n_prev) + 1))

  print(0)
  fig = px.histogram(lengths_is_match_prev_one(df_is_match)[0])#, histnorm='probability density')
  fig.show()

  print(1)
  fig = px.histogram(lengths_is_match_prev_one(df_is_match)[1])#, histnorm='probability density')
  fig.show()

  print("n_prev", n_prev, "num_matched", num_matched, "num_total", num_total, "ratio_matched", num_matched / num_total, "compression_ratio", compression_ratio)

0


1


n_prev 1 num_matched 4862 num_total 98745 ratio_matched 0.04923793609803028 compression_ratio 0.986911516403862
0


1


n_prev 2 num_matched 7624 num_total 98745 ratio_matched 0.0772089726062079 compression_ratio 1.0099820687142214
0


1


n_prev 4 num_matched 13020 num_total 98745 ratio_matched 0.13185477745708643 compression_ratio 1.055824537134494
0


1


n_prev 8 num_matched 21972 num_total 98745 ratio_matched 0.22251253228011544 compression_ratio 1.134161600368693
0


1


n_prev 16 num_matched 35479 num_total 98745 ratio_matched 0.35929920502303914 compression_ratio 1.2609933459279425
0


1


n_prev 32 num_matched 51057 num_total 98745 ratio_matched 0.5170590915995746 compression_ratio 1.414383396088584
0


1


n_prev 64 num_matched 64191 num_total 98745 ratio_matched 0.6500683578915388 compression_ratio 1.5239087344647484
0


1


n_prev 128 num_matched 72571 num_total 98745 ratio_matched 0.7349334143500936 compression_ratio 1.5405947777043194
0


1


n_prev 256 num_matched 80718 num_total 98745 ratio_matched 0.8174388576636792 compression_ratio 1.5295651845591292
0


1


n_prev 512 num_matched 87634 num_total 98745 ratio_matched 0.8874778469795939 compression_ratio 1.483176825221291
0


1


n_prev 1024 num_matched 90606 num_total 98745 ratio_matched 0.9175755734467568 compression_ratio 1.391964434388901
0


1


n_prev 2048 num_matched 94105 num_total 98745 ratio_matched 0.9530102790014684 compression_ratio 1.3077292366778686


In [27]:
import os
from scipy.io import wavfile

fname = "b4a354ca-8194-4459-b711-0fd099b117e8.wav  "
sample_rate, samples = wavfile.read(os.path.join("data", fname))

num_uncompressed_bytes = len(samples) * 16
print("num_uncompressed_bytes", num_uncompressed_bytes)

FileNotFoundError: [Errno 2] No such file or directory: 'data/b4a354ca-8194-4459-b711-0fd099b117e8.wav  '

In [36]:
import numpy as np
import math

class Encoder():
  def __init__(self, cache_size: int = 128, max_dictionary_use_len: int = 1024, min_dictionary_use_len: int = 8, log_first_n_samples: int = 1000):
    self.cache_size = cache_size
    self.max_dictionary_use_len = max_dictionary_use_len
    self.log_first_n_samples = log_first_n_samples
    self.min_dictionary_use_len = min_dictionary_use_len
    self.cache = dict()
    self.buffer = []

  def __flush(self):
    if len(self.buffer) == 0:
      return

    if len(self.buffer) < self.min_dictionary_use_len:
      for s in self.buffer:
          q = "0" + np.binary_repr(s, width=16)
          self.__log(f"{q}: <- {np.binary_repr(s, width=16)}: {len(q)/16:.2f} raw sample (buffer_len({len(self.buffer)}))")
          yield q

    marker = "1" + np.binary_repr(len(self.buffer), width=int(math.log2(self.max_dictionary_use_len)))
    self.__log(f"{marker}: flush buffer, next n({len(self.buffer)}) samples are encoded with dictionary")
    yield marker

    for s in self.buffer:
      q = np.binary_repr(self.__get_cache_idx(s), width=int(math.log2(self.cache_size)))
      self.cache[s] += 1
      self.__log(f"{q}: <- {np.binary_repr(s, width=16)}: {len(q)/16:.2f}")
      yield q

    self.buffer = []

  def __evict_from_cache(self):
    # least frequently used cache eviction policy
    min_k, min_v = None, None
    for k,v in self.cache.items():
      if min_v is None or v < min_v:
        min_k, min_v = k, v
    del self.cache[min_k]

  def __add_to_cache(self, v: np.int16):
    if len(self.cache) > self.cache_size:
      self.__evict_from_cache()
    self.cache[v] = 0

  def __get_cache_idx(self, s: np.int16) -> int:
    for i,(k,v) in enumerate(sorted([(k,v) for k,v in self.cache.items()], key=lambda x: x[1], reverse=True)):
      if s == k:
        return i
    return None

  def __log(self, s):
    if self.log_first_n_samples < 0:
      return
    self.log_first_n_samples -= 1
    print(s)

  def encode(self, samples: np.array):
    # we do not transfer dictionary, decoder has same algorithm, decoder reconstructs dictionary on its own from same sequence
    for i,s in enumerate(samples):
        if s in self.cache:
          if len(self.buffer) >= self.max_dictionary_use_len:
            for q in self.__flush():
              yield q

          self.buffer.append(s)
          continue

        for q in self.__flush():
          yield q

        self.__add_to_cache(s)

        raw_sample = "0" + np.binary_repr(s, width=16)
        self.__log(f"{raw_sample}: <- {np.binary_repr(s, width=16)}: {len(raw_sample)/16:.2f} raw sample")
        yield raw_sample

encoded = list(Encoder(cache_size=2**7,max_dictionary_use_len=2**15, min_dictionary_use_len=8).encode(samples))
compressed_bytes = sum(len(s) for s in encoded)
print("compressed_bytes", compressed_bytes, "compression_ratio", num_uncompressed_bytes / compressed_bytes)

00000011110100001: <- 0000011110100001: 1.06 raw sample
00000100000100001: <- 0000100000100001: 1.06 raw sample
00000011101100001: <- 0000011101100001: 1.06 raw sample
00000011100100001: <- 0000011100100001: 1.06 raw sample
00000011011100001: <- 0000011011100001: 1.06 raw sample
00000100001100001: <- 0000100001100001: 1.06 raw sample
00000011011100001: <- 0000011011100001: 1.06 raw sample (buffer_len(2))
00000100000100001: <- 0000100000100001: 1.06 raw sample (buffer_len(2))
1000000000000010: flush buffer, next n(2) samples are encoded with dictionary
0000100: <- 0000011011100001: 0.44
0000010: <- 0000100000100001: 0.44
00000100010100001: <- 0000100010100001: 1.06 raw sample
00000100010100001: <- 0000100010100001: 1.06 raw sample (buffer_len(1))
1000000000000001: flush buffer, next n(1) samples are encoded with dictionary
0000110: <- 0000100010100001: 0.44
00000100100100001: <- 0000100100100001: 1.06 raw sample
00000100000100001: <- 0000100000100001: 1.06 raw sample (buffer_len(5))
000

In [None]:
import numpy as np
import math

class Encoder2():
  def __init__(self, cache_size: int = 128, max_dictionary_use_len: int = 1024, min_dictionary_use_len: int = 8, log_first_n_samples: int = 1000):
    self.cache_size = cache_size
    self.max_dictionary_use_len = max_dictionary_use_len
    self.log_first_n_samples = log_first_n_samples
    self.min_dictionary_use_len = min_dictionary_use_len
    self.cache = dict()
    self.buffer = []

  def __flush(self):
    if len(self.buffer) == 0:
      return

    if len(self.buffer) < self.min_dictionary_use_len:
      for s in self.buffer:
          q = "0" + np.binary_repr(s, width=16)
          self.__log(f"{q}: <- {np.binary_repr(s, width=16)}: {len(q)/16:.2f} raw sample (buffer_len({len(self.buffer)}))")
          yield q

    marker = "1" + np.binary_repr(len(self.buffer), width=int(math.log2(self.max_dictionary_use_len)))
    self.__log(f"{marker}: flush buffer, next n({len(self.buffer)}) samples are encoded with dictionary")
    yield marker

    for s in self.buffer:
      q = np.binary_repr(self.__get_cache_idx(s), width=int(math.log2(self.cache_size)))
      self.cache[s] += 1
      self.__log(f"{q}: <- {np.binary_repr(s, width=16)}: {len(q)/16:.2f}")
      yield q

    self.buffer = []

  def __evict_from_cache(self):
    # least frequently used cache eviction policy
    min_k, min_v = None, None
    for k,v in self.cache.items():
      if min_v is None or v < min_v:
        min_k, min_v = k, v
    del self.cache[min_k]

  def __add_to_cache(self, v: np.int16):
    if len(self.cache) > self.cache_size:
      self.__evict_from_cache()
    self.cache[v] = 0

  def __get_cache_idx(self, s: np.int16) -> int:
    for i,(k,v) in enumerate(sorted([(k,v) for k,v in self.cache.items()], key=lambda x: x[1], reverse=True)):
      if s == k:
        return i
    return None

  def __log(self, s):
    if self.log_first_n_samples < 0:
      return
    self.log_first_n_samples -= 1
    print(s)

  def encode(self, samples: np.array):
    # we do not transfer dictionary, decoder has same algorithm, decoder reconstructs dictionary on its own from same sequence
    for i,s in enumerate(samples):
        if len(self.buffer) >= self.max_dictionary_use_len:
          for q in self.__flush():
            yield q

        self.buffer.append(s)

    for q in self.__flush():
      yield q

    for q in self.__flush():
      yield q

encoded = list(Encoder2(cache_size=2**7,max_dictionary_use_len=2**15, min_dictionary_use_len=8).encode(samples))
compressed_bytes = sum(len(s) for s in encoded)
print("compressed_bytes", compressed_bytes, "compression_ratio", num_uncompressed_bytes / compressed_bytes)

## Observations Log
- all values do indeed fit into 10bits, there are only 1023 distinct values
- 1024 electrodes (set/not-set electrode data per sample) is already compressed into some 10bit number, which of 1024 electrodes is firing, we do not know based on 10bit number
- there are 1024 differnt values for samples
- 10bit values are not lower-bits, for some reason higher bits are also set. this can mean there is either: A) reserved full space of 16bits or B) there is some structure to bits arrangement already
- value of 0 is not used. looks like it is reserved
- ~most frequent samples have contigious series of zeroes~
- ~most frequent samples have many zeroes~
- ~there are no unused bits among 16bit samples~
- ~bits are set equally frequently, there is not "dead"-bits in samples~
- to achieve 200x compression, reducing single sample 16bit to 10bit or lower would not do it. we need to compress cross-sample information
- ~there is significant continuity in bits across samples in single sequence of samples. akin to columnar data types, column based compression may be useful. maybe even just transposing data and compressing that can be very significant reduction in size.~
- ~max bits sequence is ~2K (both, zeroes and ones)~
- ~by doing "column based same value count encoding for sequences of 16+" 30% of bits can be removed~
- ~consecutive samples are not very different, many samples have only 4 bits difference. (note: does not make sense to repeat whole almost-previous sample all over again)~
- in single file, out of 90K samples, therea are only 1.5K different transitions between samples. in whole dataset there is only 25K different masks. out of 630M samples, that is 0.00003968253968 of total samples. unlikely this is coincidence. this simple heuristic highlights the fact of possible transitions. there is strong fundamental causality between neuron spikes. certain neruons spike only before other neurons. this causality in simplest form is encoded in "possible" transitions.
- there is 1K transitions that are used more than 1K times, which is 98% of samples
- ~calculating differences from previous in words of 4 does not work at all for some of 16bits~
- bits in number not meaningful, this is again due to copmression of 1028 into 10bit number already. what we can actually do is check if this number is the same or not the same compare to others. information is likely encoded in ordering of these numbers. likely they either: A) repeat patterns (oscilations?); B) transition in certian graph of possible transitions (oscilations?);. Some basic herusitic on most recently repeated values may work well.
- keeping cache of last N bit words, and passing either new word or index in cache, is best at N=128 and gives 1.54 compression ratio

## References

* in WAV each sample is bit value
* http://tiny.systems/software/soundProgrammer/WavFormatDocs.pdf
* https://docs.python.org/3/library/wave.html
* https://github.com/go-audio/wav
* provided files wav encode in 16bit per sample (even though doc says 10bit resolution)
* `scipy.io.wavfile.read` does not support 10bit resolution
* https://en.wikipedia.org/wiki/Variable-length_code
* https://en.wikipedia.org/wiki/Prefix_code
* https://rosettacode.org/wiki/Huffman_coding#Python
* https://golang.google.cn/src/compress/bzip2/huffman.go
* https://iopscience.iop.org/article/10.1088/1741-2552/acf5a4