# Token based DETHCOD

In [None]:
!pip install transformers wandb requests_cache datasets tqdm python-dotenv peft accelerate bitsandbytes>0.37.0

In [1]:
import os
import wandb

try:
    from dotenv import load_dotenv
    # Load environment variables from .env file
    load_dotenv()

except ImportError as e:
    print(f"Error importing dotenv: {e}")


# Check if running in Colab
try:
    from google.colab import userdata
    # If running in Colab, use userdata.get to retrieve the token
    wandb.login(key=userdata.get('wandb_token'))

except ImportError:
    # If not in Colab, load the token from the environment variable
    wandb_token = os.getenv('WANDB_TOKEN')
    if wandb_token:
        wandb.login(key=wandb_token, relogin=True)
    else:
        print("W&B token not found in environment variable. Please set WANDB_TOKEN in your environment.")


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/khodabandeh/.netrc


## Download Data

In [2]:
import io
import os
import sys
import zipfile

import requests
import requests_cache
from tqdm.auto import tqdm


zip_link = "http://www.mattmahoney.net/dc/enwik8.zip"
data_folder = "dataset"
cache_file = "download_cache"

# Ensure the data folder exists
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

# Initialize requests_cache
requests_cache.install_cache(os.path.join(data_folder, cache_file))

# Download the ZIP file with progress bar
response = requests.get(zip_link, stream=True)
response.raise_for_status()

# Get the total file size for the progress bar
total_size = int(response.headers.get("content-length", 0))

# Open the ZIP file from the content
with open(os.path.join(data_folder, "enwik8.zip"), "wb") as file:
    with tqdm(
        total=total_size, unit="B", unit_scale=True, desc="Downloading"
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            file.write(data)
            pbar.update(len(data))

# Open the cached file
with open(os.path.join(data_folder, "enwik8.zip"), "rb") as file:
    # Open the ZIP file from the content
    with zipfile.ZipFile(io.BytesIO(file.read())) as zip_file:
        # Extract all contents to the data folder
        zip_file.extractall(data_folder)

print("File downloaded and decompressed successfully.", file=sys.stderr)


Downloading:   0%|          | 0.00/36.4M [00:00<?, ?B/s]

File downloaded and decompressed successfully.


## Data

In [1]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=["dataset/enwik8"])
dataset = dataset["train"]

In [2]:
from transformers import AutoTokenizer

MODEL_ID = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [3]:
# Removing large and empty samples
MAX_LENGTH = 128

def filter_samples(example):
    tokenized = tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_LENGTH + 1,
        return_attention_mask=False,
        return_length=True,
    )

    return [
        1 < sample_length <= MAX_LENGTH
        for sample_length in tokenized.length
    ]

dataset = dataset.filter(filter_samples, batched=True)

In [4]:
import random
sample = random.choice(dataset)
print(repr(sample["text"]))

'area_note =|'


## Model

In [6]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs


class CompressionConfig(transformers.T5Config): ...


@dataclass
class CompressionOutput(transformers.modeling_outputs.Seq2SeqLMOutput):
    value_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None


class CompressionModel(transformers.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.critic_head = nn.Linear(config.d_model, 1)
        self.critic_head.weight.data.normal_(mean=0.0, std=(1 / config.d_model))
        self.critic_head.bias.data.zero_()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], CompressionOutput]:
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if output.decoder_hidden_states is not None:
            last_hidden_state = output.decoder_hidden_states[-1]
            value_predictions = self.critic_head(last_hidden_state).squeeze(-1)
        else:
            value_predictions = None

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
            loss = loss_fct(output.logits.view(-1, self.config.vocab_size), labels.view(-1))

        return CompressionOutput(
            loss=loss,
            value_predictions=value_predictions,
            logits=output.logits,
            past_key_values=output.past_key_values,
            decoder_hidden_states=output.decoder_hidden_states,
            decoder_attentions=output.decoder_attentions,
            cross_attentions=output.cross_attentions,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=output.encoder_hidden_states,
            encoder_attentions=output.encoder_attentions,
        )


In [7]:
import transformers
import transformers.modeling_outputs


class DecompressionConfig(transformers.T5Config): ...


class DecompressionModel(transformers.T5ForConditionalGeneration): ...

In [8]:
from pathlib import Path

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
MODEL_PATH = Path("./data/models/token-dethcod/a2c-v2-reward-norm")

