# Task 2. PyTorch Compression Enchancement

### I. Import Libraries

In [1]:
# !pip install llmlingua
# !pip install openai==0.28
# !pip install spacy
# !python -m spacy download en_core_web_sm
# !pip install scikit-learn
# !pip install tensorboard

# !pip install datasets

In [2]:
# import test dataset

from datasets import load_dataset
ds_test = load_dataset("openai/gsm8k", "main", split="test")
original_prompts = []
for idx, instance in enumerate(ds_test):
  original_prompts.append("Question: "+instance['question']+instance['answer'])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [3]:
import os
os.makedirs('../../../results/models/xlm-roberta-large-gsm8k-only', exist_ok=True)

In [4]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch.nn.functional as F
import torch

model_gsm8k = '../../../results/models/xlm-roberta-large-gsm8k-only'
model_meetingbank = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank"

tokenizer = AutoTokenizer.from_pretrained(model_meetingbank)
model = AutoModelForTokenClassification.from_pretrained(model_meetingbank)


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/752 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

### II. Update token scoring to handle domain shifts

For gsm8k, the probabilities of preserved tokens of numbers and operators are multiplied by a weight, which is followed by selecting the topK tokens for compression.

In [5]:
import re

# List of English number words
number_words = [
    "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
    "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen",
    "eighteen", "nineteen", "twenty", "thirty", "forty", "fifty", "sixty", "seventy",
    "eighty", "ninety", "hundred", "thousand", "million", "billion", "trillion"
]

# Compile a regex pattern for matching the number words
number_pattern = re.compile(r'\b(' + '|'.join(number_words) + r')\b', re.IGNORECASE)

def contains_number(s):
    return bool(re.search(r'\d', s)) or bool(number_pattern.search(s))

def contains_math_ops(s):
    return bool(re.search(r'[+\-*/%^=]', s))

In [6]:
def compare_updated_scoring(prompt, compression_ratio=0.6, weight=1.2):

  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, is_split_into_words=False)

  # Perform inference
  with torch.no_grad():
      outputs = model(**inputs)

  # Extract logits and predict token labels
  logits = outputs.logits
  predictions = torch.argmax(logits, dim=2)
  probabilities = F.softmax(logits, dim=-1)

  # Decode token labels
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
  predicted_labels = [pred for pred in predictions[0]]

  # Preserved tokens
  preserved_tokens = []

  for token, label, prob in zip(tokens, predicted_labels, probabilities[0]):
    if label==0:
      continue

    if contains_number(token) or contains_math_ops(token):
      preserved_tokens.append((token, prob[1].item(), weight*prob[1].item()))
    else:
      preserved_tokens.append((token, prob[1].item(), prob[1].item()))


  # Compressed tokens without updating token scoring
  top_k = int(compression_ratio * len(preserved_tokens))
  sorted_tokens = sorted(preserved_tokens, key=lambda x: x[1], reverse=True)
  prob_threshold = sorted_tokens[top_k][1]

  compressed_tokens = []
  for token, prob, _ in preserved_tokens:
    if prob<prob_threshold:
      continue

    compressed_tokens.append(token)

  # Compressed tokens after updating token scoring
  sorted_tokens = sorted(preserved_tokens, key=lambda x: x[2], reverse=True)
  prob_threshold = sorted_tokens[top_k][2]

  compressed_tokens_updated_score = []
  for token, _, prob in preserved_tokens:
    if prob<prob_threshold:
      continue

    compressed_tokens_updated_score.append(token)

  original_tokens = "".join(tokens)
  compressed_tokens = "".join(compressed_tokens)
  compressed_tokens_updated_score = "".join(compressed_tokens_updated_score)

  return original_tokens, compressed_tokens, compressed_tokens_updated_score


In [7]:
# Tokenize the input
prompt = original_prompts[0]

original, compressed, compressed_updated = compare_updated_scoring(prompt, 0.6, 1.5)

print("original tokens")
print(original)

print("compressed tokens")
print(compressed)

print("compressed tokens after updating scores")
print(compressed_updated)

