# Mock Data for Compressed Matrices

In [1]:
import numpy as np
from tqdm import tqdm
import pickle

## Preparation

### Create Random Quantized Matrices

In [2]:
np.random.seed(202404151)
matrices = np.random.randn(100, 1024, 1024)

In [3]:
quantized_matrices = np.round(matrices * 8).astype(np.int8)
print(f'mins: {quantized_matrices.min(axis=(1, 2))}')
print(f'maxs: {quantized_matrices.max(axis=(1, 2))}')

mins: [-41 -38 -40 -38 -44 -38 -37 -41 -41 -37 -39 -40 -39 -37 -39 -40 -37 -41
 -40 -40 -39 -41 -38 -42 -39 -39 -42 -39 -41 -41 -38 -39 -40 -38 -38 -38
 -42 -41 -37 -44 -38 -38 -37 -39 -39 -36 -39 -42 -40 -40 -41 -38 -36 -41
 -40 -41 -42 -38 -38 -38 -42 -37 -39 -39 -38 -36 -39 -37 -38 -40 -43 -36
 -37 -37 -41 -39 -41 -37 -40 -42 -38 -37 -38 -37 -38 -42 -41 -42 -38 -42
 -40 -46 -39 -37 -42 -41 -37 -40 -38 -42]
maxs: [38 39 36 36 39 37 36 38 36 39 38 37 39 40 37 40 37 38 39 41 38 40 37 38
 36 36 39 39 36 39 37 37 37 38 39 39 41 41 37 36 37 38 35 37 37 40 42 36
 41 37 37 41 40 39 38 40 42 43 38 41 37 36 41 37 37 38 38 39 40 36 38 39
 39 40 39 38 39 37 38 40 39 43 43 44 36 42 44 39 43 39 37 37 38 43 40 38
 40 44 39 39]


### Create Entropy Models

In [4]:
values = []
counts = []
counts_12bit = []
entropies = []
cross_entropies_12bit = []
for quantized_matrix in tqdm(quantized_matrices):
    v, c = np.unique(quantized_matrix, return_counts=True)
    order = np.argsort(c)[::-1]
    v = v[order]
    c = c[order]
    values.append(v)
    counts.append(c)
    entropies.append(np.log2(quantized_matrix.size) - c @ np.log2(c) / quantized_matrix.size)
    c12bit = np.maximum(np.round(c / 256).astype(np.uint32), 1)
    excess = sum(c12bit) - (1 << 12)
    assert excess >= 0 and excess <= len(c)
    if excess != 0:
        assert c12bit[excess - 1] > 1
    c12bit[:excess] -= 1
    counts_12bit.append(c12bit)
    cross_entropies_12bit.append(12 - c12bit @ np.log2(c12bit) / (1 << 12))
entropies = np.array(entropies)
cross_entropies_12bit = np.array(cross_entropies_12bit)
overheads = cross_entropies_12bit - entropies
print(f'Maximum absolute overhead: {overheads.max():.4f}')
print(f'Maximum relative overhead: {(overheads * 100 / entropies).max():.2f} %')
entropies, cross_entropies_12bit

100%|██████████| 100/100 [00:02<00:00, 36.25it/s]

Maximum absolute overhead: 0.0471
Maximum relative overhead: 0.93 %