### Load Model

In [9]:
LOAD_LATEST = True

if LOAD_LATEST:
    compressor = CompressionModel.from_pretrained(MODEL_PATH / "compressor").to(device)
    decompressor = DecompressionModel.from_pretrained(MODEL_PATH / "decompressor").to(device)

else:
    print(f"Loading {MODEL_ID}")
    compressor = CompressionModel.from_pretrained(MODEL_ID, quantization_config=quantization_config).to(device)
    compressor.critic_head.reset_parameters()
    decompressor = DecompressionModel.from_pretrained(MODEL_ID, quantization_config=quantization_config).to(device)

## Eval

In [10]:
import math

BATCH_SIZE = 16
REWARD_SCALING = 0.01
MAX_TOKEN_COST = math.log(compressor.config.vocab_size)

train_dataset = dataset
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
import wandb

wandb.init(
    name = "Evaluation",
    project="DETHCOD",
    config={
        "compressor_model_config": compressor.config.to_dict(),
        "decompressor_model_config": decompressor.config.to_dict(),
        # TODO: Add other parameters
    },
)

VBox(children=(Label(value='0.047 MB of 0.071 MB uploaded (0.004 MB deduped)\r'), FloatProgress(value=0.663507…

wandb: ERROR Error uploading "requirements.txt": CommError, <Response [403]>


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113023866588871, max=1.0…

In [None]:
# graph = wandb.watch((compressor.critic_head, compressor.lm_head), log_freq=100, log="all", log_graph=True)

### RL Training Loop

In [None]:
import torch.nn.functional as F
import tqdm.auto as tqdm
from transformers import GenerationConfig

torch.set_grad_enabled(False)

# Define your generation configuration as before
generation_config = GenerationConfig(
    do_sample=True,
    num_beams=1,
    max_new_tokens=128,
    decoder_start_token_id=compressor.generation_config.decoder_start_token_id,
    eos_token_id=compressor.generation_config.eos_token_id,
    pad_token_id=compressor.generation_config.pad_token_id,
    return_dict_in_generate=True,
    output_logits=True,
)

total_decompressed_size = 0
total_compressed_size = 0

with tqdm.tqdm(data_loader) as pbar:
    for step, batch in enumerate(pbar):
        # Get the current token cost from the scheduler
        token_cost = MAX_TOKEN_COST

        input_ids = tokenizer(
            batch["text"],
            return_tensors="pt",
            padding=True,
            # TODO: Test if this has any effect
            truncation=True,
        ).input_ids.to(device)

        compressed = compressor.generate(input_ids=input_ids, generation_config=generation_config)
        decompressed = decompressor.forward(input_ids=compressed.sequences, labels=input_ids)

        full_episodes = (compressed.sequences != generation_config.eos_token_id).all(dim=-1)
        sequences_copy = compressed.sequences.clone()
        sequences_copy[..., full_episodes, -1] = generation_config.eos_token_id
        compressed.sequences = sequences_copy

        actions = compressed.sequences[..., 1:]
        # compressed.logits: [
        #      torch.tensor(shape=(B, V))
        # ]
        # (L, B, V)
        # (B, L, V)
        action_distributions = torch.stack(compressed.logits).transpose(0, 1)
        # TODO: Give the `actions` as decoder_input_ids instead
        values = compressor.forward(input_ids=input_ids, decoder_input_ids=compressed.sequences).value_predictions[..., :-1]
        action_mask = actions != generation_config.pad_token_id
        is_pad = actions == generation_config.pad_token_id
        is_eos = actions == generation_config.eos_token_id
        compressed_length = actions.size(-1) - is_pad.logical_or(is_eos).sum(dim=-1)

        losses = F.cross_entropy(
            decompressed.logits.flatten(0, -2),
            target=input_ids.flatten(),
            ignore_index=0,
            reduction="none",
        ).view(input_ids.shape)
        decompressor_loss = losses.mean()

        sequence_compression_loss = losses.detach().sum(dim=-1)
        rewards = torch.where(
            actions == generation_config.eos_token_id,
            -sequence_compression_loss.unsqueeze(-1),
            -token_cost,
        ) * action_mask * REWARD_SCALING
        # TODO: Implement temporal difference learning
        qs = rewards.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

        advantage = (qs - values) * action_mask
        num_actions = action_mask.sum()
        expected_advantage = advantage.sum() / num_actions
        critic_loss = (advantage * advantage).sum() / num_actions

        data_costs = torch.where(
            actions == generation_config.eos_token_id,
            sequence_compression_loss.unsqueeze(-1),
            MAX_TOKEN_COST,
        ) * action_mask
        compressed_size = data_costs.sum(dim=-1)
        total_compressed_size += compressed_size.sum().item()
        decompressed_size = (input_ids != 0).sum(dim=-1) * MAX_TOKEN_COST
        total_decompressed_size += decompressed_size.sum().item()
        compression_ratio = (decompressed_size / compressed_size).mean()

        action_logits = F.cross_entropy(
            action_distributions.flatten(0, -2),
            target=actions.flatten(),
            ignore_index=0,
            reduction="none",
        ).view(actions.shape)
        actor_loss = (action_logits * advantage.detach()).mean()

        compressor_loss = actor_loss + critic_loss

        pbar.set_description(f"{compression_ratio=:.2f}, {critic_loss=:.2f}, {actor_loss=:.2f}, {decompressor_loss=:.2f}")

        with torch.no_grad():
            wandb.log(
                {
                    "actor_loss": actor_loss,
                    "critic_loss": critic_loss,
                    "reward": rewards.sum(dim=-1).mean(),
                    "decompressor_loss": decompressor_loss,
                    "accuracy": (-sequence_compression_loss).exp().mean(),
                    "compressed_size": compressed_length.float().mean(),
                    "compression_ratio": compression_ratio,
                    "expected_advantage": expected_advantage,
                    "token_cost": token_cost,
                    "total_compressed_size": total_compressed_size,
                    "total_decompressed_size": total_decompressed_size,
                    "running_compression_ratio": total_decompressed_size / total_compressed_size,
                }
            )


  0%|          | 0/53444 [00:00<?, ?it/s]

In [18]:
wandb.finish()

VBox(children=(Label(value='5.613 MB of 5.613 MB uploaded (0.074 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▅▃▂▃▆▂▁▁▄▃▁▄▄▅▄▃▃▂▁▅▃▄▃▂▂▃▂▂▄▃▁▁█▂▁▁▄▁▄▃
actor_loss,▃▃▅▆▃▇▄▄▅▄▄▃▆▄▄▅▂▆▅▅▇▅▆▄▄▄▅▃▇▆▅▄▂▁█▃▂▆▅▅
compressed_size,▅▅▄▂▄▂▅▅▃▄▅▆▃▃▄▁▆▃▃▂▂▃▄▂▂▂▂▅▁▂▃▅█▇▁▅▆▂▅▄
compression_ratio,▂▁▂▁▅▄▁▂▄▁▃▂█▆▂▃▃▁▃▅▄▃▁▅▂▇▃▃▄▃▂▁▂▁▃▁▄▅▁▃
critic_loss,▄▃▁▁▄▁▂▃▂▃▂▃▂▂▁▁▅▂▂▂▁▁▁▂▃▃▂▂▁▂▂▂▃▅▁▃█▁▁▂
decompressor_loss,▃▃▃▃▂▃▁▄▄▃▃▅▁▂▂▄▅▇▃▄▃▄▂▃▄▅▄▆▄▄▂▃▄▅█▄▂▄▄▃
expected_advantage,▄▄▆▆▃▇▅▄▅▄▄▄▇▅▆▆▂▇▆▇▇▆▇▅▅▅▆▄▇▇▆▄▃▂█▄▁▆▆▆
reward,▆▅▇▄▇▅▆▆▅▄▅▅▇▆▆▂▄▇▄█▆▄▅▆▃▆▄▆▄▅▆▃▆▅▄▁▄▂▁▅
running_compression_ratio,▁█▄▃▃▃▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
token_cost,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.00253
actor_loss,0.56255
compressed_size,2.5
compression_ratio,1.84862
critic_loss,0.05163
decompressor_loss,1.1767
expected_advantage,0.20043
reward,-0.36534
running_compression_ratio,2.70787
token_cost,10.37748


### Save

In [None]:
MODEL_PATH = Path("./data/models/token-dethcod/a2c-v1-reward-norm")

In [None]:
compressor.save_pretrained(MODEL_PATH / "compressor")

In [None]:
decompressor.save_pretrained(MODEL_PATH / "decompressor")