<a href="https://colab.research.google.com/github/michaelmherrera/cs224-final-proj-compressor/blob/main/NeuralTextCompressor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install datasets



In [None]:
# Varint encoding and decoding functions

import sys

# Adopted from
# https://github.com/bright-tools/varints

ONE_BYTE_LIMIT = 240
TWO_BYTE_LIMIT = 2287
THREE_BYTE_LIMIT = 67823

FOUR_BYTE_LIMIT = 16777215
FIVE_BYTE_LIMIT = 4294967295
SIX_BYTE_LIMIT = 1099511627775
SEVEN_BYTE_LIMIT = 281474976710655
EIGHT_BYTE_LIMIT = 72057594037927935
NINE_BYTE_LIMIT = 18446744073709551615
THREE_BYTE_HEADER = 249
FOUR_BYTE_HEADER = 250
FIVE_BYTE_HEADER = 251
SIX_BYTE_HEADER = 252
SEVEN_BYTE_HEADER = 253
EIGHT_BYTE_HEADER = 254
NINE_BYTE_HEADER = 255
BYTE_VALS = 256
SHORT_VALS = 65536

BUCKET_OFFSET = 2

minint = 0
maxint = NINE_BYTE_LIMIT

buckets = [ { 'limit': FOUR_BYTE_LIMIT,
              'header': FOUR_BYTE_HEADER },
            { 'limit': FIVE_BYTE_LIMIT,
              'header': FIVE_BYTE_HEADER },
            { 'limit': SIX_BYTE_LIMIT,
              'header': SIX_BYTE_HEADER },
            { 'limit': SEVEN_BYTE_LIMIT,
              'header': SEVEN_BYTE_HEADER },
            { 'limit': EIGHT_BYTE_LIMIT,
              'header': EIGHT_BYTE_HEADER },
            { 'limit': NINE_BYTE_LIMIT,
              'header': NINE_BYTE_HEADER },
          ]


def writeToFile(payload, filename):
    with open(filename, "wb") as f:
        f.write(varint_encode(payload))

def readFromFile(filename):
    with open(filename, "rb") as f:
        bytes = f.read()
    return varint_decode(bytes)

def varint_encode( num ):
    return generic_encode( num, funcs )

