# Mock Data for Compressed Matrices

In [1]:
import numpy as np
import constriction
from tqdm import tqdm
import pickle

## Preparation

### Create Random Quantized Matrices

In [2]:
np.random.seed(202404151)
matrices = np.random.randn(100, 1024, 1024)

In [3]:
quantized_matrices = np.round(matrices * 8).astype(np.int8)
print(f'mins: {quantized_matrices.min(axis=(1, 2))}')
print(f'maxs: {quantized_matrices.max(axis=(1, 2))}')

mins: [-41 -38 -40 -38 -44 -38 -37 -41 -41 -37 -39 -40 -39 -37 -39 -40 -37 -41
 -40 -40 -39 -41 -38 -42 -39 -39 -42 -39 -41 -41 -38 -39 -40 -38 -38 -38
 -42 -41 -37 -44 -38 -38 -37 -39 -39 -36 -39 -42 -40 -40 -41 -38 -36 -41
 -40 -41 -42 -38 -38 -38 -42 -37 -39 -39 -38 -36 -39 -37 -38 -40 -43 -36
 -37 -37 -41 -39 -41 -37 -40 -42 -38 -37 -38 -37 -38 -42 -41 -42 -38 -42
 -40 -46 -39 -37 -42 -41 -37 -40 -38 -42]
maxs: [38 39 36 36 39 37 36 38 36 39 38 37 39 40 37 40 37 38 39 41 38 40 37 38
 36 36 39 39 36 39 37 37 37 38 39 39 41 41 37 36 37 38 35 37 37 40 42 36
 41 37 37 41 40 39 38 40 42 43 38 41 37 36 41 37 37 38 38 39 40 36 38 39
 39 40 39 38 39 37 38 40 39 43 43 44 36 42 44 39 43 39 37 37 38 43 40 38
 40 44 39 39]


### Create Entropy Models

In [20]:
values = []
counts = []
counts_16bit = []
entropies = []
cross_entropies_16bit = []
for quantized_matrix in tqdm(quantized_matrices):
    v, c = np.unique(quantized_matrix, return_counts=True)
    order = np.argsort(c)[::-1]
    v = v[order]
    c = c[order]
    values.append(v)
    counts.append(c)
    entropies.append(np.log2(quantized_matrix.size) - c @ np.log2(c) / quantized_matrix.size)
    c16bit = np.maximum(np.round(c / 16).astype(np.int32), 1)

    # We let Probabilities add up to `(1<<16) - 1` rather than `1<<16` so that we can represent
    # the entire cdf in 16 bit values, which simplifies decoding. Note that this would (probably)
    # make bits-back coding invalid since the AnsCoder is no longer surjective, but we don't do
    # bits-back coding anyway.
    excess = sum(c16bit) - ((1 << 16) - 1)
    assert excess >= 0 and excess <= len(c)
    if excess != 0:
        assert c16bit[excess - 1] > 1
    c16bit[:excess] -= 1
    counts_16bit.append(c16bit)
    cross_entropies_16bit.append(16 - c16bit @ np.log2(c16bit) / (1 << 16))
entropies = np.array(entropies)
cross_entropies_16bit = np.array(cross_entropies_16bit)
overheads = cross_entropies_16bit - entropies
print(f'Maximum absolute overhead: {overheads.max():.4f}')
print(f'Maximum relative overhead: {(overheads * 100 / entropies).max():.2f} %')
entropies, cross_entropies_16bit

100%|██████████| 100/100 [00:18<00:00,  5.40it/s]

Maximum absolute overhead: 0.0028
Maximum relative overhead: 0.05 %





