In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [72]:
input_string = """
The University of Washington (UW and informally U-Dub or U Dub) is a public research university in Seattle, Washington, United States. Founded in 1861, the University of Washington is one of the oldest universities on the West Coast of the United States.

The university has a 703-acre (284 ha) main campus located in the city's University District. It also has satellite campuses in nearby cities of Tacoma and Bothell. Overall, UW encompasses more than 500 buildings and over 20 million gross square footage of space, including one of the largest library systems in the world with more than 26 university libraries, art centers, museums, laboratories, lecture halls, and stadiums.
"""

In [112]:
model_name = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [113]:
# Tokenize the input
inputs = tokenizer(input_string, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

In [114]:
input_ids

tensor([[128000,    198,    791,   3907,    315,   6652,    320,     52,     54,
            323,   6179,    750,    549,   9607,    392,    477,    549,  17533,
              8,    374,    264,    586,   3495,  12374,    304,  16759,     11,
           6652,     11,   3723,   4273,     13,  78811,    304,    220,   9714,
             16,     11,    279,   3907,    315,   6652,    374,    832,    315,
            279,  24417,  23978,    389,    279,   4410,  16377,    315,    279,
           3723,   4273,    382,    791,  12374,    706,    264,    220,  20436,
          64434,    320,  17058,   6520,      8,   1925,  15679,   7559,    304,
            279,   3363,    596,   3907,  11182,     13,   1102,   1101,    706,
          24088,  53008,    304,  14373,   9919,    315,  85628,    323,  11995,
            616,     13,  28993,     11,  66716,  71010,    810,   1109,    220,
           2636,  14016,    323,    927,    220,    508,   3610,  20547,   9518,
          22609,    315,   3

In [115]:
with torch.no_grad():
    outputs = model(input_ids=input_ids)

logits = outputs.logits

In [116]:
logits

tensor([[[ 6.8699,  8.7921, 12.9562,  ..., -4.4354, -4.4354, -4.4355],
         [ 8.9626,  7.1396, 12.8710,  ..., -6.3913, -6.3915, -6.3915],
         [-1.3854, -0.5908, -1.5957,  ..., -9.4544, -9.4544, -9.4544],
         ...,
         [ 2.6486,  2.1777,  0.4392,  ..., -3.0379, -3.0381, -3.0380],
         [ 8.4127,  5.4962,  8.9823,  ..., -2.1911, -2.1912, -2.1911],
         [ 7.6229,  5.7835,  7.6811,  ..., -6.9104, -6.9103, -6.9103]]],
       device='cuda:0')

In [117]:
sorted_indices = torch.argsort(logits, descending=True)

In [118]:
sorted_indices

tensor([[[ 14924,    755,      2,  ..., 103273,  51202,    350],
         [   475,      2,   1527,  ..., 113640, 118554,  82000],
         [ 17200,  31613,   3907,  ...,  97865,  13027,  31488],
         ...,
         [ 10034,  29703,  74175,  ...,  60285, 104880,   3482],
         [    13,    382,     11,  ...,  62758, 117298, 126503],
         [128001,    791,     52,  ...,  48046,  13765,  96693]]],
       device='cuda:0')

In [119]:
_, rank = (input_ids.unsqueeze(-1).squeeze(0)[1:] == sorted_indices.squeeze(0)[:-1]).nonzero(as_tuple=True)

In [120]:
rank # rank is the compressed data

tensor([49, 56,  2,  0,  3,  6,  0,  0, 24, 20,  0,  1,  2,  0,  1,  0,  2,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  1,  2,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  3,  0,  0,
         0,  3,  0,  2,  0,  4,  4,  0, 10,  0,  1,  1,  0,  0,  1,  1,  0,  0,
         1,  9,  0,  0,  1,  1,  1,  0,  2,  0,  0, 19,  2,  3,  1,  0,  0,  0,
         1, 36,  0,  1,  0,  2,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  1,
         0,  0,  0,  0,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         1,  1,  0, 48,  5,  0,  0,  0,  3,  0, 16,  0,  0,  0,  4,  4],
       device='cuda:0')

In [125]:
# Start decompression
seq_len = input_ids.shape[1]

decompressed_outputs = torch.zeros_like(input_ids)
decompressed_outputs[0, 0] = 128000 # 128000 is the special token index indicating the start of sequence
for step in range(1, seq_len):
    with torch.no_grad():
        outputs = model(input_ids=decompressed_outputs[..., :step])
    logits = outputs.logits[0, step - 1]
    sorted_indices = torch.argsort(logits, descending=True)
    decompressed_token = sorted_indices[rank[step - 1]]
    decompressed_outputs[0, step] = decompressed_token

In [126]:
decompressed_outputs

tensor([[128000,    198,    791,   3907,    315,   6652,    320,     52,     54,
            323,   6179,    750,    549,   9607,    392,    477,    549,  17533,
              8,    374,    264,    586,   3495,  12374,    304,  16759,     11,
           6652,     11,   3723,   4273,     13,  78811,    304,    220,   9714,
             16,     11,    279,   3907,    315,   6652,    374,    832,    315,
            279,  24417,  23978,    389,    279,   4410,  16377,    315,    279,
           3723,   4273,    382,    791,  12374,    706,    264,    220,  20436,
          64434,    320,  17058,   6520,      8,   1925,  15679,   7559,    304,
            279,   3363,    596,   3907,  11182,     13,   1102,   1101,    706,
          24088,  53008,    304,  14373,   9919,    315,  85628,    323,  11995,
            616,     13,  28993,     11,  66716,  71010,    810,   1109,    220,
           2636,  14016,    323,    927,    220,    508,   3610,  20547,   9518,
          22609,    315,   3

In [127]:
generated_text = tokenizer.decode(
    decompressed_outputs[0],
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
    spaces_between_special_tokens=False,
)

In [124]:
generated_text

"\nThe University of Washington (UW and informally U-Dub or U Dub) is a public research university in Seattle, Washington, United States. Founded in 1861, the University of Washington is one of the oldest universities on the West Coast of the United States.\n\nThe university has a 703-acre (284 ha) main campus located in the city's University District. It also has satellite campuses in nearby cities of Tacoma and Bothell. Overall, UW encompasses more than 500 buildings and over 20 million gross square footage of space, including one of the largest library systems in the world with more than 26 university libraries, art centers, museums, laboratories, lecture halls, and stadiums.\n"