# Token based DETHCOD

<a href="https://colab.research.google.com/github/khoda81/dethcod/blob/main/TokenDethcod.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%conda install -c conda-forge transformers wandb requests_cache datasets tqdm python-dotenv

Channels:
 - conda-forge
 - defaults
 - pytorch
Platform: linux-64
Collecting package metadata (repodata.json): - ^C

Note: you may need to restart the kernel to use updated packages.


In [4]:
import os
import wandb

try:
    from dotenv import load_dotenv
    # Load environment variables from .env file
    load_dotenv()

except ImportError as e:
    print(f"Error importing dotenv: {e}")


# Check if running in Colab
try:
    from google.colab import userdata
    # If running in Colab, use userdata.get to retrieve the token
    wandb.login(key=userdata.get('wandb_token'))

except ImportError:
    # If not in Colab, load the token from the environment variable
    wandb_token = os.getenv('WANDB_TOKEN')
    if wandb_token:
        wandb.login(key=wandb_token)
    else:
        print("W&B token not found in environment variable. Please set WANDB_TOKEN in your environment.")


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/khodabandeh/.netrc


## Download Data

In [5]:
import io
import os
import sys
import zipfile

import requests
import requests_cache
from tqdm import tqdm


zip_link = "http://www.mattmahoney.net/dc/enwik8.zip"
data_folder = "dataset"
cache_file = "download_cache"

# Ensure the data folder exists
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

# Initialize requests_cache
requests_cache.install_cache(os.path.join(data_folder, cache_file))

# Download the ZIP file with progress bar
response = requests.get(zip_link, stream=True)
response.raise_for_status()

# Get the total file size for the progress bar
total_size = int(response.headers.get("content-length", 0))

# Open the ZIP file from the content
with open(os.path.join(data_folder, "enwik8.zip"), "wb") as file:
    with tqdm(
        total=total_size, unit="B", unit_scale=True, desc="Downloading"
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            file.write(data)
            pbar.update(len(data))

# Open the cached file
with open(os.path.join(data_folder, "enwik8.zip"), "rb") as file:
    # Open the ZIP file from the content
    with zipfile.ZipFile(io.BytesIO(file.read())) as zip_file:
        # Extract all contents to the data folder
        zip_file.extractall(data_folder)

print("File downloaded and decompressed successfully.", file=sys.stderr)


Downloading: 100%|██████████████████████████████████████████████████| 36.4M/36.4M [00:00<00:00, 562MB/s]
File downloaded and decompressed successfully.


## Data

In [6]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=["dataset/enwik8"])
dataset = dataset["train"]

Generating train split: 0 examples [00:00, ? examples/s]

In [7]:
from transformers import AutoTokenizer

MODEL_ID = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [8]:
# Removing large and empty samples
MAX_LENGTH = 128

def filter_samples(example):
    tokenized = tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_LENGTH + 1,
        return_attention_mask=False,
        return_length=True,
    )

    return [
        1 < sample_length <= MAX_LENGTH
        for sample_length in tokenized.length
    ]

dataset = dataset.filter(filter_samples, batched=True)

Filter:   0%|          | 0/1128024 [00:00<?, ? examples/s]

In [9]:
import random
sample = random.choice(dataset)
print(repr(sample["text"]))

'==The ideology=='


## Model

In [10]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs


class CompressionConfig(transformers.T5Config): ...


@dataclass
class CompressionOutput(transformers.modeling_outputs.Seq2SeqLMOutput):
    value_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None


class CompressionModel(transformers.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.critic_head = nn.Linear(config.d_model, 1)
        self.critic_head.weight.data.normal_(mean=0.0, std=(1 / config.d_model))
        self.critic_head.bias.data.zero_()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], CompressionOutput]:
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if output.decoder_hidden_states is not None:
            last_hidden_state = output.decoder_hidden_states[-1]
            value_predictions = self.critic_head(last_hidden_state).squeeze(-1)
        else:
            value_predictions = None

        return CompressionOutput(
            value_predictions=value_predictions,
            logits=output.logits,
            past_key_values=output.past_key_values,
            decoder_hidden_states=output.decoder_hidden_states,
            decoder_attentions=output.decoder_attentions,
            cross_attentions=output.cross_attentions,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=output.encoder_hidden_states,
            encoder_attentions=output.encoder_attentions,
        )


In [11]:
import transformers
import transformers.modeling_outputs