def encode_int( num ):
    ret_val = None
    if num < 0:
        raise ValueError("Negative numbers not handled")

    if( num <= ONE_BYTE_LIMIT ):
        ret_val = varint_storage( num )
    elif( num <= TWO_BYTE_LIMIT ):
        top = num-ONE_BYTE_LIMIT
        ret_val = varint_storage( (top // BYTE_VALS)+ONE_BYTE_LIMIT+1 ) + \
                  varint_storage( top % BYTE_VALS )
    elif( num <= THREE_BYTE_LIMIT ):
        top = num-(TWO_BYTE_LIMIT+1)
        ret_val = varint_storage( THREE_BYTE_HEADER ) + \
                  varint_storage( top // BYTE_VALS ) + \
                  varint_storage( top % BYTE_VALS )
    else:
        start = 0

        # Work out how many bytes are needed to store this value
        while(( start < len( buckets )) and
              ( num > buckets[start]['limit'])):
            start = start + 1

        if( start == len( buckets )):
            raise ValueError("Too large")

        ret_val = varint_storage( buckets[start]['header'] )
        mod = (buckets[start]['limit']+1) // BYTE_VALS
        start = start + BUCKET_OFFSET

        while( start >= 0 ):
            start = start - 1
            ret_val = ret_val + varint_storage( num // mod )
            num = num % mod
            mod = mod // BYTE_VALS

    return ret_val

def varint_decode( num ):
    return generic_decode( num, funcs )

def decode_val( num ):
    ret_val = None
    bytes_used = 1
    first = store_to_num( num[ 0 ] )
    if( first <= ONE_BYTE_LIMIT ):
        ret_val = first
    elif( first < THREE_BYTE_HEADER ):
        second = store_to_num( num[ 1 ] )
        ret_val = ONE_BYTE_LIMIT+(BYTE_VALS*(first-(ONE_BYTE_LIMIT+1)))+second
        bytes_used = 2
    elif( first == THREE_BYTE_HEADER ):
        second = store_to_num( num[ 1 ] )
        third = store_to_num( num[ 2 ] )
        ret_val = (TWO_BYTE_LIMIT+1)+(BYTE_VALS*second)+third
        bytes_used = 3
    else:
        data_bytes = first-247
        start = data_bytes - 1
        ret_val = 0
        i = 1

        mod = (buckets[start-BUCKET_OFFSET]['limit']+1) // BYTE_VALS

        while( start >= 0 ):
            ret_val = ret_val + (mod * store_to_num( num[ i ] )) 
            i = i + 1
            start = start - 1
            mod = mod // BYTE_VALS

        bytes_used = data_bytes + 1

    return (ret_val, bytes_used)

funcs = { 'decode_val': decode_val,
          'encode_int': encode_int }

if sys.version_info[0] > 2:
    def empty_varint_storage():
        return bytes()
    def varint_storage(b):
        return bytes((b, ))
    def store_to_num(b):
        return b
    def num_types():
        return (int)
else:
    def empty_varint_storage():
        return ""
    def varint_storage(b):
        return chr(b)
    def store_to_num(b):
        return ord(b)
    def num_types():
        return (int,long)

def dump( num ):
    print( "Len: {}".format( len(num) ))
    for element in num:
        print( "B: {}".format( store_to_num(element) ))

def generic_encode( num, funcs ):
    ret_val = None
    if( isinstance(num, list)):
        ret_val = encode_list( num, funcs )
    elif( isinstance( num, num_types() )):
        ret_val = funcs['encode_int']( num )
    return ret_val

def encode_list( num, funcs ):
    ret_val = empty_varint_storage()
    for val in num:
        ret_val = ret_val + funcs['encode_int']( val )
    return ret_val

def generic_decode( num, funcs ):
    ret_val = None
    if( isinstance(num, (str,bytes))):
        ptr = 0
        while ptr < len( num ):
            (int_val, bytes_used) = funcs['decode_val']( num[ptr:] )
            ptr = ptr + bytes_used
            if ret_val is None:
                ret_val = int_val
            else:
                if isinstance( ret_val, num_types()):
                    ret_val = [ret_val]
                ret_val.append( int_val )
    return ret_val


In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gzip

model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
_ = model.eval()

In [None]:
VOCAB_SIZE = 50257
PAD_TOKEN = 50256

In [None]:
# Neural compression functions

def valid_encodings(shifted_inputs, encoded_msgs, sorted_tokens):
  # At each timestep, use the encoded message to select the tokens at the specified
  # index of the list of sorted tokens to reconstruct the original message.
  # Compare against the original message to ensure they are identical.
  batch_size, token_len, vocab_size = sorted_tokens.size()
  msg_len = token_len - 1


  # Flatten the tensor of sorted tokens to make indexing easier
  # and add offsets to the encoded message to account for this flattening
  sorted_tokens_flat = sorted_tokens.view(batch_size, -1)
  encoded_msgs_offset = encoded_msgs + torch.arange(0,vocab_size*msg_len,vocab_size).to('cuda')
  decoded_msgs_cand = torch.gather(sorted_tokens_flat, 1, encoded_msgs_offset)
  return torch.all(decoded_msgs_cand == shifted_inputs[:, :-1])
  

def trans_encode(tokenized_msgs, attentions, vocab_size):
  """
  Parameters
    tokenized_msgs: shape (batch_size, msg_len)
    attentions: shape (batch_size, msg_len)
    vocab_size: integer

  """
  # Encode
  model.eval()
  with torch.no_grad():
    # In theory, I should be able to avoid the loop because the transformer
    # automatically masks the input. But in practice, this causes the logit
    # outputs to differ slightly between the encoder and decoder
    batch_size, msgs_len = tokenized_msgs.size()
    logits_arr = torch.zeros(batch_size, msgs_len, vocab_size).to('cuda')
    for i in range(msgs_len):
      msgs_slice = tokenized_msgs[:,:i+1]
      attentions_slice = attentions[:,:i+1]
      logits = model(msgs_slice, attention_mask=attentions_slice).logits
      logits_arr[:, i] = logits[:, i]
    
  # Sort the indices of the logits in descending order of logit value.
  # This means that the model's top predicted token is the first
  # element in the sorted list, the second highest predicted token is the 
  # second element, and so on.
  # 
  # Once we have this list of tokens ordered by their probability
  # we can find the ground-truth token in this list, and save its index
  # as the encoding of the token.
  shifted_inputs = torch.roll(tokenized_msgs, -1) # Shift inputs to line up with output
  _, sorted_tokens = torch.sort(logits_arr, dim=2, descending=True, stable=True)
  shifted_inputs_reshaped = shifted_inputs.view(batch_size, msgs_len, 1)
  encoded_msgs = (sorted_tokens == shifted_inputs_reshaped).nonzero()[:,2].reshape(batch_size, -1).to('cuda')
  encoded_msgs = encoded_msgs[:, :-1] # Discard the last index because it overflows the original message
  assert valid_encodings(shifted_inputs, encoded_msgs, sorted_tokens)

  # We need to include the first token as part of the encoded message so that we
  # can bootstrap generation
  encoded_msgs = torch.cat((tokenized_msgs[:,:1], encoded_msgs), dim=1)

  return encoded_msgs, logits_arr # Logits for debugging

def trans_decode(encoded_msgs, vocab_size):
  with torch.no_grad():
    # The first value in the encoded message 
    # is the first token of the original message
    first_tokens = encoded_msgs[:, :1]
    encoded_msgs = encoded_msgs[:,1:]


    batch_size, msg_len = encoded_msgs.size()
    logits_arr = torch.zeros(batch_size, msg_len, vocab_size).to('cuda') # For debugging
    decoded_msgs = first_tokens
    for i in range(msg_len):
      logits = model(decoded_msgs).logits
      logits_arr[:,i] = logits[:,i] # For debugging
      _, indices = torch.sort(logits[:,i,:], dim=1, descending=True, stable=True)
      decoded_tokens = torch.gather(indices, 1, encoded_msgs[:,i:i+1])
      decoded_msgs = torch.cat((decoded_msgs, decoded_tokens), dim=1)
  return decoded_msgs, logits_arr # Logits for debugging

def verify_msgs(decoded_msgs, original_msgs, attentions):
  attentions_bool_mask = attentions.type(torch.BoolTensor).to('cuda')
  pad_token_mask = torch.ones(decoded_msgs.size(), dtype=int).to('cuda') * PAD_TOKEN
  # We do this masking because the decompressor will spit out garbage output
  # after the end of a message but we don't care about this because we can identify
  # end-of-message by looking for the first padding token.
  decoded_msgs_cleaned = torch.where(attentions_bool_mask, decoded_msgs, pad_token_mask)
  return torch.all(decoded_msgs_cleaned == original_msgs)

In [None]:
# Evaluation functions

def get_compressed_size(encoded_msgs, attentions, messages):
  batch_size = encoded_msgs.size()[0]
  sizes = []
  for i in range(batch_size):
    attention = attentions[i].tolist()
    end = -1
    # Get end of massage by looking at the attentions
    try:
      end = attentions[i].tolist().index(0)
    except ValueError:
      pass

    if end == -1: # Entire encoding represents a message. No padding
      encoding = encoded_msgs[i].tolist()
    else:
      encoding = encoded_msgs[i][:end+1].tolist()
    binary_arr = varint_encode(encoding)
    trans_encoding_size = len(binary_arr)

    orig_msg_bytes = bytes(messages[i], 'utf-8')
    gzip_encoding_size = len(gzip.compress(orig_msg_bytes, compresslevel=9))
    sizes.append((trans_encoding_size, gzip_encoding_size))
  return sizes

In [None]:
from datasets import load_dataset, DatasetDict
news_dataset = load_dataset("ccdv/cnn_dailymail", "3.0.0")
# Remove unneeded columns, just keep "article"
news_dataset = news_dataset.remove_columns("id")
news_dataset = news_dataset.remove_columns("highlights")

Downloading and preparing dataset cnn_dailymail/3.0.0 to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f...


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /root/.cache/huggingface/datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [187]:
dev_set = news_dataset['validation'].select(range(1000))
tokenized_dev_set = dev_set.map(
    lambda example: tokenizer(example['article'], return_tensors="np", padding='max_length', truncation=True, max_length=512),
    batched=True,
    batch_size=16
)
tokenized_dev_set.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# Filter to only articles whose tokenized lenght is less than 1024
# because we want to be able to compare the gzipped full article
# to the neurally compressed full article
# Compressing truncated articles would mean that the gzipped version
# would compress the untruncated article but the neural compressor
# would only compress the truncated version :(
tokenized_dev_set = tokenized_dev_set.filter(
    lambda example: not torch.all(example['attention_mask'] == 1)
)


  0%|          | 0/63 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [188]:
tokenized_dev_set

Dataset({
    features: ['article', 'input_ids', 'attention_mask'],
    num_rows: 301
})

In [190]:
from torch.utils.data import DataLoader

eval_dataloader = DataLoader(tokenized_dev_set, batch_size=2)


In [194]:
for batch in eval_dataloader:
  encoded_msgs, encoder_logits = trans_encode(batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'), VOCAB_SIZE)
  break


In [196]:
decoded_msgs, logits_arr = trans_decode(encoded_msgs, VOCAB_SIZE)


RuntimeError: ignored

In [197]:
verify_msgs(decoded_msgs, batch['input_ids'].to('cuda'), batch['attention_mask'].to('cuda'))

tensor(True, device='cuda:0')

In [None]:
get_compressed_size()

In [None]:
# messages = [" But if you are preparing data and doing cat in each iteration, it gets really slow when the tensor you are generating gets very large. My solution was to cat into", 
#             "msg 2 baby", 
#             "The Boat Race 2021 comprised two side-by-side rowing races that took place on 4 April. The Boat Race is contested annually between crews from the universities of Oxford and Cambridge. Traditionally held on the Championship Course in London, the 2021 race instead took place on the River Great Ouse near Ely (course map pictured). This was the 75th women's race and the 166th men's race;"]

# tokenized = tokenizer(messages, return_tensors="pt", padding="longest", truncation=True, max_length=1024)
# attentions = tokenized.attention_mask.to('cuda')
# # sample_inputs = tokenized.input_ids.to('cuda')

In [None]:
encoded_msgs, encoder_logits = trans_encode(sample_inputs, attentions, VOCAB_SIZE)
decoded_msgs, decoder_logits = trans_decode(encoded_msgs, VOCAB_SIZE)
verify_msgs(decoded_msgs, sample_inputs, attentions).item()

True