original tokens
<s>▁Question:▁Janet’s▁ducks▁lay▁16▁eggs▁per▁day.▁She▁eats▁three▁for▁breakfast▁every▁morning▁and▁bakes▁muffins▁for▁her▁friends▁every▁day▁with▁four.▁She▁sells▁the▁remainder▁at▁the▁farmers'▁market▁daily▁for▁$2▁per▁fresh▁duck▁egg.▁How▁much▁in▁dollars▁does▁she▁make▁every▁day▁at▁the▁farmers'▁market?Janet▁sells▁16▁-▁3▁-▁4▁=▁<<16-3-4=9>>9▁duck▁eggs▁a▁day.▁She▁makes▁9▁*▁2▁=▁$<<9*2=18>>18▁every▁day▁at▁the▁farmer’s▁market.▁####▁18</s>
compressed tokens
▁Janet’s▁ducks▁la▁16▁eggs▁eats▁three▁bakes▁muffins▁friends▁sells▁remainder▁farmers▁market▁$2▁per▁egg▁dollars?Janet▁sells▁16▁-▁3▁-▁4=99▁duck▁eggs▁makes▁9▁*▁2▁=*2=18>>18▁farmer▁market</s>
compressed tokens after updating scores
▁Janet’s▁ducks▁la▁16▁eggs▁eats▁three▁bakes▁muffins▁friends▁four▁sells▁remainder▁market▁$2▁per▁eggJanet▁sells▁16▁-▁3▁-▁4▁=16-3-4=99▁duck▁eggs▁makes▁9▁*▁2▁=9*2=1818▁market</s>


In [8]:
prompt = original_prompts[1]

original, compressed, compressed_updated = compare_updated_scoring(prompt, 0.6, 1.5)

print("original tokens")
print(original)

print("compressed tokens")
print(compressed)

print("compressed tokens after updating scores")
print(compressed_updated)

original tokens
<s>▁Question:▁A▁robe▁takes▁2▁bolts▁of▁blue▁fiber▁and▁half▁that▁much▁white▁fiber.▁How▁many▁bolts▁in▁total▁does▁it▁take?It▁takes▁2/2=<<2/2=1>>1▁bolt▁of▁white▁fiber▁So▁the▁total▁amount▁of▁fabric▁is▁2+1=<<2+1=3>>3▁bolts▁of▁fabric▁####▁3</s>
compressed tokens
▁robe▁takes▁2▁bolts▁blue▁half▁white▁bolts▁total▁2/2=11▁bolt▁white▁total▁fabric▁2+12+1=3>>3▁bolts
compressed tokens after updating scores
▁robe▁takes▁2▁bolt▁blue▁half▁white▁bolts▁2/2=2/2=11▁white▁total▁fabric▁2+12+1=33▁bolts</s>


In [9]:
prompt = original_prompts[2]

original, compressed, compressed_updated = compare_updated_scoring(prompt, 0.6, 1.5)

print("original tokens")
print(original)

print("compressed tokens")
print(compressed)

print("compressed tokens after updating scores")
print(compressed_updated)

original tokens
<s>▁Question:▁Josh▁decides▁to▁try▁flipping▁a▁house.▁He▁buys▁a▁house▁for▁$80,000▁and▁then▁puts▁in▁$50,000▁in▁repairs.▁This▁increased▁the▁value▁of▁the▁house▁by▁150%.▁How▁much▁profit▁did▁he▁make?The▁cost▁of▁the▁house▁and▁repairs▁came▁out▁to▁80,000+50,000=$<<80000+50000=130000>>130,000▁He▁increased▁the▁value▁of▁the▁house▁by▁80,000*1.5=<<80000*1.5=120000>>120,000▁So▁the▁new▁value▁of▁the▁house▁is▁120,000+80,000=$<<120000+80000=200000>>200,000▁So▁he▁made▁a▁profit▁of▁200,000-130,000=$<<200000-130000=70000>>70,000▁####▁70000</s>
compressed tokens
▁Josh▁buys▁$80,000▁$50,000▁repairs▁increased▁value▁150%▁profit▁cost▁80,000+50,0008+50000=130000>>130,000▁increased▁80,000*1.5120,000▁new▁value▁120,000+80,00012000080000=200000>>200,000▁profit▁200,000-130,000200000-130000=7>>70,000
compressed tokens after updating scores
▁Josh▁$80,000▁$50,000▁repairs▁150%▁80,000+50,00080000+50000=130000130,00080,000*1.580000*1.5=120000120,000▁120,000+80,000=120000+80000=200000200,000▁200,000-130,000=2000

