In [1]:
# models
!pip install transformers datasets peft

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [2]:
# load the drive folder
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/KV_Compression

Mounted at /content/drive
/content/drive/MyDrive/KV_Compression


In [3]:
# imports
import math
import os
import torch
import random
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM,DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from transformers import Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from torch.utils.data import Dataset
from itertools import chain
from copy import deepcopy
import re
import numpy as np
import selective_context
from selective_context import SelectiveContext
from accelerate.logging import MultiProcessAdapter
from datasets import DatasetDict
from transformers import PreTrainedTokenizer

from collections import deque
import re
import sys

tqdm.pandas()

Loading dependencies...
Using device: cuda


## Pre-process Dataset
We want a model to learn to predict next words given the compress_right token hidden representation instead of the words in the span. To make sure that the previous representations are relevant, we split the WikiText articles into examples where each example contains the text from a subsection of an article. These subsection texts have the following stats:

*   mean word count: 365.8
*   std deviation of words: 270.3
*   Upper 75% : 483 words




In [4]:
# Load the Wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Ensure validation split is present
validation_split_percentage = 10  # Define the percentage for the validation split

# Check if the validation split exists, and create one if not
if "validation" not in dataset.keys():
    print("Validation split not found. Splitting the training set...")
    dataset["validation"] = dataset["train"].train_test_split(
        test_size=validation_split_percentage / 100, shuffle=True, seed=42
    )["test"]
    dataset["train"] = dataset["train"].train_test_split(
        test_size=validation_split_percentage / 100, shuffle=True, seed=42
    )["train"]