(array([5.04873129, 5.04657062, 5.04677307, 5.04884595, 5.0473163 ,
        5.0486622 , 5.04661374, 5.04849787, 5.04695842, 5.04865761,
        5.04702978, 5.04732894, 5.04798469, 5.04782882, 5.04794853,
        5.04891135, 5.04778863, 5.0474892 , 5.04845118, 5.04870344,
        5.04692647, 5.0466256 , 5.04761184, 5.04772784, 5.04902663,
        5.04909297, 5.0483021 , 5.0477679 , 5.04680325, 5.0470743 ,
        5.04774866, 5.04810709, 5.04799223, 5.04751116, 5.04805406,
        5.04894296, 5.04914948, 5.04813692, 5.04773768, 5.04766498,
        5.04820142, 5.04674641, 5.04612653, 5.04954575, 5.04798425,
        5.04911074, 5.04806961, 5.04883531, 5.04730741, 5.04820349,
        5.0475805 , 5.04728574, 5.0476444 , 5.04874619, 5.04859174,
        5.04892564, 5.04605025, 5.04941888, 5.04747965, 5.04812066,
        5.04752634, 5.04696448, 5.04841835, 5.0466188 , 5.04660691,
        5.04749418, 5.04658043, 5.04970472, 5.04751697, 5.0466669 ,
        5.0487524 , 5.04709681, 5.0476797 , 5.04

In [5]:
inv_vocabs = [{v: i for i, v in enumerate(vs)} for vs in values]

### Entropy Coder

In [6]:
class AnsCoder:
    def __init__(self, precision, word_size, compressed=[]):
        self.precision = precision
        self.word_size = word_size
        self.word_mask = (1 << word_size) - 1
        self.quantile_mask = (1 << precision) - 1
        self.bulk = compressed.copy()
        self.head = 0
        while len(self.bulk) != 0 and (self.head >> word_size) == 0:
            self.head = (self.head << word_size) | self.bulk.pop()

    def push(self, symbol, m):
        if (self.head >> (2 * self.word_size - self.precision)) >= m[symbol]:
            self.bulk.append(self.head & self.word_mask)
            self.head >>= self.word_size

        z = self.head % m[symbol] + sum(m[0:symbol])
        self.head //= m[symbol]
        self.head = (self.head << self.precision) | z

    def pop(self, m):
        z = self.head & self.quantile_mask
        self.head >>= self.precision
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break
        self.head = self.head * m_symbol + z
        if (self.head >> self.word_size) == 0 and len(self.bulk) != 0:
            self.head = (self.head << self.word_size) | self.bulk.pop()
        return symbol

    def get_compressed(self):
        compressed = self.bulk.copy()
        head = self.head
        while head != 0:
            compressed.append(head & self.word_mask)
            head >>= self.word_size
        return compressed

## Compression 1: independently compressed matrix entries, with non-interleaved matrices

### Compress the matrices

Uncomment the following cell if you've already executed the entropy coding part in a previous run and saved the results to a file, and you now only want to play with the multiplexing part.
Then, don't execute any more cells in this section (as they would overwrite what you just loaded from the pickle file) and instead jump directly to the next section ("Serialize to a file").

In [8]:
# with open("100-compressed-matrices.pickle", "rb") as f:
#     d = pickle.load(f)
#     compressed_stream = d["compressed_stream"]
#     coder_offsets = d["coder_offsets"]
#     del d

The following cell is expensive, but only because the entropy coders are implemented in python, which makes them excruciatingly slow.
Once we've settled on precise coder settings (`precision` and `word_size`), we should move the entropy coding part to a compiled language.

In [9]:
compressed_stream = []
coder_offsets = np.zeros((len(matrices), matrices[0].size), dtype=np.uint32)
coders = [AnsCoder(12, 16, [1, 1]) for _ in range(matrices[0].size)]
for i in tqdm(reversed(range(len(matrices)))):
    model = counts_12bit[i]
    inv_vocab = inv_vocabs[i]
    for j, entry, coder in zip(range(quantized_matrices[i].size - 1, -1, -1), reversed(quantized_matrices[i].ravel()), coders):
        symbol_id = inv_vocab[entry]
        coder.push(symbol_id, model)
        if len(coder.bulk) == 1:
            compressed_stream.append(coder.bulk.pop())
        coder_offsets[i, j] = len(compressed_stream)
    print(f'Encoded matrix {i}; words emitted so far: {coder_offsets[i, -1]}')

1it [00:03,  3.49s/it]

Encoded matrix 99; words emitted so far: 0


2it [00:06,  3.47s/it]

Encoded matrix 98; words emitted so far: 0


3it [00:10,  3.49s/it]

Encoded matrix 97; words emitted so far: 6442


4it [00:14,  3.57s/it]

Encoded matrix 96; words emitted so far: 256628


5it [00:17,  3.52s/it]

Encoded matrix 95; words emitted so far: 1048779


6it [00:21,  3.53s/it]

Encoded matrix 94; words emitted so far: 1061698


7it [00:24,  3.58s/it]

Encoded matrix 93; words emitted so far: 1274851


8it [00:28,  3.58s/it]

Encoded matrix 92; words emitted so far: 2035984


9it [00:31,  3.59s/it]

Encoded matrix 91; words emitted so far: 2114211


10it [00:35,  3.61s/it]

Encoded matrix 90; words emitted so far: 2294396


11it [00:39,  3.60s/it]

Encoded matrix 89; words emitted so far: 2910784


12it [00:42,  3.57s/it]

Encoded matrix 88; words emitted so far: 3164586


13it [00:46,  3.59s/it]

Encoded matrix 87; words emitted so far: 3316701


14it [00:49,  3.59s/it]

Encoded matrix 86; words emitted so far: 3820266


15it [00:53,  3.61s/it]

Encoded matrix 85; words emitted so far: 4203805


16it [00:57,  3.61s/it]

Encoded matrix 84; words emitted so far: 4343774


17it [01:00,  3.61s/it]

Encoded matrix 83; words emitted so far: 4762721


18it [01:04,  3.58s/it]

Encoded matrix 82; words emitted so far: 5207028


19it [01:07,  3.58s/it]

Encoded matrix 81; words emitted so far: 5373775


20it [01:11,  3.58s/it]

Encoded matrix 80; words emitted so far: 5727251


21it [01:15,  3.56s/it]

Encoded matrix 79; words emitted so far: 6184444


22it [01:18,  3.55s/it]

Encoded matrix 78; words emitted so far: 6404354


23it [01:22,  3.59s/it]

Encoded matrix 77; words emitted so far: 6708664


24it [01:25,  3.59s/it]

Encoded matrix 76; words emitted so far: 7152621


25it [01:29,  3.60s/it]

Encoded matrix 75; words emitted so far: 7428625


26it [01:33,  3.59s/it]

Encoded matrix 74; words emitted so far: 7701460


27it [01:36,  3.59s/it]

Encoded matrix 73; words emitted so far: 8120839


28it [01:40,  3.57s/it]

Encoded matrix 72; words emitted so far: 8441861


29it [01:43,  3.58s/it]

Encoded matrix 71; words emitted so far: 8702378


30it [01:47,  3.56s/it]

Encoded matrix 70; words emitted so far: 9093003


31it [01:50,  3.54s/it]

Encoded matrix 69; words emitted so far: 9444023


32it [01:54,  3.53s/it]

Encoded matrix 68; words emitted so far: 9709115


33it [01:57,  3.54s/it]

Encoded matrix 67; words emitted so far: 10069518


34it [02:01,  3.53s/it]

Encoded matrix 66; words emitted so far: 10437342


35it [02:04,  3.54s/it]

Encoded matrix 65; words emitted so far: 10716954


36it [02:08,  3.58s/it]

Encoded matrix 64; words emitted so far: 11052761


37it [02:12,  3.58s/it]

Encoded matrix 63; words emitted so far: 11425705


38it [02:15,  3.59s/it]

Encoded matrix 62; words emitted so far: 11722858


39it [02:19,  3.68s/it]

Encoded matrix 61; words emitted so far: 12041996


40it [02:23,  3.67s/it]

Encoded matrix 60; words emitted so far: 12410758


41it [02:26,  3.67s/it]

Encoded matrix 59; words emitted so far: 12726113


42it [02:30,  3.65s/it]

Encoded matrix 58; words emitted so far: 13035500


43it [02:34,  3.64s/it]

Encoded matrix 57; words emitted so far: 13395635


44it [02:37,  3.62s/it]

Encoded matrix 56; words emitted so far: 13725241


45it [02:41,  3.61s/it]

Encoded matrix 55; words emitted so far: 14032246


46it [02:44,  3.62s/it]

Encoded matrix 54; words emitted so far: 14382222


47it [02:48,  3.62s/it]

Encoded matrix 53; words emitted so far: 14721411


48it [02:52,  3.63s/it]

Encoded matrix 52; words emitted so far: 15030577


49it [02:55,  3.63s/it]

Encoded matrix 51; words emitted so far: 15370724


50it [02:59,  3.66s/it]

Encoded matrix 50; words emitted so far: 15714205


51it [03:03,  3.67s/it]

Encoded matrix 49; words emitted so far: 16029060


52it [03:06,  3.68s/it]

Encoded matrix 48; words emitted so far: 16360831


53it [03:10,  3.68s/it]

Encoded matrix 47; words emitted so far: 16706084


54it [03:14,  3.66s/it]

Encoded matrix 46; words emitted so far: 17026551


55it [03:17,  3.66s/it]

Encoded matrix 45; words emitted so far: 17353267


56it [03:21,  3.65s/it]

Encoded matrix 44; words emitted so far: 17697071


57it [03:25,  3.66s/it]

Encoded matrix 43; words emitted so far: 18023446


58it [03:28,  3.66s/it]

Encoded matrix 42; words emitted so far: 18346814


59it [03:32,  3.67s/it]

Encoded matrix 41; words emitted so far: 18687526


60it [03:36,  3.68s/it]

Encoded matrix 40; words emitted so far: 19018871


61it [03:39,  3.67s/it]

Encoded matrix 39; words emitted so far: 19341563


62it [03:43,  3.67s/it]

Encoded matrix 38; words emitted so far: 19679067


63it [03:47,  3.68s/it]

Encoded matrix 37; words emitted so far: 20013158


64it [03:51,  3.73s/it]

Encoded matrix 36; words emitted so far: 20336871


65it [03:54,  3.72s/it]

Encoded matrix 35; words emitted so far: 20670222


66it [03:58,  3.69s/it]

Encoded matrix 34; words emitted so far: 21006192


67it [04:02,  3.68s/it]

Encoded matrix 33; words emitted so far: 21332328


68it [04:05,  3.68s/it]

Encoded matrix 32; words emitted so far: 21663189


69it [04:09,  3.66s/it]

Encoded matrix 31; words emitted so far: 21999552


70it [04:13,  3.65s/it]

Encoded matrix 30; words emitted so far: 22327116


71it [04:16,  3.65s/it]

Encoded matrix 29; words emitted so far: 22655949


72it [04:20,  3.63s/it]

Encoded matrix 28; words emitted so far: 22991943


73it [04:23,  3.62s/it]

Encoded matrix 27; words emitted so far: 23321390


74it [04:27,  3.60s/it]

Encoded matrix 26; words emitted so far: 23649220


75it [04:31,  3.61s/it]

Encoded matrix 25; words emitted so far: 23984877


76it [04:34,  3.60s/it]

Encoded matrix 24; words emitted so far: 24315658


77it [04:38,  3.60s/it]

Encoded matrix 23; words emitted so far: 24644002


78it [04:41,  3.64s/it]

Encoded matrix 22; words emitted so far: 24977313


79it [04:45,  3.68s/it]

Encoded matrix 21; words emitted so far: 25309423


80it [04:49,  3.69s/it]

Encoded matrix 20; words emitted so far: 25637462


81it [04:53,  3.72s/it]

Encoded matrix 19; words emitted so far: 25969721


82it [04:56,  3.70s/it]

Encoded matrix 18; words emitted so far: 26302485


83it [05:00,  3.71s/it]

Encoded matrix 17; words emitted so far: 26632280


84it [05:04,  3.74s/it]

Encoded matrix 16; words emitted so far: 26963759


85it [05:08,  3.77s/it]

Encoded matrix 15; words emitted so far: 27295997


86it [05:11,  3.75s/it]

Encoded matrix 14; words emitted so far: 27625988


87it [05:15,  3.72s/it]

Encoded matrix 13; words emitted so far: 27956972


88it [05:19,  3.68s/it]

Encoded matrix 12; words emitted so far: 28289047


89it [05:22,  3.65s/it]

Encoded matrix 11; words emitted so far: 28620248


90it [05:26,  3.70s/it]

Encoded matrix 10; words emitted so far: 28950098


91it [05:30,  3.68s/it]

Encoded matrix 9; words emitted so far: 29282317


92it [05:33,  3.67s/it]

Encoded matrix 8; words emitted so far: 29613855


93it [05:37,  3.65s/it]

Encoded matrix 7; words emitted so far: 29943590


94it [05:41,  3.63s/it]

Encoded matrix 6; words emitted so far: 30275837


95it [05:44,  3.62s/it]

Encoded matrix 5; words emitted so far: 30607617


96it [05:48,  3.64s/it]

Encoded matrix 4; words emitted so far: 30937585


97it [05:52,  3.70s/it]

Encoded matrix 3; words emitted so far: 31268964


98it [05:55,  3.66s/it]

Encoded matrix 2; words emitted so far: 31600871


99it [05:59,  3.64s/it]

Encoded matrix 1; words emitted so far: 31931008


100it [06:02,  3.63s/it]

Encoded matrix 0; words emitted so far: 32261992





In [10]:
for coder in tqdm(coders):
    assert len(coder.bulk) == 0
    head = coder.head
    compressed_stream.append(head & coder.word_mask)
    head >>= coder.word_size
    compressed_stream.append(head & coder.word_mask)

compressed_stream = np.array(compressed_stream, dtype=np.uint16)

  0%|          | 0/1048576 [00:00<?, ?it/s]

100%|██████████| 1048576/1048576 [00:00<00:00, 1706337.17it/s]


In [11]:
bitrate = len(compressed_stream) * 16 / quantized_matrices.size
len(compressed_stream), bitrate, bitrate - cross_entropies_12bit.mean()

(34691299, 5.293472137451172, np.float64(0.20617493810071696))

In [12]:
with open("100-compressed-matrices.pickle", "wb") as f:
    pickle.dump({"compressed_stream": compressed_stream, "coder_offsets": coder_offsets}, f)

### Serialize to a file

We generate a binary file by concatenating the following bit strings:

- `coder_heads`: a concatenation of 1024*1024 unsigned 32-bit integers (in little-endian byte order), each representing the initial `head` state of an entropy coder.
  Each entropy coder corresponds to a single matrix element, and their initial `head` states are stored in row-major order.
- `compressed_sizes`: a concatenation of 100 unsigned 32-bit integers (in little-endian byte order), each holding the total size (in multiples of 16 bits) of `matrices[i]`, defined below, where `i ∈ {0, 1, ..., 99}`.
- `matrices`: a concatenation of 100 compressed representations of the matrices; here, the compressed representation of matrix `i` (for `i ∈ {0, 1, ..., 99}`) is a concatenation of `compressed_sizes[i]` unsigned 16-bit integers (in little-endian byte order), constructed by concatenating the following:
  - `data_offset`: a single unsigned 16-bit integer pointing to the `data` section below.
  - `cdf`: a concatenation of `data_offset - 1` unsigned 16-bit integers (in little-endian byte order), representing the cumulative distribution function (CDF) of the entropy model used for encoding the current matrix.
    The CDF is stored in 12-bit precision as a sequence of non-increasing integers, where the first integer is always `0` and the last integer is always `2^12 = 4096`.
    This defines a CDF over `data_offset - 2` discrete symbols.
    Admittedly, the first and last entry of the CDF are redundant (because they always have values `0` and `4096`, respectively), but having them in the compressed data allows us to avoid an additional branch in the decoder implementation.
  - `data`: a (possibly empty) sequence of unsigned 16-bit integers that represent compressed bits.
    The decoder will use these to refill any of the entropy coders that need refilling after decoding the current matrix element.
    For each matrix element, `data` contains either zero or one unsigned 16-bit integer, and they are concatenated in row-major order.
    It's the decoder's job to figure out which entropy coder needs refilling, and to refill each one with the correct chunk of compressed data.

In [14]:
def serialize_matrix(i):
    cdf = np.array([counts_12bit[i][:j].sum() for j in range(len(counts_12bit[i]) + 1)], dtype=np.uint16)
    data_offset = np.array([len(cdf) + 1], dtype=np.uint16)
    start = coder_offsets[i, 0]
    end = start if i == 0 else coder_offsets[i - 1, 0]
    data = compressed_stream[start:end][::-1]
    return np.concatenate([data_offset, cdf, data])

In [15]:
serialized_sizes = np.array([serialize_matrix(i).size for i in range(len(quantized_matrices))], dtype=np.uint32)
serialized_sizes, serialized_sizes.max()

(array([    78, 332235, 331061, 330215, 331988, 331457, 330044, 331859,
        332323, 329813, 331619, 332297, 329928, 331279, 332151, 331064,
        330067, 332317, 331558, 329876, 332844, 332337, 328115, 332189,
        333387, 328422, 330859, 335738, 327906, 329527, 336071, 328908,
        327643, 336439, 330939, 326216, 336048, 333430, 323791, 334168,
        337581, 322772, 331419, 340791, 323447, 326453, 343883, 326795,
        320546, 345331, 331849, 314936, 343557, 340226, 309246, 339266,
        350058, 307082, 329686, 360214, 309465, 315432, 368840, 319216,
        297231, 373020, 335886, 279688, 367901, 360480, 265171, 351098,
        390702, 260596, 321100, 419457, 272915, 276082, 444035, 304390,
        219990, 457270, 353554, 166826, 444383, 419026, 140051, 383617,
        503643, 152196, 253879, 616465, 180265,  78305, 761215, 213233,
         12997, 792229, 250265,   6522], dtype=uint32),
 np.uint32(792229))

In [16]:
with open("100-compressed-matrices.bin", "wb") as f:
    initial_coder_heads = compressed_stream[coder_offsets[0, 0]:][::-1]
    initial_coder_heads.tofile(f)
    serialized_sizes_lo = (serialized_sizes & ((1 << 16) - 1)).astype(np.uint16)
    serialized_sizes_hi = (serialized_sizes >> 16).astype(np.uint16)
    serialized_sizes_le = np.vstack([serialized_sizes_lo, serialized_sizes_hi]).T.ravel()
    serialized_sizes_le.tofile(f)
    for i in range(len(quantized_matrices)):
        m = serialize_matrix(i)
        m.tofile(f)
    # NOTE: this doesn't serialize the final compressed data that would be necessary
    # to reconstruct the original coder states (which is currently always `[1, 0]`).

In [17]:
!ls -lh 100-compressed-matrices.bin

-rw-rw-r-- 1 robamler robamler 67M Feb 28 16:42 100-compressed-matrices.bin


## Debugging

In [18]:
# Checksum adapted from rust's FxHash, but restricted to 26 bit hashes so that the
# multiplication doesn't exceed the range of exactly representable integers in JavaScript.
def get_checksum(i, cap=None):
    checksum = 0
    m = quantized_matrices[i].ravel()
    if cap is not None:
        m = m[:cap]
    for value in m:
        value = inv_vocabs[i][value]
        checksum = ((checksum & 0x001f_ffff) << 5) | (checksum >> 21) # rotate
        checksum = (checksum ^ value) * 0x0322_0a95
        checksum = checksum & 0x03ff_ffff # truncate to 26 bit
    return checksum

In [23]:
for i in range(100):
    print(f'checksum of matrix {i}: {get_checksum(i)}')

checksum of matrix 0: 64364773
checksum of matrix 1: 64612196
checksum of matrix 2: 60176446
checksum of matrix 3: 10715593
checksum of matrix 4: 37212993
checksum of matrix 5: 19062629
checksum of matrix 6: 44602562
checksum of matrix 7: 7827136
checksum of matrix 8: 64606195
checksum of matrix 9: 57267776
checksum of matrix 10: 2963741
checksum of matrix 11: 15677458
checksum of matrix 12: 60359594
checksum of matrix 13: 63603614
checksum of matrix 14: 49520088
checksum of matrix 15: 24331276
checksum of matrix 16: 41146240
checksum of matrix 17: 22772352
checksum of matrix 18: 24032544
checksum of matrix 19: 47519280
checksum of matrix 20: 24127898
checksum of matrix 21: 35338592
checksum of matrix 22: 23628479
checksum of matrix 23: 37124762
checksum of matrix 24: 65280212
checksum of matrix 25: 3816441
checksum of matrix 26: 59399301
checksum of matrix 27: 65826698
checksum of matrix 28: 57103306
checksum of matrix 29: 52778725
checksum of matrix 30: 54218222
checksum of matrix 31