* By prioritizing numbers and operators, more of them are kept in compressed tokens

### III. Add fallback mechanisms for critical content

Numbers and operators are taken as critical content for gsm8k. A mark ('critical' or 'non-critical') is appended to preserved_tokens. Thus a token will be kept during compression if its preserved probability is greater than probability threshold or it is 'critical'.

In [10]:
def compare_fallback(prompt, compression_ratio=0.6, fallback=True):

  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, is_split_into_words=False)

  # Perform inference
  with torch.no_grad():
      outputs = model(**inputs)

  # Extract logits and predict token labels
  logits = outputs.logits
  predictions = torch.argmax(logits, dim=2)
  probabilities = F.softmax(logits, dim=-1)

  # Decode token labels
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
  predicted_labels = [pred for pred in predictions[0]]

  # Preserved tokens
  preserved_tokens = []

  for token, label, prob in zip(tokens, predicted_labels, probabilities[0]):
    if label==0:
      continue

    if contains_number(token) or contains_math_ops(token):
      preserved_tokens.append((token, prob[1].item(), 'critical'))
    else:
      preserved_tokens.append((token, prob[1].item(), 'non-critical'))


  # Compressed tokens without fallback
  top_k = int(compression_ratio * len(preserved_tokens))
  sorted_tokens = sorted(preserved_tokens, key=lambda x: x[1], reverse=True)
  prob_threshold = sorted_tokens[top_k][1]

  compressed_tokens = []
  for token, prob, _ in preserved_tokens:
    if prob<prob_threshold:
      continue

    compressed_tokens.append(token)

  # Compressed tokens with fallback

  compressed_tokens_fallback = []
  for token, prob, iscritical in preserved_tokens:
    if prob<prob_threshold and iscritical=='non-critical':
      continue

    compressed_tokens_fallback.append(token)

  original_tokens = "".join(tokens)
  compressed_tokens = "".join(compressed_tokens)
  compressed_tokens_fallback = "".join(compressed_tokens_fallback)

  return original_tokens, compressed_tokens, compressed_tokens_fallback

In [11]:
prompt = original_prompts[0]

original, compressed, compressed_fallback = compare_fallback(prompt)

print("original tokens")
print(original)

print("compressed tokens")
print(compressed)

print("compressed tokens after adding fallback")
print(compressed_fallback)

original tokens
<s>▁Question:▁Janet’s▁ducks▁lay▁16▁eggs▁per▁day.▁She▁eats▁three▁for▁breakfast▁every▁morning▁and▁bakes▁muffins▁for▁her▁friends▁every▁day▁with▁four.▁She▁sells▁the▁remainder▁at▁the▁farmers'▁market▁daily▁for▁$2▁per▁fresh▁duck▁egg.▁How▁much▁in▁dollars▁does▁she▁make▁every▁day▁at▁the▁farmers'▁market?Janet▁sells▁16▁-▁3▁-▁4▁=▁<<16-3-4=9>>9▁duck▁eggs▁a▁day.▁She▁makes▁9▁*▁2▁=▁$<<9*2=18>>18▁every▁day▁at▁the▁farmer’s▁market.▁####▁18</s>
compressed tokens
▁Janet’s▁ducks▁la▁16▁eggs▁eats▁three▁bakes▁muffins▁friends▁sells▁remainder▁farmers▁market▁$2▁per▁egg▁dollars?Janet▁sells▁16▁-▁3▁-▁4=99▁duck▁eggs▁makes▁9▁*▁2▁=*2=18>>18▁farmer▁market</s>
compressed tokens after adding fallback
▁Janet’s▁ducks▁la▁16▁eggs▁eats▁three▁bakes▁muffins▁friends▁four▁sells▁remainder▁farmers▁market▁$2▁per▁egg▁dollars?Janet▁sells▁16▁-▁3▁-▁4▁=16-3-4=99▁duck▁eggs▁makes▁9▁*▁2▁=9*2=18>>18▁farmer▁market</s>