class DecompressionConfig(transformers.T5Config): ...


class DecompressionModel(transformers.T5ForConditionalGeneration): ...

In [12]:
from pathlib import Path

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
MODEL_PATH = Path("./data/models/token-dethcod/a2c-v1")

### Load Model

In [13]:
LOAD_ORIGINAL = True
if LOAD_ORIGINAL:
    model_path = "google-t5/t5-small"
    print(f"Loading model from {model_path}")
    compressor = CompressionModel.from_pretrained(model_path).to(device)
    compressor.critic_head.reset_parameters()
    decompressor = DecompressionModel.from_pretrained(model_path).to(device)
else:
    compressor = CompressionModel.from_pretrained(MODEL_PATH / "compressor").to(device)
    decompressor = DecompressionModel.from_pretrained(MODEL_PATH / "decompressor").to(device)

Loading model from google-t5/t5-small


Some weights of CompressionModel were not initialized from the model checkpoint at google-t5/t5-small and are newly initialized: ['critic_head.bias', 'critic_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train

In [14]:
wandb.init(
    name = "Token Training",
    project="DETHCOD",
    config={
        "compressor_model_config": compressor.config.to_dict(),
        "decompressor_model_config": decompressor.config.to_dict(),
    },
)

In [15]:
LR = 1e-4

compressor_optimizer = torch.optim.Adam(compressor.parameters(), lr=LR)
decompressor_optimizer = torch.optim.Adam(decompressor.parameters(), lr=LR)

In [16]:
# TODO: increase batch size
BATCH_SIZE = 16
TOKEN_COST = 0.01
train_dataset = dataset
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
import torch.nn.functional as F
import tqdm.auto as tqdm
from transformers import GenerationConfig
from transformers import modeling_outputs

generation_config = GenerationConfig(
    do_sample=True,
    num_beams=1,
    max_new_tokens=100,
    decoder_start_token_id=compressor.generation_config.decoder_start_token_id,
    eos_token_id=compressor.generation_config.eos_token_id,
    pad_token_id=compressor.generation_config.pad_token_id,
    return_dict_in_generate=True,
    output_logits=True,
    # output_scores = True,
)

with tqdm.tqdm(data_loader) as pbar:
    for batch in pbar:
        input_ids = tokenizer(
            batch["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
        ).input_ids.to(device)

        compressed = compressor.generate(input_ids=input_ids, generation_config=generation_config)
        decompressed = decompressor.forward(input_ids=compressed.sequences, labels=input_ids)

        actions = compressed.sequences[..., 1:]
        action_distributions = torch.stack(compressed.logits).transpose(0, 1)
        values = compressor.forward(input_ids=input_ids, decoder_input_ids=compressed.sequences).value_predictions[..., :-1]
        action_mask = actions != generation_config.pad_token_id
        # FIX: if we hit generation limit, there would be no <eos> so this would fail
        (_, eos_indices) = (actions == generation_config.eos_token_id).nonzero(as_tuple=True)

        losses = F.cross_entropy(
            decompressed.logits.flatten(0, -2),
            target=input_ids.flatten(),
            ignore_index=0,
            reduction="none",
        ).view(input_ids.shape)
        decompressor_loss = losses.mean()

        sequence_compression_loss = losses.detach().sum(dim=-1)
        rewards = torch.where(
            actions == generation_config.eos_token_id,
            -sequence_compression_loss.unsqueeze(-1),
            -TOKEN_COST,
        ) * action_mask
        qs = rewards.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

        advantage = (qs - values) * action_mask
        critic_loss = (advantage * advantage).mean()

        action_logits = F.cross_entropy(
            action_distributions.flatten(0, -2),
            target=actions.flatten(),
            ignore_index=0,
            reduction="none",
        ).view(actions.shape)
        actor_loss = (action_logits * advantage.detach()).mean()

        pbar.set_description(f"{critic_loss=:.2f}, {actor_loss=:.2f}, {decompressor_loss=:.2f}")

        compressor_optimizer.zero_grad()
        decompressor_optimizer.zero_grad()

        (actor_loss + critic_loss).backward()
        decompressor_loss.backward()

        compressor_optimizer.step()
        decompressor_optimizer.step()

        with torch.no_grad():
            wandb.log(
                {
                    "actor_loss": actor_loss,
                    "critic_loss": critic_loss,
                    "reward": rewards.sum(dim=-1).mean(),
                    "decompressor_loss": decompressor_loss,
                    "accuracy": (-losses.sum(dim=-1)).exp().mean(),
                    "compressed_size": eos_indices.float().mean(),
                }
            )

  0%|          | 0/53444 [00:00<?, ?it/s]

In [18]:
wandb.finish()

VBox(children=(Label(value='0.275 MB of 0.275 MB uploaded (0.042 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▅▆▃▆▅▅▅█▂█▄▆▃▆▅▅▂▄▅▇▅▅▇▄▁▆▄▅▄▇█▄▇▇▆▂▇▆▆▆
actor_loss,█▄▂▂▁▂▂▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
compressed_size,▂▁ █
critic_loss,█▂▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
decompressor_loss,▄▃▇▃▄▄▃▁▇▁▆▃▇▂▄▃▇▇▄▁▄▄▂▄▇▄▅▄▅▂▁▅▂▂▃█▂▄▃▃
reward,▂▇██▁███████████████████████████████████

0,1
accuracy,0.81237
actor_loss,0.00544
compressed_size,
critic_loss,1e-05
decompressor_loss,0.66343
reward,-1.0


### Save

In [21]:
compressor.save_pretrained(MODEL_PATH / "compressor")

In [22]:
decompressor.save_pretrained(MODEL_PATH / "decompressor")

## Playground

In [145]:
import random

sample = random.choice(dataset)
print(repr(sample["text"]))

input_ids = tokenizer(
    sample["text"],
    return_tensors="pt",
    padding=True,
    truncation=True,
).input_ids.to(device)

compressed = compressor.generate(
    input_ids=input_ids,
    generation_config=generation_config,
)

with torch.no_grad():
    compression_output = compressor.forward(
        input_ids=input_ids,
        decoder_input_ids=compressed,
    )

with torch.no_grad():
    decompressed = decompressor.forward(
        input_ids=compressed,
        labels=input_ids,
    )

num_ids = decompressed.logits.size(-1)
losses = F.cross_entropy(
    decompressed.logits.view(-1, num_ids),
    target=input_ids.view(-1),
    ignore_index=0,
    reduction='none',
)

preds = torch.argmax(decompressed.logits, dim=-1)

TOKEN_COST = 1
len_compressed = compressed.shape[1]
reward = -TOKEN_COST * len_compressed - losses.detach().sum()

value = compression_output.value_predictions.squeeze(-1)
value = value[..., :-1]
Q = torch.ones_like(value) * reward

advantage = Q - value
critic_loss = torch.nn.functional.mse_loss(value, Q, reduction='mean')

num_ids = compression_output.logits.size(-1)
# TODO: add a negative sign if didn't work
action_logits = F.cross_entropy(
    compression_output.logits[:, :-1].view(-1, num_ids),
    target=compressed[:, 1:].view(-1),
    ignore_index=0,
    reduction='none',
)

actor_loss = (advantage.detach() * action_logits).mean()
decompressor_loss = losses.mean()

print(f"reward = {reward.item():.2f}")
print(f"values = {value[0].double().round(decimals=2).tolist()}")
print(f"advantages = {advantage[0].double().round(decimals=2).tolist()}")

"|align=&quot;left&quot;|''[[Medúlla]]''"
reward = -42.99
values = [-31.68]
advantages = [-11.31]


In [17]:
tokenizer.decode(compressed[0])

'<pad></s>'

[-33.48]

In [52]:
action_logits

tensor([-0.], device='cuda:0')

In [54]:
compressed[0, 1] = 4

In [76]:
values, indices = compression_output.logits[0, -1].sort(descending=True)

In [78]:
indices

tensor([    1,     3,   183,  ..., 26948, 31655, 31245], device='cuda:0')

In [62]:
F.cross_entropy(
    compression_output.logits[:, :-1, :].view(-1, num_ids),
    target=compressed[:, 1:].flatten(),
    ignore_index=0,
    reduction='none',
) * advantage.flatten()

tensor([-1923.3522], device='cuda:0')

In [None]:
compression_output.keys()

In [None]:
len(action_logits)

In [None]:
compressed

In [None]:
sample["text"]

In [None]:
advantage

In [None]:
losses

In [None]:
reward

In [None]:
len_compressed

In [None]:
advantage

In [None]:
actor_loss

In [146]:
critic_loss

tensor(128.0074, device='cuda:0')