# Print the dataset keys and sizes
print(f"Dataset splits: {list(dataset.keys())}")
print(f"Train size: {len(dataset['train'])}")
print(f"Validation size: {len(dataset['validation'])}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset splits: ['test', 'train', 'validation']
Train size: 36718
Validation size: 3760


In [5]:
def find_word_level_indexes(words, phrases):
    """
    Finds the word-level indexes of the phrases in the text.
    returns a dict with key for every unique phrase in text.
    if there are multiplt matches for one phrase, we return both matches.
    """
    idx = 0
    phrase_locs = []

    for phrase in phrases:  # Use a set to handle duplicate phrases
        # Tokenize the phrase into words
        phrase_words = phrase.split()
        phrase_len = len(phrase_words)

        # Find the starting word index of the phrase
        for i in range(idx,len(words) - phrase_len + 1):
            if words[i:i + phrase_len] == phrase_words:
                phrase_locs.append((i, i + phrase_len - 1))
                idx = i + phrase_len
                break

    return phrase_locs

In [6]:
def remove_back_to_back_cr_cl(tokens, cr_token, cl_token):
    """
    Removes `cr_token` followed immediately by `cl_token` from the token list.

    Args:
        tokens (list): The list of tokens.
        cr_token (str): The token indicating the end of a masked phrase.
        cl_token (str): The token indicating the start of a masked phrase.

    Returns:
        list: A new list of tokens with consecutive `cr_token` and `cl_token` removed.
    """
    result = []
    skip_next = False

    for i in range(len(tokens) - 1):
        if skip_next:
            skip_next = False
            continue
        if tokens[i] == cr_token and tokens[i + 1] == cl_token:
            skip_next = True  # Skip the next token
        else:
            result.append(tokens[i])

    # Add the last token if it wasn't skipped
    if not skip_next:
        result.append(tokens[-1])

    return result

In [7]:
# after calling sc for masked_phrases
def insert_sentinel_tokens(input_text, masked_phrases, cl_token, cr_token):
    # Normalize masked phrases: strip extra spaces and remove '�'
    masked_phrases = [phrase.replace('�', '').strip() for phrase in masked_phrases if phrase.strip()]
    masked_phrases = [phrase for phrase in masked_phrases if phrase]

    # get the words
    words = input_text.split()

    # find indices where unique masked phrases occur, list of tuples
    masked_phrase_indexes = find_word_level_indexes(words, masked_phrases)

    # scan through
    phrase_idx_q = deque(masked_phrase_indexes)

    final_tokens = []

    # check for empty
    if not phrase_idx_q:
      return input_text
    curr_mask_start, curr_mask_end = phrase_idx_q.popleft()

    for idx, word in enumerate(words):

      if idx == curr_mask_start:
        final_tokens.append(cl_token)

      final_tokens.append(word)

      if idx == curr_mask_end:
        final_tokens.append(cr_token)
        if phrase_idx_q:
          curr_mask_start, curr_mask_end = phrase_idx_q.popleft()

    # combine spans that meet at ends, i.e. remove cr_token, cl_token if they are next to each other
    combined_tokens = remove_back_to_back_cr_cl(final_tokens, cr_token, cl_token)


    return " ".join(combined_tokens)

## Test functions and check that words in spans are covered with the right proportion

In [39]:
def strategic_tokenize_function_debug(
    examples: DatasetDict,
    compress: bool,
    tokenizer: PreTrainedTokenizer,
    text_column_name: str,
    max_span_length: int,
    bound_ratio: float,
    cl_token: str,
    cr_token: str,
    logger: MultiProcessAdapter
) -> DatasetDict:
    if not compress:
        return tokenizer(examples[text_column_name])

    # initialize model to get self information
    sc = SelectiveContext(model_type='gpt2',lang='en')


    def process_example(example):
      cur_text = example[text_column_name]

      # ignore examples < 30 words
      words = cur_text.split()
      if len(words) < 30:
          return {text_column_name: cur_text}  # Return original text if too short

      # Use selective context class to determine masked phrases
      _, masked_phrases = sc(cur_text, reduce_ratio=bound_ratio, reduce_level='phrase')

      # Insert sentinel tokens to cover masked phrases in spans
      tokenized_text = insert_sentinel_tokens(cur_text, masked_phrases, cl_token, cr_token)
      # print(tokenized_text)  # Debugging output

      # Return the modified text
      return {text_column_name: tokenized_text}

    # use map to apply to each example
    new_dataset = examples.map(process_example)

    return new_dataset
    # return tokenizer(examples[text_column_name])


In [40]:
def calculate_span_ratios(examples,cl_token,cr_token):
    """
    Calculate the ratio of words in a span (between <cl> and <cr>) to the total number of words
    in the example, excluding <cl> and <cr> tokens.

    Parameters:
        examples (dict): A dataset dictionary where each example contains a string.

    Returns:
        list: A list of ratios for each example.
    """
    ratios = []

    for example in examples:
        text = example["text"]

        # Tokenize the text
        tokens = text.split()

        # Initialize counters
        in_span = False
        span_word_count = 0
        total_word_count = 0

        # Iterate through tokens to count words
        for token in tokens:
            if token == cl_token:
                in_span = True  # Start counting words in the span
            elif token == cr_token:
                in_span = False  # Stop counting words in the span
            elif in_span:
                span_word_count += 1

            if token not in {cl_token, cr_token}:
                total_word_count += 1

        # Calculate the ratio
        ratio = span_word_count / total_word_count if total_word_count > 0 else 0

        # Store ratio for this example
        ratios.append(ratio)

    return ratios

def calculate_average_ratio(ratios):
    """
    Calculate the average ratio from the list of ratios.

    Parameters:
        ratios (list): List of ratios for each example.

    Returns:
        float: The average ratio across all examples.
    """
    total_ratio = sum(ratios)
    average_ratio = total_ratio / len(ratios) if ratios else 0
    return average_ratio


In [45]:
import copy
comp_ratios =  [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]

# # reverse comp_ratios
# comp_ratios = comp_ratios[::-1]

print(comp_ratios)

# get the training data
examples = dataset['train']

# for this example, filter out examples with <30 words
examples = examples.filter(lambda example: len(example['text'].split()) >= 30)


for comp_ratio in comp_ratios:
  examples_copy = copy.deepcopy(examples)
  # settings for call to strategic tokenize function
  tokenizer = None
  text_column_name = "text"
  compress = True
  max_span_length =20
  bound_ratio = comp_ratio
  cl_token = "<CL>"
  cr_token = "<CR>"
  logger = None

  # call to function
  tokenized_examples = strategic_tokenize_function_debug(examples_copy,compress,tokenizer,text_column_name,
      max_span_length,bound_ratio,cl_token,cr_token,logger)

  #print(tokenized_examples['text'][3])
  # call to new function to count for each example how many in tokenized text are enclosed in spans
  span_ratios = calculate_span_ratios(tokenized_examples,cl_token,cr_token)
  average_ratio = calculate_average_ratio(span_ratios)
  print(f"At r: {comp_ratio}, Average Ratio: {average_ratio}")





[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.1, Average Ratio: 0.07896478407203295
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.2, Average Ratio: 0.1609350349209587
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.3, Average Ratio: 0.252482257753118
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.4, Average Ratio: 0.35079480119462614
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.5, Average Ratio: 0.4543381259523914
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.6, Average Ratio: 0.566761240415236
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.7, Average Ratio: 0.6794352545081181
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.8, Average Ratio: 0.7884591239912913
model loaded


Map:   0%|          | 0/14719 [00:00<?, ? examples/s]

At r: 0.9, Average Ratio: 0.8914215428439306
