# Token based DETHCOD

In [3]:
%conda install -c conda-forge transformers wandb requests_cache datasets tqdm python-dotenv

Channels:
 - conda-forge
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.


Note: you may need to restart the kernel to use updated packages.


## Download Data

In [4]:
import io
import os
import sys
import zipfile

import requests
import requests_cache
from tqdm import tqdm


zip_link = "http://www.mattmahoney.net/dc/enwik8.zip"
data_folder = "dataset"
cache_file = "download_cache"

# Ensure the data folder exists
if not os.path.exists(data_folder):
    os.makedirs(data_folder)

# Initialize requests_cache
requests_cache.install_cache(os.path.join(data_folder, cache_file))

# Download the ZIP file with progress bar
response = requests.get(zip_link, stream=True)
response.raise_for_status()

# Get the total file size for the progress bar
total_size = int(response.headers.get("content-length", 0))

# Open the ZIP file from the content
with open(os.path.join(data_folder, "enwik8.zip"), "wb") as file:
    with tqdm(
        total=total_size, unit="B", unit_scale=True, desc="Downloading"
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            file.write(data)
            pbar.update(len(data))

# Open the cached file
with open(os.path.join(data_folder, "enwik8.zip"), "rb") as file:
    # Open the ZIP file from the content
    with zipfile.ZipFile(io.BytesIO(file.read())) as zip_file:
        # Extract all contents to the data folder
        zip_file.extractall(data_folder)

print("File downloaded and decompressed successfully.", file=sys.stderr)


Downloading: 100%|████████████████████████████████████████| 36.4M/36.4M [00:00<00:00, 655MB/s]
File downloaded and decompressed successfully.


## Data

In [1]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=["dataset/enwik8"])
dataset = dataset["train"]

In [2]:
from transformers import AutoTokenizer

MODEL_ID = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [3]:
# Removing large and empty samples
MAX_LENGTH = 128

def filter_samples(example):
    tokenized = tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_LENGTH + 1,
        return_attention_mask=False,
        return_length=True,
    )

    return [
        1 < sample_length <= MAX_LENGTH
        for sample_length in tokenized.length
    ]

dataset = dataset.filter(filter_samples, batched=True)

In [4]:
import random
sample = random.choice(dataset)
print(repr(sample["text"]))

'== Gestapo counterintelligence =='


## Model

In [5]:
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs


class CompressionConfig(transformers.T5Config): ...


@dataclass
class CompressionOutput(transformers.modeling_outputs.Seq2SeqLMOutput):
    value_predictions: Optional[Tuple[torch.FloatTensor, ...]] = None


class CompressionModel(transformers.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.critic_head = nn.Linear(config.d_model, 1)
        self.critic_head.weight.data.normal_(mean=0.0, std=(1 / config.d_model))
        self.critic_head.bias.data.zero_()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], CompressionOutput]:
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if output.decoder_hidden_states is not None:
            last_hidden_state = output.decoder_hidden_states[-1]
            value_predictions = self.critic_head(last_hidden_state).squeeze(-1)
        else:
            value_predictions = None

        return CompressionOutput(
            value_predictions=value_predictions,
            logits=output.logits,
            past_key_values=output.past_key_values,
            decoder_hidden_states=output.decoder_hidden_states,
            decoder_attentions=output.decoder_attentions,
            cross_attentions=output.cross_attentions,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=output.encoder_hidden_states,
            encoder_attentions=output.encoder_attentions,
        )


In [6]:
import transformers
import transformers.modeling_outputs


class DecompressionConfig(transformers.T5Config): ...


class DecompressionModel(transformers.T5ForConditionalGeneration): ...

In [7]:
from pathlib import Path

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PATH = Path("./data/models/token-dethcod/a2c-v1")

### Load Model

In [8]:
LOAD_LATEST = False

if LOAD_LATEST:
    compressor = CompressionModel.from_pretrained(MODEL_PATH / "compressor").to(device)
    decompressor = DecompressionModel.from_pretrained(MODEL_PATH / "decompressor").to(device)

else:
    model_path = "google-t5/t5-small"
    print(f"Loading model from {model_path}")
    compressor = CompressionModel.from_pretrained(model_path).to(device)
    compressor.critic_head.reset_parameters()
    decompressor = DecompressionModel.from_pretrained(model_path).to(device)

Loading model from google-t5/t5-small