(array([5.04873129, 5.04657062, 5.04677307, 5.04884595, 5.0473163 ,
        5.0486622 , 5.04661374, 5.04849787, 5.04695842, 5.04865761,
        5.04702978, 5.04732894, 5.04798469, 5.04782882, 5.04794853,
        5.04891135, 5.04778863, 5.0474892 , 5.04845118, 5.04870344,
        5.04692647, 5.0466256 , 5.04761184, 5.04772784, 5.04902663,
        5.04909297, 5.0483021 , 5.0477679 , 5.04680325, 5.0470743 ,
        5.04774866, 5.04810709, 5.04799223, 5.04751116, 5.04805406,
        5.04894296, 5.04914948, 5.04813692, 5.04773768, 5.04766498,
        5.04820142, 5.04674641, 5.04612653, 5.04954575, 5.04798425,
        5.04911074, 5.04806961, 5.04883531, 5.04730741, 5.04820349,
        5.0475805 , 5.04728574, 5.0476444 , 5.04874619, 5.04859174,
        5.04892564, 5.04605025, 5.04941888, 5.04747965, 5.04812066,
        5.04752634, 5.04696448, 5.04841835, 5.0466188 , 5.04660691,
        5.04749418, 5.04658043, 5.04970472, 5.04751697, 5.0466669 ,
        5.0487524 , 5.04709681, 5.0476797 , 5.04

In [21]:
inv_vocabs = [{v: i for i, v in enumerate(vs)} for vs in values]

### Entropy Coder

In [22]:
class AnsCoder:
    def __init__(self, precision, word_size, compressed=[]):
        self.precision = precision
        self.word_size = word_size
        self.word_mask = (1 << word_size) - 1
        self.quantile_mask = (1 << precision) - 1
        self.bulk = compressed.copy()
        self.head = 0
        while len(self.bulk) != 0 and (self.head >> word_size) == 0:
            self.head = (self.head << word_size) | self.bulk.pop()

    def push(self, symbol, m):
        if (self.head >> (2 * self.word_size - self.precision)) >= m[symbol]:
            self.bulk.append(self.head & self.word_mask)
            self.head >>= self.word_size

        z = self.head % m[symbol] + sum(m[0:symbol])
        self.head //= m[symbol]
        self.head = (self.head << self.precision) | z

    def pop(self, m):
        z = self.head & self.quantile_mask
        self.head >>= self.precision
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break
        self.head = self.head * m_symbol + z
        if (self.head >> self.word_size) == 0 and len(self.bulk) != 0:
            self.head = (self.head << self.word_size) | self.bulk.pop()
        return symbol

    def get_compressed(self):
        compressed = self.bulk.copy()
        head = self.head
        while head != 0:
            compressed.append(head & self.word_mask)
            head >>= self.word_size
        return compressed

## Compression 1: independently compressed matrix entries, with non-interleaved matrices

### Compress matrices

In [23]:
# with open("100-compressed-matrices.pickle", "rb") as f:
#     d = pickle.load(f)
#     compressed_stream = d["compressed_stream"]
#     coder_offsets = d["coder_offsets"]
#     del d

In [24]:
compressed_stream = []
coder_offsets = np.zeros((len(matrices), matrices[0].size), dtype=np.uint32)
coders = [AnsCoder(16, 16, [1, 1]) for _ in range(matrices[0].size)]
for i in tqdm(reversed(range(len(matrices)))):
    model = counts_16bit[i]
    inv_vocab = inv_vocabs[i]
    for j, entry, coder in zip(range(quantized_matrices[i].size - 1, -1, -1), reversed(quantized_matrices[i].ravel()), coders):
        symbol_id = inv_vocab[entry]
        coder.push(symbol_id, model)
        if len(coder.bulk) == 1:
            compressed_stream.append(coder.bulk[0])
            coder.bulk = []
        coder_offsets[i, j] = len(compressed_stream)
    print(f'Encoded matrix {i}; words emitted so far: {coder_offsets[i, -1]}')

1it [00:38, 38.52s/it]

Encoded matrix 99; words emitted so far: 0


2it [01:02, 30.03s/it]

Encoded matrix 98; words emitted so far: 127


3it [01:26, 27.25s/it]

Encoded matrix 97; words emitted so far: 6467


4it [01:49, 25.59s/it]

Encoded matrix 96; words emitted so far: 253130


5it [02:14, 25.16s/it]

Encoded matrix 95; words emitted so far: 1049018


6it [02:40, 25.62s/it]

Encoded matrix 94; words emitted so far: 1063608


7it [03:05, 25.52s/it]

Encoded matrix 93; words emitted so far: 1278127


8it [03:29, 25.04s/it]

Encoded matrix 92; words emitted so far: 2023904


9it [03:53, 24.48s/it]

Encoded matrix 91; words emitted so far: 2116859


10it [04:17, 24.61s/it]

Encoded matrix 90; words emitted so far: 2298457


11it [04:40, 23.92s/it]

Encoded matrix 89; words emitted so far: 2896717


12it [05:02, 23.46s/it]

Encoded matrix 88; words emitted so far: 3167526


13it [05:28, 24.04s/it]

Encoded matrix 87; words emitted so far: 3322073


14it [05:51, 23.98s/it]

Encoded matrix 86; words emitted so far: 3811291


15it [06:16, 24.26s/it]

Encoded matrix 85; words emitted so far: 4202162


16it [06:40, 24.06s/it]

Encoded matrix 84; words emitted so far: 4349498


17it [07:06, 24.56s/it]

Encoded matrix 83; words emitted so far: 4757518


18it [07:29, 24.33s/it]

Encoded matrix 82; words emitted so far: 5199084


19it [07:55, 24.54s/it]

Encoded matrix 81; words emitted so far: 5379838


20it [08:20, 24.72s/it]

Encoded matrix 80; words emitted so far: 5725565


21it [08:43, 24.24s/it]

Encoded matrix 79; words emitted so far: 6173949


22it [09:08, 24.41s/it]

Encoded matrix 78; words emitted so far: 6409281


23it [09:33, 24.83s/it]

Encoded matrix 77; words emitted so far: 6709065


24it [10:00, 25.50s/it]

Encoded matrix 76; words emitted so far: 7142597


25it [10:29, 26.31s/it]

Encoded matrix 75; words emitted so far: 7430117


26it [10:54, 25.91s/it]

Encoded matrix 74; words emitted so far: 7703842


27it [11:18, 25.31s/it]

Encoded matrix 73; words emitted so far: 8112292


28it [11:43, 25.49s/it]

Encoded matrix 72; words emitted so far: 8440046


29it [12:09, 25.52s/it]

Encoded matrix 71; words emitted so far: 8705917


30it [12:34, 25.20s/it]

Encoded matrix 70; words emitted so far: 9085751


31it [12:58, 24.96s/it]

Encoded matrix 69; words emitted so far: 9439095


32it [13:24, 25.15s/it]

Encoded matrix 68; words emitted so far: 9712584


33it [13:49, 25.18s/it]

Encoded matrix 67; words emitted so far: 10064987


34it [14:14, 25.06s/it]

Encoded matrix 66; words emitted so far: 10431059


35it [14:38, 24.77s/it]

Encoded matrix 65; words emitted so far: 10719160


36it [15:03, 24.80s/it]

Encoded matrix 64; words emitted so far: 11050484


37it [15:31, 25.79s/it]

Encoded matrix 63; words emitted so far: 11418528


38it [15:37, 19.83s/it]

Encoded matrix 62; words emitted so far: 11723220


39it [15:42, 15.62s/it]

Encoded matrix 61; words emitted so far: 12040964


40it [15:54, 14.57s/it]

Encoded matrix 60; words emitted so far: 12404147


41it [16:05, 13.28s/it]

Encoded matrix 59; words emitted so far: 12724645


42it [16:14, 12.04s/it]

Encoded matrix 58; words emitted so far: 13035800


43it [16:19,  9.98s/it]

Encoded matrix 57; words emitted so far: 13389566


44it [16:25,  8.73s/it]

Encoded matrix 56; words emitted so far: 13722423


45it [16:37,  9.62s/it]

Encoded matrix 55; words emitted so far: 14032420


46it [16:47,  9.86s/it]

Encoded matrix 54; words emitted so far: 14377229


47it [16:56,  9.76s/it]

Encoded matrix 53; words emitted so far: 14716943


48it [17:02,  8.41s/it]

Encoded matrix 52; words emitted so far: 15030458


49it [17:07,  7.55s/it]

Encoded matrix 51; words emitted so far: 15367171


50it [17:18,  8.45s/it]

Encoded matrix 50; words emitted so far: 15709229


51it [17:28,  9.11s/it]

Encoded matrix 49; words emitted so far: 16027738


52it [17:39,  9.45s/it]

Encoded matrix 48; words emitted so far: 16358402


53it [17:44,  8.22s/it]

Encoded matrix 47; words emitted so far: 16700571


54it [17:49,  7.36s/it]

Encoded matrix 46; words emitted so far: 17024316


55it [18:00,  8.19s/it]

Encoded matrix 45; words emitted so far: 17351043


56it [18:10,  9.00s/it]

Encoded matrix 44; words emitted so far: 17692096


57it [18:33, 12.98s/it]

Encoded matrix 43; words emitted so far: 18020469


58it [19:05, 18.81s/it]

Encoded matrix 42; words emitted so far: 18345010


59it [19:32, 21.13s/it]

Encoded matrix 41; words emitted so far: 18683072


60it [19:42, 17.82s/it]

Encoded matrix 40; words emitted so far: 19014902


61it [19:48, 14.45s/it]

Encoded matrix 39; words emitted so far: 19339455


62it [19:55, 12.08s/it]

Encoded matrix 38; words emitted so far: 19674720


63it [20:01, 10.33s/it]

Encoded matrix 37; words emitted so far: 20009138


64it [20:07,  9.03s/it]

Encoded matrix 36; words emitted so far: 20334360


65it [20:16,  8.94s/it]

Encoded matrix 35; words emitted so far: 20666556


66it [20:23,  8.28s/it]

Encoded matrix 34; words emitted so far: 21001769


67it [20:29,  7.75s/it]

Encoded matrix 33; words emitted so far: 21329014


68it [20:35,  7.32s/it]

Encoded matrix 32; words emitted so far: 21659662


69it [20:42,  7.14s/it]

Encoded matrix 31; words emitted so far: 21994858


70it [20:49,  7.08s/it]

Encoded matrix 30; words emitted so far: 22323778


71it [20:55,  6.73s/it]

Encoded matrix 29; words emitted so far: 22652648


72it [21:01,  6.44s/it]

Encoded matrix 28; words emitted so far: 22987681


73it [21:07,  6.23s/it]

Encoded matrix 27; words emitted so far: 23317647


74it [21:12,  6.09s/it]

Encoded matrix 26; words emitted so far: 23646511


75it [21:19,  6.27s/it]

Encoded matrix 25; words emitted so far: 23980131


76it [21:26,  6.40s/it]

Encoded matrix 24; words emitted so far: 24311348


77it [21:33,  6.55s/it]

Encoded matrix 23; words emitted so far: 24640586


78it [21:39,  6.64s/it]

Encoded matrix 22; words emitted so far: 24972879


79it [21:45,  6.41s/it]

Encoded matrix 21; words emitted so far: 25305323


80it [21:51,  6.26s/it]

Encoded matrix 20; words emitted so far: 25634047


81it [21:58,  6.28s/it]

Encoded matrix 19; words emitted so far: 25965546


82it [22:04,  6.28s/it]

Encoded matrix 18; words emitted so far: 26298092


83it [22:11,  6.64s/it]

Encoded matrix 17; words emitted so far: 26628261


84it [22:17,  6.50s/it]

Encoded matrix 16; words emitted so far: 26959404


85it [22:25,  6.91s/it]

Encoded matrix 15; words emitted so far: 27291500


86it [22:32,  6.96s/it]

Encoded matrix 14; words emitted so far: 27621713


87it [22:38,  6.66s/it]

Encoded matrix 13; words emitted so far: 27952438


88it [22:47,  7.10s/it]

Encoded matrix 12; words emitted so far: 28284730


89it [22:59,  8.73s/it]

Encoded matrix 11; words emitted so far: 28615584


90it [23:07,  8.50s/it]

Encoded matrix 10; words emitted so far: 28946132


91it [23:16,  8.65s/it]

Encoded matrix 9; words emitted so far: 29277629


92it [23:23,  8.28s/it]

Encoded matrix 8; words emitted so far: 29609392


93it [23:30,  7.82s/it]

Encoded matrix 7; words emitted so far: 29939068


94it [23:37,  7.63s/it]

Encoded matrix 6; words emitted so far: 30271055


95it [23:45,  7.52s/it]

Encoded matrix 5; words emitted so far: 30602529


96it [23:52,  7.52s/it]

Encoded matrix 4; words emitted so far: 30932715


97it [23:58,  7.05s/it]

Encoded matrix 3; words emitted so far: 31264234


98it [24:04,  6.69s/it]

Encoded matrix 2; words emitted so far: 31595555


99it [24:10,  6.41s/it]

Encoded matrix 1; words emitted so far: 31926723


100it [24:16, 14.56s/it]

Encoded matrix 0; words emitted so far: 32257481





In [25]:
for coder in coders:
    assert len(coder.bulk) == 0
    compressed = coder.get_compressed()
    assert len(compressed) == 2
    compressed_stream += compressed
compressed_stream = np.array(compressed_stream, dtype=np.uint16)

In [26]:
bitrate = len(compressed_stream) * 16 / quantized_matrices.size
len(compressed_stream), bitrate, bitrate - cross_entropies_16bit.mean()

(34686522, 5.292743225097656, 0.24308157751720483)

In [27]:
with open("100-compressed-matrices.pickle", "wb") as f:
    pickle.dump({"compressed_stream": compressed_stream, "coder_offsets": coder_offsets}, f)

In [28]:
with open("100-compressed-matrices.pickle", "rb") as f:
    d = pickle.load(f)
    compressed_stream = d["compressed_stream"]
    coder_offsets = d["coder_offsets"]
    del d

### Serialize to a File

In [29]:
np.array([counts_16bit[0][:j].sum() for j in range(len(counts_16bit[0]) + 1)], dtype=np.uint16)

array([    0,  3268,  6515,  9741, 12901, 16054, 19106, 22148, 25042,
       27922, 30621, 33314, 35771, 38206, 40446, 42679, 44669, 46653,
       48401, 50118, 51617, 53092, 54372, 55639, 56711, 57770, 58643,
       59505, 60220, 60925, 61494, 62061, 62510, 62953, 63301, 63634,
       63903, 64164, 64364, 64555, 64701, 64845, 64949, 65051, 65127,
       65199, 65253, 65306, 65344, 65379, 65405, 65431, 65448, 65464,
       65476, 65487, 65495, 65502, 65507, 65512, 65515, 65517, 65519,
       65521, 65523, 65524, 65525, 65526, 65527, 65528, 65529, 65530,
       65531, 65532, 65533, 65534, 65535], dtype=uint16)

In [30]:
def serialize_matrix(i):
    cdf = np.array([counts_16bit[i][:j].sum() for j in range(len(counts_16bit[i]) + 1)], dtype=np.uint16)
    data_offset = np.array([len(cdf) + 1], dtype=np.uint16)
    start = coder_offsets[i, 0]
    end = start if i == 0 else coder_offsets[i - 1, 0]
    data = compressed_stream[start:end][::-1]
    return np.concatenate([data_offset, cdf, data])

In [31]:
serialized_sizes = np.array([serialize_matrix(i).size for i in range(len(quantized_matrices))], dtype=np.uint32)
serialized_sizes, serialized_sizes.max()

(array([    78, 331969, 330835, 331246, 331402, 331597, 330262, 331553,
        332063, 329754, 331844, 331575, 330626, 330932, 332368, 330805,
        330289, 332175, 331222, 330250, 332626, 331577, 328800, 332523,
        332369, 329316, 331295, 333701, 328940, 330046, 335110, 328945,
        328999, 335272, 330726, 327325, 335291, 332275, 325300, 334495,
        335342, 324633, 331904, 338141, 324620, 328451, 341132, 326806,
        323824, 342247, 330742, 318590, 342134, 336792, 313595, 339791,
        344891, 310074, 332937, 353845, 311233, 320575, 363261, 317822,
        304770, 368120, 331403, 288177, 366149, 352480, 273568, 353422,
        379911, 265950, 327832, 408528, 273805, 287598, 433610, 299864,
        235412, 448461, 345804, 180835, 441641, 408099, 147418, 390949,
        489296, 154628, 270886, 598337, 181678,  93033, 745859, 214599,
         14668, 795966, 246742,   6420], dtype=uint32),
 795966)

In [34]:
with open("100-compressed-matrices.bin", "wb") as f:
    initial_coder_heads = compressed_stream[coder_offsets[0, 0]:][::-1]
    initial_coder_heads.tofile(f)
    serialized_sizes_lo = (serialized_sizes & ((1 << 16) - 1)).astype(np.uint16)
    serialized_sizes_hi = (serialized_sizes >> 16).astype(np.uint16)
    serialized_sizes_le = np.vstack([serialized_sizes_lo, serialized_sizes_hi]).T.ravel()
    serialized_sizes_le.tofile(f)
    for i in range(len(quantized_matrices)):
        m = serialize_matrix(i)
        m.tofile(f)
    # NOTE: this doesn't serialize the final compressed data that would be necessary
    # to reconstruct the original coder states (which is currently always `[1, 0]`).

In [35]:
!ls -lh 100-compressed-matrices.bin

-rw-rw-r-- 1 robamler robamler 67M Mai 12 18:48 100-compressed-matrices.bin


## Debugging

In [72]:
def get_checksum(i):
    checksum = 123
    for value in quantized_matrices[i].ravel():
        value = inv_vocabs[i][value]
        checksum = (checksum << 5) | (checksum >> 11) # rotate
        checksum = (checksum + 1) * (value + 1)
        checksum = checksum & 65535 # truncate to 16 bit
    return checksum

In [73]:
for i in range(100):
    print(f'checksum of matrix {i}: {get_checksum(i)}')

checksum of matrix 0: 49452
checksum of matrix 1: 5155
checksum of matrix 2: 28836
checksum of matrix 3: 9171
checksum of matrix 4: 37580
checksum of matrix 5: 43096
checksum of matrix 6: 23220
checksum of matrix 7: 19944
checksum of matrix 8: 16426
checksum of matrix 9: 37286
checksum of matrix 10: 65497
checksum of matrix 11: 32631
checksum of matrix 12: 55362
checksum of matrix 13: 23313
checksum of matrix 14: 56082
checksum of matrix 15: 61656
checksum of matrix 16: 40632
checksum of matrix 17: 4821
checksum of matrix 18: 7285
checksum of matrix 19: 33036
checksum of matrix 20: 352
checksum of matrix 21: 45423
checksum of matrix 22: 47965
checksum of matrix 23: 46764
checksum of matrix 24: 52984
checksum of matrix 25: 61631
checksum of matrix 26: 22544
checksum of matrix 27: 43713
checksum of matrix 28: 28384
checksum of matrix 29: 28430
checksum of matrix 30: 24583
checksum of matrix 31: 16624
checksum of matrix 32: 17010
checksum of matrix 33: 31050
checksum of matrix 34: 7081
ch