In [None]:
%pip install wandb requests_cache datasets tqdm python-dotenv

Collecting wandb
  Downloading wandb-0.17.5-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting requests_cache
  Downloading requests_cache-1.2.1-py3-none-any.whl.metadata (9.9 kB)
Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.11.0-py2.py3-none-any.whl.metadata (14 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting cattrs>=22.2 (from requests

In [None]:
import os
import wandb

try:
    from dotenv import load_dotenv
    # Load environment variables from .env file
    load_dotenv()

except ImportError as e:
    print(f"Error importing dotenv: {e}")


# Check if running in Colab
try:
    from google.colab import userdata
    # If running in Colab, use userdata.get to retrieve the token
    wandb.login(key=userdata.get('wandb_token'))

except ImportError:
    # If not in Colab, load the token from the environment variable
    wandb_token = os.getenv('WANDB_TOKEN')
    if wandb_token:
        wandb.login(key=wandb_token)
    else:
        print("W&B token not found in environment variable. Please set WANDB_TOKEN in your environment.")


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Download Data

In [None]:
import io
import os
import sys
import zipfile

import requests
import requests_cache
from tqdm import tqdm


zip_link = "http://www.mattmahoney.net/dc/enwik8.zip"
data_folder = "dataset"
cache_file = "download_cache"

# Ensure the data folder exists
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

# Initialize requests_cache
requests_cache.install_cache(os.path.join(data_folder, cache_file))

# Download the ZIP file with progress bar
response = requests.get(zip_link, stream=True)
response.raise_for_status()

# Get the total file size for the progress bar
total_size = int(response.headers.get("content-length", 0))

# Open the ZIP file from the content
with open(os.path.join(data_folder, "enwik8.zip"), "wb") as file:
    with tqdm(
        total=total_size, unit="B", unit_scale=True, desc="Downloading"
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            file.write(data)
            pbar.update(len(data))

# Open the cached file
with open(os.path.join(data_folder, "enwik8.zip"), "rb") as file:
    # Open the ZIP file from the content
    with zipfile.ZipFile(io.BytesIO(file.read())) as zip_file:
        # Extract all contents to the data folder
        zip_file.extractall(data_folder)

print("File downloaded and decompressed successfully.", file=sys.stderr)


Downloading: 100%|██████████| 36.4M/36.4M [00:00<00:00, 445MB/s]
File downloaded and decompressed successfully.


## Data

In [None]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=["dataset/enwik8"])
dataset = dataset["train"]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
from transformers import AutoTokenizer

MODEL_ID = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [None]:
prev_len = len(dataset)

In [None]:
# Removing large and empty samples
# TODO: speed up filtering by batch tokenizing
max_len = 128

def filter_samples(samples):
    lengths = tokenizer(samples['text'], return_length=True)['length']
    return [1 < l <= max_len for l in lengths]

dataset = dataset.filter(filter_samples, batched=True)


Filter:   0%|          | 0/1128024 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (596 > 512). Running this sequence through the model will result in indexing errors


In [None]:
new_len = len(dataset)
removed_fraction = 1 - (new_len / prev_len)
print(f"Removed {removed_fraction:.2f} of dataset")

Removed 0.24 of dataset


In [None]:
import random
sample = random.choice(dataset)
print(repr(sample["text"]))

'[[Category:1966 singles]]'


## Model

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs


class CompressionConfig(transformers.T5Config): ...


@dataclass
class CompressionOutput(transformers.modeling_outputs.Seq2SeqLMOutput):
    value_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None


class CompressionModel(transformers.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.critic_head = nn.Linear(config.d_model, 1)
        self.critic_head.weight.data.normal_(mean=0.0, std=(1 / config.d_model))
        self.critic_head.bias.data.zero_()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], CompressionOutput]:
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if output.decoder_hidden_states is not None:
            last_hidden_state = output.decoder_hidden_states[-1]
            value_predictions = self.critic_head(last_hidden_state)
        else:
            value_predictions = None

        return CompressionOutput(
            value_predictions=value_predictions,
            logits=output.logits,
            past_key_values=output.past_key_values,
            decoder_hidden_states=output.decoder_hidden_states,
            decoder_attentions=output.decoder_attentions,
            cross_attentions=output.cross_attentions,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=output.encoder_hidden_states,
            encoder_attentions=output.encoder_attentions,
        )


In [None]:
import transformers
import transformers.modeling_outputs


class DecompressionConfig(transformers.T5Config): ...


class DecompressionModel(transformers.T5ForConditionalGeneration): ...

# Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

compressor = CompressionModel.from_pretrained(MODEL_ID).to(device)
decompressor = DecompressionModel.from_pretrained(MODEL_ID).to(device)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

Some weights of CompressionModel were not initialized from the model checkpoint at google-t5/t5-small and are newly initialized: ['critic_head.bias', 'critic_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
wandb.init(
    name = "Token Training",
    project="DETHCOD",
    # TODO: Remove this for the primary run
    # mode="disabled",
    config={
        "compressor_model_config": compressor.config.to_dict(),
        "decompressor_model_config": decompressor.config.to_dict(),
    },
)

[34m[1mwandb[0m: Currently logged in as: [33maxiom[0m ([33mchihuahuas[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
LR = 1e-6

compressor_optimizer = torch.optim.Adam(compressor.parameters(), lr=LR)
decompressor_optimizer = torch.optim.Adam(decompressor.parameters(), lr=LR)

In [None]:
# TODO: increase batch size
batch_size = 1
train_dataset = dataset
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch.nn.functional as F
import tqdm.auto as tqdm
from transformers import GenerationConfig
from transformers import modeling_outputs

generation_config = GenerationConfig(
    do_sample=True,
    num_beams=1,
    max_new_tokens=100,
    # output_scores = True,
    decoder_start_token_id = compressor.generation_config.decoder_start_token_id,
    eos_token_id = compressor.generation_config.eos_token_id,
    pad_token_id = compressor.generation_config.pad_token_id,
)

with tqdm.tqdm(data_loader) as pbar:
    for batch in pbar:
        correct_predictions = 0

        input_ids = tokenizer(batch['text'], return_tensors="pt", padding=True, truncation=True).input_ids.to(device)

        compressed = compressor.generate(
            input_ids=input_ids,
            generation_config=generation_config,
        )

        compression_output = compressor.forward(
            input_ids=input_ids,
            decoder_input_ids=compressed,
        )

        logits = compression_output.logits

        decompressed = decompressor.forward(
            input_ids=compressed,
            labels=input_ids,
        )

        num_ids = decompressed.logits.size(-1)
        losses = F.cross_entropy(
            decompressed.logits.view(-1, num_ids),
            target=input_ids.view(-1),
            ignore_index=0,
            reduction='none',
        )

        preds = torch.argmax(decompressed.logits, dim=-1)
        correct_predictions = torch.sum(preds == input_ids)
        accuracy = correct_predictions.item() / input_ids.shape[-1]

        TOKEN_COST = 1
        len_compressed = compressed.shape[1]
        reward = (-TOKEN_COST * len_compressed - losses.detach().sum()) / 100.0

        value = compression_output.value_predictions.squeeze(-1)
        value = value[..., :-1]
        Q = torch.ones_like(value) * reward

        advantage = Q - value

        critic_loss = torch.nn.functional.mse_loss(value, Q, reduction='mean')

        num_ids = logits.size(-1)
        # TODO: add a negative sign if didn't work
        action_logits = F.cross_entropy(
            logits[:, :-1].view(-1, num_ids),
            target=compressed[:, 1:].view(-1),
            ignore_index=0,
            reduction='none',
        )

        actor_loss = (advantage.detach() * action_logits).mean()
        decompressor_loss = losses.mean()
        pbar.set_description(f"actor_loss={actor_loss:.2f}, critic_loss={critic_loss:.2f}, decompressor_loss={decompressor_loss:.2f}")

        compressor_optimizer.zero_grad()
        decompressor_optimizer.zero_grad()

        (actor_loss + critic_loss).backward()
        decompressor_loss.backward()

        compressor_optimizer.step()
        decompressor_optimizer.step()

        wandb.log({
            "actor_loss": actor_loss,
            "critic_loss": critic_loss,
            "reward":reward,
            "decompressor_loss":decompressor_loss,
            "accuracy": accuracy,
        })

  0%|          | 0/855090 [00:00<?, ?it/s]

In [None]:
Q

tensor([[-214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921, -214.2921, -214.2921, -214.2921, -214.2921, -214.2921,
         -214.2921]])

In [None]:
wandb.finish()

In [None]:
compression_output.keys()

odict_keys(['logits', 'past_key_values', 'decoder_hidden_states', 'encoder_last_hidden_state', 'encoder_hidden_states', 'value_predictions'])

In [None]:
torch.std_mean(compression_output.encoder_last_hidden_state)

(tensor(0.1653, grad_fn=<StdMeanBackward0>),
 tensor(0.0011, grad_fn=<StdMeanBackward0>))

In [None]:
torch.std_mean(compressor.critic_head.weight)

(tensor(5.6074e+33, grad_fn=<StdMeanBackward0>),
 tensor(5.4796e+32, grad_fn=<StdMeanBackward0>))

In [None]:
last_hidden_state = compression_output.decoder_hidden_states[-1]
value_predictions = compressor.critic_head(last_hidden_state)

In [None]:
torch.std_mean(last_hidden_state)

(tensor(0.6742, grad_fn=<StdMeanBackward0>),
 tensor(-0.0450, grad_fn=<StdMeanBackward0>))

In [None]:
torch.std_mean(value_predictions)

(tensor(0., grad_fn=<StdMeanBackward0>),
 tensor(5.7685e-21, grad_fn=<StdMeanBackward0>))

In [None]:
last_hidden_state.std()

tensor(0.7415, grad_fn=<StdBackward0>)

In [None]:
del compressor, decompressor

In [None]:
compression_output.keys()

In [None]:
action_logits

In [None]:
advantage

In [None]:
losses

In [None]:
reward

In [None]:
len_compressed

In [None]:
advantage

In [None]:
actor_loss

In [None]:
critic_loss

In [None]:
value, Q