Some weights of CompressionModel were not initialized from the model checkpoint at google-t5/t5-small and are newly initialized: ['critic_head.bias', 'critic_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train

In [9]:
import os
import wandb

try:
    from dotenv import load_dotenv
    # Load environment variables from .env file
    load_dotenv()

except ImportError as e:
    print(f"Error importing dotenv: {e}")


# Check if running in Colab
try:
    from google.colab import userdata
    # If running in Colab, use userdata.get to retrieve the token
    wandb.login(key=userdata.get('wandb_token'))

except ImportError:
    # If not in Colab, load the token from the environment variable
    wandb_token = os.getenv('WANDB_TOKEN')
    if wandb_token:
        wandb.login(key=wandb_token, relogin=True)
    else:
        print("W&B token not found in environment variable. Please set WANDB_TOKEN in your environment.")


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/khodabandeh/.netrc


In [10]:
COMPRESSOR_LR = 1e-3
DECOMPRESSOR_LR = 1e-3
# CRITIC_BIAS_LR = 0.1

# # Create parameter groups
# param_groups = [
#     {"params": [param for name, param in compressor.named_parameters() if name != "critic_head.bias"], "lr": LR},
#     {"params": [compressor.critic_head.bias], "lr": CRITIC_BIAS_LR},
# ]

# # Define optimizer with parameter groups
# compressor_optimizer = torch.optim.Adam(param_groups)

compressor_optimizer = torch.optim.Adam(compressor.parameters(), lr=COMPRESSOR_LR)
decompressor_optimizer = torch.optim.Adam(decompressor.parameters(), lr=DECOMPRESSOR_LR)

In [11]:
import math

BATCH_SIZE = 8
MAX_TOKEN_COST = math.log(compressor.config.vocab_size)

train_dataset = dataset
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

SCHEDULING_STEPS = len(data_loader) * 1.0e-2 # Schedule over 30% of an epoch
PRETRAINING_STEPS = len(data_loader) * 2.0e-2 # Schedule over 10% of an epoch

In [12]:
import wandb

wandb.init(
    name = "Token Training",
    project="DETHCOD",
    config={
        "compressor_model_config": compressor.config.to_dict(),
        "decompressor_model_config": decompressor.config.to_dict(),
        # TODO: Add other parameters
    },
)

[34m[1mwandb[0m: Currently logged in as: [33maxiom[0m ([33mchihuahuas[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111310427594516, max=1.0)…

In [13]:
class TokenCostScheduler:
    def __init__(self, total_steps, max_token_cost, schedule_fn=None):
        self.total_steps = total_steps
        self.max_token_cost = max_token_cost
        self.step_count = 0

        linear_schedule = lambda self: min(self.step_count / self.total_steps, 1.0) * self.max_token_cost
        # If no schedule function is provided, default to linear schedule
        self.schedule_fn = schedule_fn if schedule_fn else linear_schedule

    def get_token_cost(self):
        # Get the current token cost based on the schedule
        token_cost = self.schedule_fn(self)
        self.step_count += 1  # Increment the step count
        return token_cost

In [14]:
graph = wandb.watch((compressor.critic_head, compressor.lm_head), log_freq=100, log="all", log_graph=True)

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


### RL Training Loop

In [None]:
import torch.nn.functional as F
import tqdm.auto as tqdm
from transformers import GenerationConfig
from transformers import modeling_outputs


# Define your generation configuration as before
generation_config = GenerationConfig(
    do_sample=True,
    num_beams=1,
    max_new_tokens=128,
    decoder_start_token_id=compressor.generation_config.decoder_start_token_id,
    eos_token_id=compressor.generation_config.eos_token_id,
    pad_token_id=compressor.generation_config.pad_token_id,
    return_dict_in_generate=True,
    output_logits=True,
)

# Initialize the scheduler
token_cost_scheduler = TokenCostScheduler(total_steps=SCHEDULING_STEPS, max_token_cost=MAX_TOKEN_COST)

with tqdm.tqdm(data_loader) as pbar:
    for step, batch in enumerate(pbar):
        # Get the current token cost from the scheduler
        token_cost = token_cost_scheduler.get_token_cost()

        input_ids = tokenizer(
            batch["text"],
            return_tensors="pt",
            padding=True,
            # TODO: Test if this has any effect
            truncation=True,
        ).input_ids.to(device)

        compressed = compressor.generate(input_ids=input_ids, generation_config=generation_config)
        decompressed = decompressor.forward(input_ids=compressed.sequences, labels=input_ids)

        # Force last token to be eos for episodes with no eos (terminated by max_len)
        full_episodes = (compressed.sequences != generation_config.eos_token_id).all(dim=-1)
        sequences_copy = compressed.sequences.clone()
        sequences_copy[..., full_episodes, -1] = generation_config.eos_token_id
        compressed.sequences = sequences_copy

        actions = compressed.sequences[..., 1:]
        # compressed.logits: [
        #      torch.tensor(shape=(B, V))
        # ]
        # (L, B, V)
        # (B, L, V)
        action_distributions = torch.stack(compressed.logits).transpose(0, 1)
        # TODO: Give the `actions` as decoder_input_ids instead
        values = compressor.forward(input_ids=input_ids, decoder_input_ids=compressed.sequences).value_predictions[..., :-1]
        action_mask = actions != generation_config.pad_token_id
        is_pad = actions == generation_config.pad_token_id
        is_eos = actions == generation_config.eos_token_id
        compressed_length = actions.size(-1) - is_pad.logical_or(is_eos).sum(dim=-1)

        losses = F.cross_entropy(
            decompressed.logits.flatten(0, -2),
            target=input_ids.flatten(),
            ignore_index=0,
            reduction="none",
        ).view(input_ids.shape)
        decompressor_loss = losses.mean()

        sequence_compression_loss = losses.detach().sum(dim=-1)
        rewards = torch.where(
            actions == generation_config.eos_token_id,
            -sequence_compression_loss.unsqueeze(-1),
            -token_cost,
        ) * action_mask * 0.01
        qs = rewards.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

        advantage = (qs - values) * action_mask
        masked_advantage = advantage[action_mask]
        critic_loss = (masked_advantage * masked_advantage).mean()

        compressed_size = (action_mask.sum(dim=-1) - 1) * MAX_TOKEN_COST + sequence_compression_loss
        decompressed_size = ((input_ids != 0).sum(dim=-1) - 1) * MAX_TOKEN_COST
        compression_ratio = (decompressed_size / compressed_size).mean()

        if step < PRETRAINING_STEPS:
            # Train the model to generate the original sequence
            actor_loss = super(CompressionModel, compressor).forward(input_ids=input_ids, labels=input_ids).loss

        else:
            # [x] | x \in R
            # b = -ln(\sigma e^x)
            # norm = [x + b][action]
            # al = x[action] - ln(sigma(e^x))
            #    = ln(e^x[action]) - ln(sigma(e^x))
            #    = ln(e^x[action]/sigma(e^x))

            # cross entropy = -ln(e^x[action]/sigma(e^x))
            action_logits = F.cross_entropy(
                action_distributions.flatten(0, -2),
                target=actions.flatten(),
                ignore_index=0,
                reduction="none",
            ).view(actions.shape)
            actor_loss = (action_logits * advantage.detach()).mean()

        compressor_loss = actor_loss + critic_loss

        pbar.set_description(f"{compression_ratio=:.2f}, {critic_loss=:.2f}, {actor_loss=:.2f}, {decompressor_loss=:.2f}")

        compressor_optimizer.zero_grad()
        compressor_loss.backward()
        compressor_optimizer.step()

        decompressor_optimizer.zero_grad()
        decompressor_loss.backward()
        decompressor_optimizer.step()

        with torch.no_grad():
            wandb.log(
                {
                    "actor_loss": actor_loss,
                    "critic_loss": critic_loss,
                    "reward": rewards.sum(dim=-1).mean(),
                    "decompressor_loss": decompressor_loss,
                    "accuracy": (-sequence_compression_loss).exp().mean(),
                    "compressed_size": compressed_length.float().mean(),
                    "compression_ratio": compression_ratio,
                    "expected_advantage": masked_advantage.mean(),
                    "advantage_std": masked_advantage.std(),
                    "advantage": masked_advantage,
                    "token_cost": token_cost,
                }
            )


  0%|          | 0/106887 [00:00<?, ?it/s]

In [None]:
wandb.finish()

### Save

In [21]:
compressor.save_pretrained(MODEL_PATH / "compressor-v2")

In [22]:
decompressor.save_pretrained(MODEL_PATH / "decompressor-v2")

## Playground

In [None]:
import torch
import matplotlib.pyplot as plt

top_token_count = 100

# Assuming `action_distributions` is the tensor of shape [100, 32128]
logits = action_distributions[1].detach().cpu()  # Ensure it's on the CPU

# Step 1: Average the logits across the first axis (dimension 0)
avg_logits = torch.mean(logits, dim=0)

# Step 2: Get the top 50 tokens based on average logit values
top_values, top_indices = torch.topk(avg_logits, top_token_count)

# Step 3: Convert the top indices to tokens using the tokenizer
top_tokens = tokenizer.convert_ids_to_tokens(top_indices.numpy())

# Step 4: Plot the top 50 logits using imshow with tokens as labels
plt.figure(figsize=(10, 2))
plt.imshow(logits[..., top_indices].numpy(), cmap='viridis', aspect='auto', interpolation="nearest")
plt.colorbar(label='Logit Value')
plt.yticks([])  # Hide y-axis as we only have one row
plt.xticks(range(top_token_count), top_tokens, rotation='vertical')
plt.title('Top 50 Tokens by Average Logit')
plt.show()


In [None]:
advantage

In [None]:
values

In [None]:
val_tmp = values.detach()

In [None]:
bias = nn.Parameter(torch.tensor(0.0, device=device))
optim_tmp = torch.optim.Adam(params=[bias])

In [None]:
# optim_tmp.param_groups[0]['betas'] = (0.99, 0.5)
optim_tmp.param_groups[0]['lr'] = 0.1

In [None]:
with tqdm.tqdm(range(10000)) as pbar:
    for _ in pbar:
        advantage = (qs - bias+val_tmp) * action_mask
        num_actions = action_mask.sum()
        expected_advantage = advantage.sum() / num_actions
        critic_loss = (advantage * advantage).sum() / num_actions

        optim_tmp.zero_grad()
        critic_loss.backward()
        optim_tmp.step()

        pbar.set_postfix({
            "critic_loss": critic_loss.item(),
            "bias": bias.item(),
            "E(adv)": expected_advantage.item(),
        })

In [None]:
advantage

In [None]:
action_mask.sum()

In [None]:
action_logits

In [None]:
values[0], qs[0]

In [None]:
action_distributions

In [None]:
plt.plot(advantage[0].cpu().detach())

In [None]:
import random

sample = random.choice(dataset)
print(repr(sample["text"]))

input_ids = tokenizer(
    batch["text"],
    return_tensors="pt",
    padding=True,
    truncation=True,
).input_ids.to(device)

with torch.no_grad():
    compressed = compressor.generate(input_ids=input_ids, generation_config=generation_config)
    print(repr(tokenizer.decode(compressed.sequences[0])))
    decompressed = decompressor.forward(input_ids=compressed.sequences, labels=input_ids)

actions = compressed.sequences[..., 1:]
action_distributions = torch.stack(compressed.logits).transpose(0, 1)
values = compressor.forward(input_ids=input_ids, decoder_input_ids=compressed.sequences).value_predictions[..., :-1]
action_mask = actions != generation_config.pad_token_id
is_pad = actions == generation_config.pad_token_id
is_eos = actions == generation_config.eos_token_id
compressed_length = actions.size(-1) - is_pad.logical_or(is_eos).sum(dim=-1)

losses = F.cross_entropy(
    decompressed.logits.flatten(0, -2),
    target=input_ids.flatten(),
    ignore_index=0,
    reduction="none",
).view(input_ids.shape)
decompressor_loss = losses.mean()

sequence_compression_loss = losses.detach().sum(dim=-1)
rewards = torch.where(
    actions == generation_config.eos_token_id,
    -sequence_compression_loss.unsqueeze(-1),
    -TOKEN_COST,
) * action_mask
qs = rewards.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])

advantage = (qs - values) * action_mask
critic_loss = (advantage * advantage).mean()

action_logits = F.cross_entropy(
    action_distributions.flatten(0, -2),
    target=actions.flatten(),
    ignore_index=0,
    reduction="none",
).view(actions.shape)
actor_loss = (action_logits * advantage.detach()).mean()

print(f"actor_loss={actor_loss}")
print(f"critic_loss={critic_loss}")
print(f"reward={rewards.sum(dim=-1).mean()}")
print(f"decompressor_loss={decompressor_loss}")
print(f"accuracy={(-losses.sum(dim=-1)).exp().mean()}")
print(f"compressed_size={compressed_length.float().mean()}")

In [None]:
actions[2][4] = 1

In [None]:
actions

In [None]:
_61.tolist()

In [None]:
tokenizer.decode(compressed[0])

In [None]:
action_logits

In [None]:
compressed[0, 1] = 4

In [None]:
values, indices = compression_output.logits[0, -1].sort(descending=True)

In [None]:
indices

In [None]:
F.cross_entropy(
    compression_output.logits[:, :-1, :].view(-1, num_ids),
    target=compressed[:, 1:].flatten(),
    ignore_index=0,
    reduction='none',
) * advantage.flatten()

In [None]:
compression_output.keys()

In [None]:
len(action_logits)

In [None]:
compressed

In [None]:
sample["text"]

In [None]:
advantage

In [None]:
losses

In [None]:
reward

In [None]:
len_compressed

In [None]:
advantage

In [None]:
actor_loss

In [None]:
critic_loss