* As can be seen, all the numbers and operators are kept in compressed prompt and in corresponding order

### IV. Implement adaptive thresholds

Instead of selecting topk tokens within a sample, the preserved probabilities of all samples are sorted and the probability threshold are set based on the topk probability over all samples. During compression, only the tokens with probablity greater than threshold will be kept. This will guarantee a fixed overall compression rate for all samples, but the compression rate for each sample may vary.

In [13]:
def compare_adaptive_thresholds(prompts, compression_ratio=0.6):

  # collecting all preserved tokens with probs
  tokens_all = []
  tokens_org = []
  for prompt in prompts:

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, is_split_into_words=False)
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)
    probabilities = F.softmax(logits, dim=-1)

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    predicted_labels = [pred for pred in predictions[0]]

    preserved_tokens = []
    for token, label, prob in zip(tokens, predicted_labels, probabilities[0]):
      if label==0:
        continue

      preserved_tokens.append((token, prob[1].item()))

    tokens_all.append(preserved_tokens)
    tokens_org.append("".join(tokens))

  # compute probability threshold of all samples
  probs = []
  for tokens in tokens_all:
    probs += [prob for _, prob in tokens]

  top_k = int(compression_ratio * len(probs))
  sorted_probs = sorted(probs, reverse=True)
  prob_threshold = sorted_probs[top_k]

  # compressed tokens based on prob_threshold
  compressed_tokens_all = []
  for tokens in tokens_all:
    compressed_tokens = []
    for token, prob in tokens:
      if prob<prob_threshold:
        continue

      compressed_tokens.append(token)

    compressed_tokens_all.append("".join(compressed_tokens))

  return tokens_org, compressed_tokens_all

In [14]:
tokens_org, compressed_tokens_all  = compare_adaptive_thresholds(original_prompts[0:3])

for original, compressed in zip(tokens_org, compressed_tokens_all):

  print("====================")
  print("original tokens")
  print(original)

  print("compressed tokens")
  print(compressed)



original tokens
<s>▁Question:▁Janet’s▁ducks▁lay▁16▁eggs▁per▁day.▁She▁eats▁three▁for▁breakfast▁every▁morning▁and▁bakes▁muffins▁for▁her▁friends▁every▁day▁with▁four.▁She▁sells▁the▁remainder▁at▁the▁farmers'▁market▁daily▁for▁$2▁per▁fresh▁duck▁egg.▁How▁much▁in▁dollars▁does▁she▁make▁every▁day▁at▁the▁farmers'▁market?Janet▁sells▁16▁-▁3▁-▁4▁=▁<<16-3-4=9>>9▁duck▁eggs▁a▁day.▁She▁makes▁9▁*▁2▁=▁$<<9*2=18>>18▁every▁day▁at▁the▁farmer’s▁market.▁####▁18</s>
compressed tokens
▁Janet’s▁ducks▁la▁16▁eggs▁eats▁three▁bakes▁muffins▁friends▁sells▁remainder▁farmer▁market▁$2▁per▁eggJanet▁sells▁16▁-▁3▁-▁4=99▁duck▁eggs▁makes▁9▁*▁22=18>>18▁market
original tokens
<s>▁Question:▁A▁robe▁takes▁2▁bolts▁of▁blue▁fiber▁and▁half▁that▁much▁white▁fiber.▁How▁many▁bolts▁in▁total▁does▁it▁take?It▁takes▁2/2=<<2/2=1>>1▁bolt▁of▁white▁fiber▁So▁the▁total▁amount▁of▁fabric▁is▁2+1=<<2+1=3>>3▁bolts▁of▁fabric▁####▁3</s>
compressed tokens
▁robe▁takes▁2▁bolt▁blue▁half▁white▁bolts▁2/2=11▁bolt▁white▁total▁fabric▁2+12+1=3>>3▁bolts
original tokens