# Mock Data for Compressed Matrices

In [1]:
import numpy as np
import constriction
from tqdm import tqdm
import pickle

## Preparation

### Create Random Quantized Matrices

In [2]:
np.random.seed(202404151)
matrices = np.random.randn(100, 1024, 1024)

In [3]:
quantized_matrices = np.round(matrices * 8).astype(np.int8)
print(f'mins: {quantized_matrices.min(axis=(1, 2))}')
print(f'maxs: {quantized_matrices.max(axis=(1, 2))}')

mins: [-41 -38 -40 -38 -44 -38 -37 -41 -41 -37 -39 -40 -39 -37 -39 -40 -37 -41
 -40 -40 -39 -41 -38 -42 -39 -39 -42 -39 -41 -41 -38 -39 -40 -38 -38 -38
 -42 -41 -37 -44 -38 -38 -37 -39 -39 -36 -39 -42 -40 -40 -41 -38 -36 -41
 -40 -41 -42 -38 -38 -38 -42 -37 -39 -39 -38 -36 -39 -37 -38 -40 -43 -36
 -37 -37 -41 -39 -41 -37 -40 -42 -38 -37 -38 -37 -38 -42 -41 -42 -38 -42
 -40 -46 -39 -37 -42 -41 -37 -40 -38 -42]
maxs: [38 39 36 36 39 37 36 38 36 39 38 37 39 40 37 40 37 38 39 41 38 40 37 38
 36 36 39 39 36 39 37 37 37 38 39 39 41 41 37 36 37 38 35 37 37 40 42 36
 41 37 37 41 40 39 38 40 42 43 38 41 37 36 41 37 37 38 38 39 40 36 38 39
 39 40 39 38 39 37 38 40 39 43 43 44 36 42 44 39 43 39 37 37 38 43 40 38
 40 44 39 39]


### Create Entropy Models

In [4]:
values = []
counts = []
counts_16bit = []
entropies = []
cross_entropies_16bit = []
for quantized_matrix in tqdm(quantized_matrices):
    v, c = np.unique(quantized_matrix, return_counts=True)
    order = np.argsort(c)[::-1]
    v = v[order]
    c = c[order]
    values.append(v)
    counts.append(c)
    entropies.append(np.log2(quantized_matrix.size) - c @ np.log2(c) / quantized_matrix.size)
    c16bit = np.maximum(np.round(c / 16).astype(np.int32), 1)

    # We let Probabilities add up to `(1<<16) - 1` rather than `1<<16` so that we can represent
    # the entire cdf in 16 bit values, which simplifies decoding. Note that this would (probably)
    # make bits-back coding invalid since the AnsCoder is no longer surjective, but we don't do
    # bits-back coding anyway.
    excess = sum(c16bit) - ((1 << 16) - 1)
    assert excess >= 0 and excess <= len(c)
    if excess != 0:
        assert c16bit[excess - 1] > 1
    c16bit[:excess] -= 1
    counts_16bit.append(c16bit)
    cross_entropies_16bit.append(16 - c16bit @ np.log2(c16bit) / (1 << 16))
entropies = np.array(entropies)
cross_entropies_16bit = np.array(cross_entropies_16bit)
overheads = cross_entropies_16bit - entropies
print(f'Maximum absolute overhead: {overheads.max():.4f}')
print(f'Maximum relative overhead: {(overheads * 100 / entropies).max():.2f} %')
entropies, cross_entropies_16bit

100%|██████████| 100/100 [00:03<00:00, 32.80it/s]

Maximum absolute overhead: 0.0028
Maximum relative overhead: 0.05 %





