# Mock Data for Compressed Matrices

In [1]:
import numpy as np
import constriction
from tqdm import tqdm
import pickle

## Preparation

### Create Random Quantized Matrices

In [22]:
np.random.seed(202404151)
matrices = np.random.randn(100, 1024, 1024)

In [23]:
quantized_matrices = np.round(matrices * 8).astype(np.int8)
print(f'mins: {quantized_matrices.min(axis=(1, 2))}')
print(f'maxs: {quantized_matrices.max(axis=(1, 2))}')

mins: [-41 -38 -40 -38 -44 -38 -37 -41 -41 -37 -39 -40 -39 -37 -39 -40 -37 -41
 -40 -40 -39 -41 -38 -42 -39 -39 -42 -39 -41 -41 -38 -39 -40 -38 -38 -38
 -42 -41 -37 -44 -38 -38 -37 -39 -39 -36 -39 -42 -40 -40 -41 -38 -36 -41
 -40 -41 -42 -38 -38 -38 -42 -37 -39 -39 -38 -36 -39 -37 -38 -40 -43 -36
 -37 -37 -41 -39 -41 -37 -40 -42 -38 -37 -38 -37 -38 -42 -41 -42 -38 -42
 -40 -46 -39 -37 -42 -41 -37 -40 -38 -42]
maxs: [38 39 36 36 39 37 36 38 36 39 38 37 39 40 37 40 37 38 39 41 38 40 37 38
 36 36 39 39 36 39 37 37 37 38 39 39 41 41 37 36 37 38 35 37 37 40 42 36
 41 37 37 41 40 39 38 40 42 43 38 41 37 36 41 37 37 38 38 39 40 36 38 39
 39 40 39 38 39 37 38 40 39 43 43 44 36 42 44 39 43 39 37 37 38 43 40 38
 40 44 39 39]


### Create Entropy Models

In [4]:
values = []
counts = []
counts_12bit = []
entropies = []
cross_entropies_12bit = []
for quantized_matrix in tqdm(quantized_matrices):
    v, c = np.unique(quantized_matrix, return_counts=True)
    order = np.argsort(c)[::-1]
    v = v[order]
    c = c[order]
    values.append(v)
    counts.append(c)
    entropies.append(np.log2(quantized_matrix.size) - c @ np.log2(c) / quantized_matrix.size)
    c12bit = np.maximum(np.round(c / 256).astype(np.int32), 1)
    excess = sum(c12bit) - (1 << 12)
    assert excess >= 0 and excess <= len(c)
    if excess != 0:
        assert c12bit[excess - 1] > 1
    c12bit[:excess] -= 1
    counts_12bit.append(c12bit)
    cross_entropies_12bit.append(12 - c12bit @ np.log2(c12bit) / (1 << 12))
entropies = np.array(entropies)
cross_entropies_12bit = np.array(cross_entropies_12bit)
overheads = cross_entropies_12bit - entropies
print(f'Maximum absolute overhead: {overheads.max():.4f}')
print(f'Maximum relative overhead: {(overheads * 100 / entropies).max():.2f} %')
entropies, cross_entropies_12bit

100%|██████████| 100/100 [00:15<00:00,  6.47it/s]

Maximum absolute overhead: 0.0471
Maximum relative overhead: 0.93 %





