In [None]:
!pip install dahuffman

In [None]:
from datasets import load_dataset

# v177
dataset = load_dataset("arxiv_dataset", data_dir="datasets", split="train", trust_remote_code=True, verification_mode="no_checks") 

def keep_first_arxiv_category(example):
    example["category"] = example["categories"].split(' ', 1)[0]
    return example

dataset = dataset.map(keep_first_arxiv_category)
#dataset = dataset.class_encode_column("category")
dataset = dataset.remove_columns(["id", "submitter", "authors", "comments", "journal-ref", "doi",
                                  "report-no", "categories", "license", "abstract", "update_date"])
dataset

In [None]:
from collections import Counter
from dahuffman import HuffmanCodec

frequencies = Counter(dataset["category"])
frequencies

In [None]:
codec = HuffmanCodec.from_frequencies(frequencies, eof=list(frequencies)[0])
codec.print_code_table()

In [None]:
from math import ceil, log2

bit_sum = 0
for symb, count in frequencies.items():
    bit_length = codec._table[symb][0]
    bit_sum += bit_length * count

classes = len(frequencies.items())
print('Classes:', classes)
print('Packed bits/value:', ceil(log2(classes)))
print('Huffman bits/value:', bit_sum / dataset.num_rows)

In [None]:
import pandas as pd

df = pd.read_csv('arxiv_model/distilbert_arxiv_20240504-074557 outputs.csv')
assert(len(df.columns[2:]) == classes)

df

In [None]:
from tqdm import tqdm

id2label = dict(zip(range(classes), df.columns[2:]))
learned_bit_sum = 0
correct_predictions = 0

for i, row in tqdm(df.iterrows(), total=len(df)):
    label_id = row['label']
    label = id2label[label_id]

    freqs = row.copy()
    del freqs['example']
    del freqs['label']

    correct_predictions += 1 if label_id == freqs.values.argmax() else 0
    row_codec = HuffmanCodec.from_frequencies(freqs, eof=list(frequencies)[0])
    bit_length = row_codec._table[label][0]
    learned_bit_sum += bit_length

print('Overall accuracy:', correct_predictions / dataset.num_rows)
print('Learned retrieval bits/value:', learned_bit_sum / dataset.num_rows)