(array([5.04873129, 5.04657062, 5.04677307, 5.04884595, 5.0473163 ,
        5.0486622 , 5.04661374, 5.04849787, 5.04695842, 5.04865761,
        5.04702978, 5.04732894, 5.04798469, 5.04782882, 5.04794853,
        5.04891135, 5.04778863, 5.0474892 , 5.04845118, 5.04870344,
        5.04692647, 5.0466256 , 5.04761184, 5.04772784, 5.04902663,
        5.04909297, 5.0483021 , 5.0477679 , 5.04680325, 5.0470743 ,
        5.04774866, 5.04810709, 5.04799223, 5.04751116, 5.04805406,
        5.04894296, 5.04914948, 5.04813692, 5.04773768, 5.04766498,
        5.04820142, 5.04674641, 5.04612653, 5.04954575, 5.04798425,
        5.04911074, 5.04806961, 5.04883531, 5.04730741, 5.04820349,
        5.0475805 , 5.04728574, 5.0476444 , 5.04874619, 5.04859174,
        5.04892564, 5.04605025, 5.04941888, 5.04747965, 5.04812066,
        5.04752634, 5.04696448, 5.04841835, 5.0466188 , 5.04660691,
        5.04749418, 5.04658043, 5.04970472, 5.04751697, 5.0466669 ,
        5.0487524 , 5.04709681, 5.0476797 , 5.04

In [5]:
inv_vocabs = [{v: i for i, v in enumerate(vs)} for vs in values]

### Entropy Coder

In [6]:
class AnsCoder:
    def __init__(self, precision, word_size, compressed=[]):
        self.precision = precision
        self.word_size = word_size
        self.word_mask = (1 << word_size) - 1
        self.quantile_mask = (1 << precision) - 1
        self.bulk = compressed.copy()
        self.head = 0
        while len(self.bulk) != 0 and (self.head >> word_size) == 0:
            self.head = (self.head << word_size) | self.bulk.pop()

    def push(self, symbol, m):
        if (self.head >> (2 * self.word_size - self.precision)) >= m[symbol]:
            self.bulk.append(self.head & self.word_mask)
            self.head >>= self.word_size

        z = self.head % m[symbol] + sum(m[0:symbol])
        self.head //= m[symbol]
        self.head = (self.head << self.precision) | z

    def pop(self, m):
        z = self.head & self.quantile_mask
        self.head >>= self.precision
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break
        self.head = self.head * m_symbol + z
        if (self.head >> self.word_size) == 0 and len(self.bulk) != 0:
            self.head = (self.head << self.word_size) | self.bulk.pop()
        return symbol

    def get_compressed(self):
        compressed = self.bulk.copy()
        head = self.head
        while head != 0:
            compressed.append(head & self.word_mask)
            head >>= self.word_size
        return compressed

## Compression 1: independently compressed matrix entries, with non-interleaved matrices

### Compress matrices

In [7]:
# with open("100-compressed-matrices.pickle", "rb") as f:
#     d = pickle.load(f)
#     compressed_stream = d["compressed_stream"]
#     coder_offsets = d["coder_offsets"]
#     del d

In [None]:
compressed_stream = []
coder_offsets = np.zeros((len(matrices), matrices[0].size), dtype=np.uint32)
coders = [AnsCoder(16, 16, [1, 1]) for _ in range(matrices[0].size)]
for i in tqdm(reversed(range(len(matrices)))):
    model = counts_16bit[i]
    inv_vocab = inv_vocabs[i]
    for j, entry, coder in zip(range(quantized_matrices[i].size - 1, -1, -1), reversed(quantized_matrices[i].ravel()), coders):
        symbol_id = inv_vocab[entry]
        coder.push(symbol_id, model)
        if len(coder.bulk) == 1:
            compressed_stream.append(coder.bulk[0])
            coder.bulk = []
        coder_offsets[i, j] = len(compressed_stream)
    print(f'Encoded matrix {i}; words emitted so far: {coder_offsets[i, -1]}')

1it [00:10, 10.32s/it]

Encoded matrix 99; words emitted so far: 0


2it [00:16,  7.65s/it]

Encoded matrix 98; words emitted so far: 127


3it [00:22,  6.87s/it]

Encoded matrix 97; words emitted so far: 6467


4it [00:28,  6.56s/it]

Encoded matrix 96; words emitted so far: 253130


5it [00:33,  6.28s/it]

Encoded matrix 95; words emitted so far: 1049018


6it [00:39,  6.16s/it]

Encoded matrix 94; words emitted so far: 1063608


7it [00:45,  6.16s/it]

Encoded matrix 93; words emitted so far: 1278127


8it [00:52,  6.12s/it]

Encoded matrix 92; words emitted so far: 2023904


9it [00:58,  6.09s/it]

Encoded matrix 91; words emitted so far: 2116859


10it [01:04,  6.07s/it]

Encoded matrix 90; words emitted so far: 2298457


11it [01:10,  6.03s/it]

Encoded matrix 89; words emitted so far: 2896717


12it [01:15,  6.01s/it]

Encoded matrix 88; words emitted so far: 3167526


13it [01:22,  6.10s/it]

Encoded matrix 87; words emitted so far: 3322073


14it [01:28,  6.14s/it]

Encoded matrix 86; words emitted so far: 3811291


15it [01:34,  6.09s/it]

Encoded matrix 85; words emitted so far: 4202162


16it [01:40,  6.10s/it]

Encoded matrix 84; words emitted so far: 4349498


17it [01:46,  6.11s/it]

Encoded matrix 83; words emitted so far: 4757518


18it [01:52,  6.12s/it]

Encoded matrix 82; words emitted so far: 5199084


19it [02:00,  6.61s/it]

Encoded matrix 81; words emitted so far: 5379838


20it [02:06,  6.49s/it]

Encoded matrix 80; words emitted so far: 5725565


21it [02:13,  6.41s/it]

Encoded matrix 79; words emitted so far: 6173949


22it [02:19,  6.35s/it]

Encoded matrix 78; words emitted so far: 6409281


23it [02:25,  6.45s/it]

Encoded matrix 77; words emitted so far: 6709065


24it [02:32,  6.39s/it]

Encoded matrix 76; words emitted so far: 7142597


25it [02:38,  6.40s/it]

Encoded matrix 75; words emitted so far: 7430117


26it [02:45,  6.45s/it]

Encoded matrix 74; words emitted so far: 7703842


27it [02:51,  6.38s/it]

Encoded matrix 73; words emitted so far: 8112292


28it [02:57,  6.27s/it]

Encoded matrix 72; words emitted so far: 8440046


29it [03:03,  6.19s/it]

Encoded matrix 71; words emitted so far: 8705917


30it [03:09,  6.13s/it]

Encoded matrix 70; words emitted so far: 9085751


31it [03:15,  6.09s/it]

Encoded matrix 69; words emitted so far: 9439095


32it [03:21,  6.08s/it]

Encoded matrix 68; words emitted so far: 9712584


33it [03:27,  6.08s/it]

Encoded matrix 67; words emitted so far: 10064987


34it [03:33,  6.09s/it]

Encoded matrix 66; words emitted so far: 10431059


35it [03:39,  6.08s/it]

Encoded matrix 65; words emitted so far: 10719160


36it [03:45,  6.07s/it]

Encoded matrix 64; words emitted so far: 11050484


37it [03:51,  6.06s/it]

Encoded matrix 63; words emitted so far: 11418528


38it [03:57,  6.07s/it]

Encoded matrix 62; words emitted so far: 11723217


39it [04:03,  6.03s/it]

Encoded matrix 61; words emitted so far: 12040963


40it [04:10,  6.09s/it]

Encoded matrix 60; words emitted so far: 12404148


41it [04:16,  6.12s/it]

Encoded matrix 59; words emitted so far: 12724647


42it [04:22,  6.15s/it]

Encoded matrix 58; words emitted so far: 13035795


43it [04:28,  6.14s/it]

Encoded matrix 57; words emitted so far: 13389567


44it [04:34,  6.16s/it]

Encoded matrix 56; words emitted so far: 13722426


45it [04:40,  6.15s/it]

Encoded matrix 55; words emitted so far: 14032421


46it [04:47,  6.12s/it]

Encoded matrix 54; words emitted so far: 14377227


47it [04:53,  6.11s/it]

Encoded matrix 53; words emitted so far: 14716950


48it [04:59,  6.09s/it]

Encoded matrix 52; words emitted so far: 15030454


49it [05:05,  6.08s/it]

Encoded matrix 51; words emitted so far: 15367184


50it [05:11,  6.08s/it]

Encoded matrix 50; words emitted so far: 15709227


51it [05:17,  6.06s/it]

Encoded matrix 49; words emitted so far: 16027730


52it [05:23,  6.05s/it]

Encoded matrix 48; words emitted so far: 16358400


53it [05:29,  6.08s/it]

Encoded matrix 47; words emitted so far: 16700575


54it [05:35,  6.08s/it]

Encoded matrix 46; words emitted so far: 17024312


55it [05:41,  6.11s/it]

Encoded matrix 45; words emitted so far: 17351048


56it [05:48,  6.19s/it]

Encoded matrix 44; words emitted so far: 17692095


57it [05:54,  6.15s/it]

Encoded matrix 43; words emitted so far: 18020467


58it [06:00,  6.13s/it]

Encoded matrix 42; words emitted so far: 18345024


59it [06:06,  6.10s/it]

Encoded matrix 41; words emitted so far: 18683074


60it [06:12,  6.10s/it]

Encoded matrix 40; words emitted so far: 19014897


61it [06:18,  6.11s/it]

Encoded matrix 39; words emitted so far: 19339450


62it [06:24,  6.10s/it]

Encoded matrix 38; words emitted so far: 19674714


63it [06:30,  6.15s/it]

Encoded matrix 37; words emitted so far: 20009134


64it [06:37,  6.17s/it]

Encoded matrix 36; words emitted so far: 20334371


65it [06:43,  6.16s/it]

Encoded matrix 35; words emitted so far: 20666542


66it [06:49,  6.09s/it]

Encoded matrix 34; words emitted so far: 21001766


67it [06:55,  6.05s/it]

Encoded matrix 33; words emitted so far: 21329015


68it [07:01,  6.02s/it]

Encoded matrix 32; words emitted so far: 21659660


69it [07:06,  5.99s/it]

Encoded matrix 31; words emitted so far: 21994849


70it [07:12,  5.97s/it]

Encoded matrix 30; words emitted so far: 22323777


71it [07:18,  5.97s/it]

Encoded matrix 29; words emitted so far: 22652649


72it [07:24,  5.98s/it]

Encoded matrix 28; words emitted so far: 22987681


73it [07:30,  5.99s/it]

Encoded matrix 27; words emitted so far: 23317647


74it [07:36,  6.01s/it]

Encoded matrix 26; words emitted so far: 23646514


75it [07:43,  6.04s/it]

Encoded matrix 25; words emitted so far: 23980137


76it [07:49,  6.04s/it]

Encoded matrix 24; words emitted so far: 24311343


77it [07:55,  6.04s/it]

Encoded matrix 23; words emitted so far: 24640586


78it [08:01,  6.03s/it]

Encoded matrix 22; words emitted so far: 24972890


79it [08:07,  6.03s/it]

Encoded matrix 21; words emitted so far: 25305310


80it [08:13,  6.04s/it]

Encoded matrix 20; words emitted so far: 25634052


81it [08:19,  6.03s/it]

Encoded matrix 19; words emitted so far: 25965545


82it [08:25,  6.02s/it]

Encoded matrix 18; words emitted so far: 26298086


83it [08:31,  6.11s/it]

Encoded matrix 17; words emitted so far: 26628267


84it [08:37,  6.10s/it]

Encoded matrix 16; words emitted so far: 26959398


85it [08:43,  6.04s/it]

Encoded matrix 15; words emitted so far: 27291498


86it [08:49,  6.00s/it]

Encoded matrix 14; words emitted so far: 27621714


87it [08:55,  5.95s/it]

Encoded matrix 13; words emitted so far: 27952429


88it [09:01,  5.93s/it]

Encoded matrix 12; words emitted so far: 28284730


89it [09:06,  5.90s/it]

Encoded matrix 11; words emitted so far: 28615579


90it [09:12,  5.88s/it]

Encoded matrix 10; words emitted so far: 28946137


91it [09:18,  5.87s/it]

Encoded matrix 9; words emitted so far: 29277624


92it [09:24,  5.86s/it]

Encoded matrix 8; words emitted so far: 29609384


93it [09:30,  5.85s/it]

Encoded matrix 7; words emitted so far: 29939077


94it [09:36,  5.85s/it]

Encoded matrix 6; words emitted so far: 30271042


95it [09:41,  5.85s/it]

Encoded matrix 5; words emitted so far: 30602517


96it [09:48,  5.95s/it]

Encoded matrix 4; words emitted so far: 30932717


97it [09:54,  5.99s/it]

Encoded matrix 3; words emitted so far: 31264239


98it [10:00,  6.00s/it]

Encoded matrix 2; words emitted so far: 31595550


99it [10:06,  6.01s/it]

Encoded matrix 1; words emitted so far: 31926728


100it [10:12,  6.13s/it]

Encoded matrix 0; words emitted so far: 32257476





In [None]:
for coder in coders:
    assert len(coder.bulk) == 0
    compressed = coder.get_compressed()
    assert len(compressed) == 2
    compressed_stream += compressed
compressed_stream = np.array(compressed_stream, dtype=np.uint16)

In [None]:
bitrate = len(compressed_stream) * 16 / quantized_matrices.size
len(compressed_stream), bitrate, bitrate - cross_entropies_16bit.mean()

(34686512, 5.29274169921875, 0.24308005163829893)

In [None]:
with open("100-compressed-matrices.pickle", "wb") as f:
    pickle.dump({"compressed_stream": compressed_stream, "coder_offsets": coder_offsets}, f)

In [9]:
with open("100-compressed-matrices.pickle", "rb") as f:
    d = pickle.load(f)
    compressed_stream = d["compressed_stream"]
    coder_offsets = d["coder_offsets"]
    del d

### Serialize to a File

In [None]:
np.array([counts_16bit[0][:j].sum() for j in range(len(counts_16bit[0]) + 1)], dtype=np.uint16)

array([    0,  3268,  6515,  9741, 12901, 16054, 19106, 22148, 25042,
       27922, 30621, 33314, 35771, 38206, 40446, 42679, 44669, 46653,
       48401, 50118, 51617, 53092, 54372, 55639, 56711, 57770, 58643,
       59505, 60220, 60925, 61494, 62061, 62510, 62953, 63301, 63634,
       63903, 64164, 64364, 64555, 64701, 64845, 64949, 65051, 65127,
       65199, 65253, 65306, 65344, 65379, 65405, 65431, 65448, 65464,
       65476, 65487, 65495, 65502, 65507, 65512, 65515, 65517, 65519,
       65521, 65523, 65524, 65525, 65526, 65527, 65528, 65529, 65530,
       65531, 65532, 65533, 65534, 65535], dtype=uint16)

In [None]:
def serialize_matrix(i):
    cdf = np.array([counts_16bit[i][:j].sum() for j in range(len(counts_16bit[i]) + 1)], dtype=np.uint16)
    data_offset = np.array([len(cdf) + 1], dtype=np.uint16)
    start = coder_offsets[i, 0]
    end = start if i == 0 else coder_offsets[i - 1, 0]
    data = compressed_stream[start:end][::-1]
    return np.concatenate([data_offset, cdf, data])

In [None]:
serialized_sizes = np.array([serialize_matrix(i).size for i in range(len(quantized_matrices))], dtype=np.uint32)
serialized_sizes, serialized_sizes.max()

(array([    78, 331964, 330825, 331256, 331392, 331600, 330276, 331554,
        332041, 329771, 331841, 331565, 330636, 330927, 332377, 330795,
        330292, 332179, 331210, 330262, 332621, 331571, 328818, 332499,
        332380, 329321, 331284, 333704, 328943, 330046, 335109, 328947,
        329007, 335265, 330723, 327329, 335302, 332250, 325315, 334497,
        335341, 324633, 331897, 338129, 324636, 328450, 341126, 326815,
        323816, 342253, 330748, 318584, 342119, 336809, 313584, 339800,
        344888, 310072, 332939, 353851, 311226, 320576, 363263, 317824,
        304767, 368120, 331403, 288177, 366149, 352480, 273568, 353422,
        379911, 265950, 327832, 408528, 273805, 287598, 433610, 299864,
        235412, 448461, 345804, 180835, 441641, 408099, 147418, 390949,
        489296, 154628, 270886, 598337, 181678,  93033, 745859, 214599,
         14668, 795966, 246742,   6420], dtype=uint32),
 795966)

In [None]:
with open("100-compressed-matrices.bin", "wb") as f:
    initial_coder_heads = compressed_stream[coder_offsets[0, 0]:][::-1]
    initial_coder_heads.tofile(f)
    serialized_sizes_lo = (serialized_sizes & ((1 << 16) - 1)).astype(np.uint16)
    serialized_sizes_hi = (serialized_sizes >> 16).astype(np.uint16)
    serialized_sizes_le = np.vstack([serialized_sizes_lo, serialized_sizes_hi]).T.ravel()
    serialized_sizes_le.tofile(f)
    for i in range(len(quantized_matrices)):
        m = serialize_matrix(i)
        m.tofile(f)
    # NOTE: this doesn't serialize the final compressed data that would be necessary
    # to reconstruct the original coder states (which is currently always `[1, 0]`).

In [None]:
!ls -lh 100-compressed-matrices.bin

-rw-rw-r-- 1 robamler robamler 67M May 13 12:14 100-compressed-matrices.bin


## Debugging

In [52]:
# Checksum adapted from rust's FxHash, but restricted to 26 bit hashes so that the
# multiplication doesn't exceed the range of exactly representable integers in JavaScript.
def get_checksum(i):
    checksum = 0
    for value in quantized_matrices[i].ravel():
        value = inv_vocabs[i][value]
        checksum = ((checksum & 0x001f_ffff) << 5) | (checksum >> 21) # rotate
        checksum = (checksum ^ value) * 0x0322_0a95
        checksum = checksum & 0x03ff_ffff # truncate to 26 bit
    return checksum

In [53]:
for i in range(100):
    print(f'checksum of matrix {i}: {get_checksum(i)}')

checksum of matrix 0: 55445785
checksum of matrix 1: 15163511
checksum of matrix 2: 6664559
checksum of matrix 3: 15747500
checksum of matrix 4: 5573655
checksum of matrix 5: 27711411
checksum of matrix 6: 20047205
checksum of matrix 7: 10641525
checksum of matrix 8: 64606195
checksum of matrix 9: 65232545
checksum of matrix 10: 51903642
checksum of matrix 11: 57475160
checksum of matrix 12: 54761909
checksum of matrix 13: 30428408
checksum of matrix 14: 22313942
checksum of matrix 15: 33179881
checksum of matrix 16: 37189258
checksum of matrix 17: 16278593
checksum of matrix 18: 56573661
checksum of matrix 19: 41121083
checksum of matrix 20: 55929650
checksum of matrix 21: 43756132
checksum of matrix 22: 28003007
checksum of matrix 23: 56424187
checksum of matrix 24: 28589585
checksum of matrix 25: 37515944
checksum of matrix 26: 58783295
checksum of matrix 27: 57634150
checksum of matrix 28: 6008159
checksum of matrix 29: 16510335
checksum of matrix 30: 64138918
checksum of matrix 31