(array([5.04873129, 5.04657062, 5.04677307, 5.04884595, 5.0473163 ,
        5.0486622 , 5.04661374, 5.04849787, 5.04695842, 5.04865761,
        5.04702978, 5.04732894, 5.04798469, 5.04782882, 5.04794853,
        5.04891135, 5.04778863, 5.0474892 , 5.04845118, 5.04870344,
        5.04692647, 5.0466256 , 5.04761184, 5.04772784, 5.04902663,
        5.04909297, 5.0483021 , 5.0477679 , 5.04680325, 5.0470743 ,
        5.04774866, 5.04810709, 5.04799223, 5.04751116, 5.04805406,
        5.04894296, 5.04914948, 5.04813692, 5.04773768, 5.04766498,
        5.04820142, 5.04674641, 5.04612653, 5.04954575, 5.04798425,
        5.04911074, 5.04806961, 5.04883531, 5.04730741, 5.04820349,
        5.0475805 , 5.04728574, 5.0476444 , 5.04874619, 5.04859174,
        5.04892564, 5.04605025, 5.04941888, 5.04747965, 5.04812066,
        5.04752634, 5.04696448, 5.04841835, 5.0466188 , 5.04660691,
        5.04749418, 5.04658043, 5.04970472, 5.04751697, 5.0466669 ,
        5.0487524 , 5.04709681, 5.0476797 , 5.04

In [5]:
inv_vocabs = [{v: i for i, v in enumerate(vs)} for vs in values]

### Entropy Coder

In [6]:
class AnsCoder:
    def __init__(self, precision, word_size, compressed=[]):
        self.precision = precision
        self.word_size = word_size
        self.word_mask = (1 << word_size) - 1
        self.quantile_mask = (1 << precision) - 1
        self.bulk = compressed.copy()
        self.head = 0
        while len(self.bulk) != 0 and (self.head >> word_size) == 0:
            self.head = (self.head << word_size) | self.bulk.pop()

    def push(self, symbol, m):
        if (self.head >> (2 * self.word_size - self.precision)) >= m[symbol]:
            self.bulk.append(self.head & self.word_mask)
            self.head >>= self.word_size

        z = self.head % m[symbol] + sum(m[0:symbol])
        self.head //= m[symbol]
        self.head = (self.head << self.precision) | z

    def pop(self, m):
        z = self.head & self.quantile_mask
        self.head >>= self.precision
        for symbol, m_symbol in enumerate(m):
            if z >= m_symbol:
                z -= m_symbol
            else:
                break
        self.head = self.head * m_symbol + z
        if (self.head >> self.word_size) == 0 and len(self.bulk) != 0:
            self.head = (self.head << self.word_size) | self.bulk.pop()
        return symbol

    def get_compressed(self):
        compressed = self.bulk.copy()
        head = self.head
        while head != 0:
            compressed.append(head & self.word_mask)
            head >>= self.word_size
        return compressed

## Compression 1: independently compressed matrix entries, with non-interleaved matrices

### Compress matrices

In [7]:
# with open("100-compressed-matrices.pickle", "rb") as f:
#     d = pickle.load(f)
#     compressed_stream = d["compressed_stream"]
#     coder_offsets = d["coder_offsets"]
#     del d

In [8]:
compressed_stream = []
coder_offsets = np.zeros((len(matrices), matrices[0].size), dtype=np.uint32)
coders = [AnsCoder(12, 16, [1, 1]) for _ in range(matrices[0].size)]
for i in tqdm(reversed(range(len(matrices)))):
    model = counts_12bit[i]
    inv_vocab = inv_vocabs[i]
    for j, entry, coder in zip(range(quantized_matrices[i].size - 1, -1, -1), reversed(quantized_matrices[i].ravel()), coders):
        symbol_id = inv_vocab[entry]
        coder.push(symbol_id, model)
        if len(coder.bulk) == 1:
            compressed_stream.append(coder.bulk[0])
            coder.bulk = []
        coder_offsets[i, j] = len(compressed_stream)
    print(f'Encoded matrix {i}; words emitted so far: {coder_offsets[i, -1]}')

1it [00:10, 10.33s/it]

Encoded matrix 99; words emitted so far: 0


2it [00:16,  7.64s/it]

Encoded matrix 98; words emitted so far: 0


3it [00:21,  6.83s/it]

Encoded matrix 97; words emitted so far: 6442


4it [00:28,  6.55s/it]

Encoded matrix 96; words emitted so far: 256628


5it [00:33,  6.27s/it]

Encoded matrix 95; words emitted so far: 1048779


6it [00:39,  6.13s/it]

Encoded matrix 94; words emitted so far: 1061698


7it [00:45,  6.09s/it]

Encoded matrix 93; words emitted so far: 1274851


8it [00:51,  6.01s/it]

Encoded matrix 92; words emitted so far: 2035984


9it [00:57,  6.01s/it]

Encoded matrix 91; words emitted so far: 2114211


10it [01:03,  6.04s/it]

Encoded matrix 90; words emitted so far: 2294396


11it [01:09,  6.00s/it]

Encoded matrix 89; words emitted so far: 2910784


12it [01:15,  6.01s/it]

Encoded matrix 88; words emitted so far: 3164586


13it [01:21,  6.01s/it]

Encoded matrix 87; words emitted so far: 3316701


14it [01:27,  5.99s/it]

Encoded matrix 86; words emitted so far: 3820261


15it [01:33,  5.97s/it]

Encoded matrix 85; words emitted so far: 4203805


16it [01:39,  6.01s/it]

Encoded matrix 84; words emitted so far: 4343773


17it [01:45,  6.04s/it]

Encoded matrix 83; words emitted so far: 4762724


18it [01:51,  6.06s/it]

Encoded matrix 82; words emitted so far: 5207027


19it [01:58,  6.11s/it]

Encoded matrix 81; words emitted so far: 5373774


20it [02:04,  6.12s/it]

Encoded matrix 80; words emitted so far: 5727253


21it [02:10,  6.11s/it]

Encoded matrix 79; words emitted so far: 6184446


22it [02:16,  6.19s/it]

Encoded matrix 78; words emitted so far: 6404355


23it [02:22,  6.16s/it]

Encoded matrix 77; words emitted so far: 6708662


24it [02:28,  6.12s/it]

Encoded matrix 76; words emitted so far: 7152625


25it [02:34,  6.08s/it]

Encoded matrix 75; words emitted so far: 7428623


26it [02:40,  6.09s/it]

Encoded matrix 74; words emitted so far: 7701463


27it [02:46,  6.10s/it]

Encoded matrix 73; words emitted so far: 8120840


28it [02:53,  6.12s/it]

Encoded matrix 72; words emitted so far: 8441861


29it [02:59,  6.09s/it]

Encoded matrix 71; words emitted so far: 8702378


30it [03:05,  6.05s/it]

Encoded matrix 70; words emitted so far: 9093002


31it [03:11,  6.01s/it]

Encoded matrix 69; words emitted so far: 9444019


32it [03:16,  5.99s/it]

Encoded matrix 68; words emitted so far: 9709118


33it [03:22,  5.97s/it]

Encoded matrix 67; words emitted so far: 10069516


34it [03:28,  5.94s/it]

Encoded matrix 66; words emitted so far: 10437339


35it [03:34,  5.92s/it]

Encoded matrix 65; words emitted so far: 10716957


36it [03:40,  5.93s/it]

Encoded matrix 64; words emitted so far: 11052762


37it [03:46,  5.93s/it]

Encoded matrix 63; words emitted so far: 11425707


38it [03:52,  5.92s/it]

Encoded matrix 62; words emitted so far: 11722858


39it [03:58,  5.92s/it]

Encoded matrix 61; words emitted so far: 12041996


40it [04:04,  5.91s/it]

Encoded matrix 60; words emitted so far: 12410758


41it [04:10,  5.91s/it]

Encoded matrix 59; words emitted so far: 12726113


42it [04:16,  5.92s/it]

Encoded matrix 58; words emitted so far: 13035503


43it [04:22,  5.92s/it]

Encoded matrix 57; words emitted so far: 13395640


44it [04:27,  5.92s/it]

Encoded matrix 56; words emitted so far: 13725240


45it [04:34,  5.98s/it]

Encoded matrix 55; words emitted so far: 14032249


46it [04:40,  6.07s/it]

Encoded matrix 54; words emitted so far: 14382219


47it [04:46,  6.03s/it]

Encoded matrix 53; words emitted so far: 14721414


48it [04:52,  6.03s/it]

Encoded matrix 52; words emitted so far: 15030574


49it [04:58,  6.13s/it]

Encoded matrix 51; words emitted so far: 15370722


50it [05:04,  6.13s/it]

Encoded matrix 50; words emitted so far: 15714206


51it [05:10,  6.09s/it]

Encoded matrix 49; words emitted so far: 16029058


52it [05:17,  6.13s/it]

Encoded matrix 48; words emitted so far: 16360833


53it [05:23,  6.10s/it]

Encoded matrix 47; words emitted so far: 16706085


54it [05:29,  6.08s/it]

Encoded matrix 46; words emitted so far: 17026548


55it [05:35,  6.08s/it]

Encoded matrix 45; words emitted so far: 17353267


56it [05:41,  6.10s/it]

Encoded matrix 44; words emitted so far: 17697071


57it [05:47,  6.15s/it]

Encoded matrix 43; words emitted so far: 18023448


58it [05:53,  6.15s/it]

Encoded matrix 42; words emitted so far: 18346812


59it [05:59,  6.13s/it]

Encoded matrix 41; words emitted so far: 18687526


60it [06:05,  6.11s/it]

Encoded matrix 40; words emitted so far: 19018871


61it [06:11,  6.09s/it]

Encoded matrix 39; words emitted so far: 19341560


62it [06:17,  6.08s/it]

Encoded matrix 38; words emitted so far: 19679065


63it [06:23,  6.06s/it]

Encoded matrix 37; words emitted so far: 20013155


64it [06:30,  6.06s/it]

Encoded matrix 36; words emitted so far: 20336870


65it [06:36,  6.05s/it]

Encoded matrix 35; words emitted so far: 20670225


66it [06:42,  6.04s/it]

Encoded matrix 34; words emitted so far: 21006191


67it [06:48,  6.05s/it]

Encoded matrix 33; words emitted so far: 21332327


68it [06:54,  6.06s/it]

Encoded matrix 32; words emitted so far: 21663193


69it [07:00,  6.04s/it]

Encoded matrix 31; words emitted so far: 21999554


70it [07:06,  6.03s/it]

Encoded matrix 30; words emitted so far: 22327112


71it [07:12,  6.01s/it]

Encoded matrix 29; words emitted so far: 22655947


72it [07:18,  6.01s/it]

Encoded matrix 28; words emitted so far: 22991942


73it [07:24,  6.00s/it]

Encoded matrix 27; words emitted so far: 23321391


74it [07:30,  6.01s/it]

Encoded matrix 26; words emitted so far: 23649216


75it [07:36,  6.04s/it]

Encoded matrix 25; words emitted so far: 23984877


76it [07:42,  6.07s/it]

Encoded matrix 24; words emitted so far: 24315656


77it [07:48,  6.03s/it]

Encoded matrix 23; words emitted so far: 24644003


78it [07:54,  6.02s/it]

Encoded matrix 22; words emitted so far: 24977312


79it [08:00,  6.02s/it]

Encoded matrix 21; words emitted so far: 25309427


80it [08:06,  6.03s/it]

Encoded matrix 20; words emitted so far: 25637462


81it [08:12,  6.04s/it]

Encoded matrix 19; words emitted so far: 25969723


82it [08:18,  6.07s/it]

Encoded matrix 18; words emitted so far: 26302483


83it [08:24,  6.08s/it]

Encoded matrix 17; words emitted so far: 26632281


84it [08:30,  6.08s/it]

Encoded matrix 16; words emitted so far: 26963760


85it [08:37,  6.15s/it]

Encoded matrix 15; words emitted so far: 27296000


86it [08:43,  6.12s/it]

Encoded matrix 14; words emitted so far: 27625990


87it [08:49,  6.10s/it]

Encoded matrix 13; words emitted so far: 27956974


88it [08:55,  6.09s/it]

Encoded matrix 12; words emitted so far: 28289043


89it [09:01,  6.06s/it]

Encoded matrix 11; words emitted so far: 28620249


90it [09:07,  6.04s/it]

Encoded matrix 10; words emitted so far: 28950098


91it [09:13,  6.03s/it]

Encoded matrix 9; words emitted so far: 29282315


92it [09:19,  6.02s/it]

Encoded matrix 8; words emitted so far: 29613851


93it [09:25,  6.01s/it]

Encoded matrix 7; words emitted so far: 29943588


94it [09:31,  6.00s/it]

Encoded matrix 6; words emitted so far: 30275837


95it [09:37,  6.00s/it]

Encoded matrix 5; words emitted so far: 30607618


96it [09:43,  6.00s/it]

Encoded matrix 4; words emitted so far: 30937586


97it [09:49,  6.03s/it]

Encoded matrix 3; words emitted so far: 31268963


98it [09:55,  6.02s/it]

Encoded matrix 2; words emitted so far: 31600869


99it [10:01,  6.02s/it]

Encoded matrix 1; words emitted so far: 31931007


100it [10:07,  6.07s/it]

Encoded matrix 0; words emitted so far: 32261990





In [9]:
for coder in coders:
    assert len(coder.bulk) == 0
    compressed = coder.get_compressed()
    assert len(compressed) == 2
    compressed_stream += compressed
compressed_stream = np.array(compressed_stream, dtype=np.uint16)

In [12]:
bitrate = len(compressed_stream) * 16 / quantized_matrices.size
len(compressed_stream), bitrate, bitrate - cross_entropies_12bit.mean()

(34691303, 5.293472747802735, 0.20617554845228003)

In [13]:
# with open("100-compressed-matrices.pickle", "wb") as f:
#     pickle.dump({"compressed_stream": compressed_stream, "coder_offsets": coder_offsets}, f)

In [6]:
with open("100-compressed-matrices.pickle", "rb") as f:
    d = pickle.load(f)
    compressed_stream = d["compressed_stream"]
    coder_offsets = d["coder_offsets"]
    del d

### Serialize to a File

In [34]:
def serialize_matrix(i):
    cdf = np.array([counts_12bit[i][:j].sum() for j in range(len(counts_12bit[i]) + 1)], dtype=np.uint16)
    data_offset = np.array([len(cdf) + 1], dtype=np.uint16)
    start = coder_offsets[i, 0]
    end = start if i == 0 else coder_offsets[i - 1, 0]
    data = compressed_stream[start:end][::-1]
    return np.concatenate([data_offset, cdf, data])

In [35]:
serialized_sizes = np.array([serialize_matrix(i).size for i in range(len(quantized_matrices))], dtype=np.uint32)
serialized_sizes, serialized_sizes.max()

(array([    78, 332241, 331060, 330216, 331987, 331455, 330044, 331860,
        332325, 329815, 331617, 332295, 329927, 331284, 332145, 331064,
        330066, 332319, 331558, 329879, 332840, 332339, 328111, 332194,
        333385, 328425, 330857, 335742, 327901, 329529, 336072, 328910,
        327637, 336437, 330944, 326216, 336044, 333434, 323793, 334167,
        337582, 322769, 331419, 340793, 323443, 326455, 343883, 326798,
        320542, 345330, 331853, 314933, 343560, 340227, 309240, 339272,
        350052, 307086, 329680, 360216, 309468, 315432, 368840, 319216,
        297229, 373021, 335884, 279694, 367900, 360475, 265178, 351095,
        390701, 260596, 321099, 419455, 272920, 276076, 444041, 304387,
        219989, 457270, 353557, 166826, 444379, 419030, 140050, 383622,
        503638, 152196, 253879, 616465, 180265,  78305, 761215, 213233,
         12997, 792229, 250265,   6522], dtype=uint32),
 792229)

In [81]:
# with open("100-compressed-matrices.bin", "wb") as f:
#     initial_coder_heads = compressed_stream[coder_offsets[0, 0]:][::-1]
#     initial_coder_heads.tofile(f)
#     serialized_sizes_lo = (serialized_sizes & ((1 << 16) - 1)).astype(np.uint16)
#     serialized_sizes_hi = (serialized_sizes >> 16).astype(np.uint16)
#     serialized_sizes_le = np.vstack([serialized_sizes_lo, serialized_sizes_hi]).T.ravel()
#     serialized_sizes_le.tofile(f)
#     for i in range(len(quantized_matrices)):
#         m = serialize_matrix(i)
#         m.tofile(f)
#     # NOTE: this doesn't serialize the final compressed data that would be necessary
#     # to reconstruct the original coder states (which is currently always `[1, 0]`).

In [7]:
!ls -lh 100-compressed-matrices.bin

-rw-rw-r-- 1 robamler robamler 67M Mai 12 16:47 100-compressed-matrices.bin


## Debugging

In [24]:
# Checksum adapted from rust's FxHash, but restricted to 26 bit hashes so that the
# multiplication doesn't exceed the range of exactly representable integers in JavaScript.
def get_checksum(i):
    checksum = 0
    for value in quantized_matrices[i].ravel():
        value = inv_vocabs[i][value]
        checksum = ((checksum & 0x001f_ffff) << 5) | (checksum >> 21) # rotate
        checksum = (checksum ^ value) * 0x0322_0a95
        checksum = checksum & 0x03ff_ffff # truncate to 26 bit
    return checksum

In [57]:
[inv_vocabs[1][i] for i in quantized_matrices[1].ravel()[[32_768]]]

[13]

In [None]:
for i in range(100):
    print(f'checksum of matrix {i}: {get_checksum(i)}')

checksum of matrix 0: 55445785
checksum of matrix 1: 15163511
checksum of matrix 2: 6664559
checksum of matrix 3: 15747500
checksum of matrix 4: 5573655
checksum of matrix 5: 27711411
checksum of matrix 6: 20047205
checksum of matrix 7: 10641525
checksum of matrix 8: 64606195
checksum of matrix 9: 65232545
checksum of matrix 10: 51903642
checksum of matrix 11: 57475160
checksum of matrix 12: 54761909
checksum of matrix 13: 30428408
checksum of matrix 14: 22313942
checksum of matrix 15: 33179881
checksum of matrix 16: 37189258
checksum of matrix 17: 16278593
checksum of matrix 18: 56573661
checksum of matrix 19: 41121083
checksum of matrix 20: 55929650
checksum of matrix 21: 43756132
checksum of matrix 22: 28003007
checksum of matrix 23: 56424187
checksum of matrix 24: 28589585
checksum of matrix 25: 37515944
checksum of matrix 26: 58783295
checksum of matrix 27: 57634150
checksum of matrix 28: 6008159
checksum of matrix 29: 16510335
checksum of matrix 30: 64138918
checksum of matrix 31