## Byte Pair Encoding

In [1]:
text = """
End-to-end autonomous driving, which involves learning a neural planner with raw sensor inputs, is considered a promising direction to achieve full autonomy.
Despite the promising progress in this field, recent studies have exposed multiple vulnerabilities and limitations of imitation learning (IL) methods, particularly the inherent issues in open-loop evaluation, such as the dysfunctional metrics and implicit biases.
This is critical as it fails to guarantee safety, efficiency, comfort, and compliance with traffic rules.
To address this main limitation, several works have proposed incorporating closed-loop metrics, which more effectively evaluate end-to-end autonomous driving by ensuring that the machine-learned planner meets essential criteria beyond merely mimicking human drivers.
Therefore, end-to-end planning is ideally a multi-target and multimodal task, where multi-target planning involves meeting various evaluation metrics from either open-loop and closed-loop settings.
In this context, multimodal indicates the existence of multiple optimal solutions for each metric.
Existing end-to-end approaches often try to consider closed-loop evaluation via post-processing, which is not streamlined and may result in the loss of additional information compared to a fully end-to-end pipeline.
Meanwhile, rule-based planners struggle with imperfect perception inputs.
These imperfect inputs degrade the performance of rule-based planning under both closed-loop and open-loop metrics, as they rely on predicted perception instead of ground truth (GT) labels.
To address the issues, we propose a novel end-to-end autonomous driving framework called Hydra-MDP (Multimodal Planning with Multi-target Hydra-distillation).
HydraMDP is based on a novel teacher-student knowledge distillation (KD) architecture. The student model learns diverse trajectory candidates tailored to various evaluation metrics
through KD from both human and rule-based teachers. We instantiate the multi-target Hydra-distillation with a multihead decoder, thus effectively integrating the knowledge from specialized teachers. Hydra-MDP also features an extendable KD architecture, allowing for easy integration of additional teachers.
The student model uses environmental observations during training, while the teacher models use ground truth (GT) data. This setup allows the teacher models to generate better planning predictions, helping the student model to learn effectively. By training the student model with environmental observations, it becomes adept at handling realistic conditions where GT perception is not accessible during testing.
Our contributions are summarized as follows:
1. We propose a universal framework of end-to-end multimodal planning via multi-target hydra-distillation, allowing the model to learn from both rule-based planners and
human drivers in a scalable manner.
2. Our approach achieves the state-of-the-art performance
under the simulation-based evaluation metrics on Navsim.
"""

In [2]:
tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print("length:", len(text))
print('---')
print("length:", len(tokens))
ids = tokens

length: 2999
---
length: 2999


In [3]:
def get_frequency(tokens):
    counts  = {}

    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [4]:
counts = get_frequency(tokens)

In [5]:
top_pair = max(counts, key=counts.get)
top_pair

(116, 105)

In [6]:
def replace_tokens(tokens, pair, new_token):
  new_tokens = []
  i = 0
  while i < len(tokens):
    if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
      new_tokens.append(new_token)
      i += 2
    else:
      new_tokens.append(tokens[i])
      i += 1
  return new_tokens

In [7]:
vocab_size = 276 
num_merges = vocab_size - 256

In [8]:
merges = {} 

for i in range(num_merges):

  stats = get_frequency(tokens)
  pair = max(stats, key=stats.get)
  
  new_token = 256 + i
  
  print(f"merging {pair} into a new token {new_token}")
  tokens = replace_tokens(tokens, pair, new_token)
  merges[pair] = new_token

merging (116, 105) into a new token 256
merging (115, 32) into a new token 257
merging (101, 32) into a new token 258
merging (105, 110) into a new token 259
merging (111, 110) into a new token 260
merging (101, 114) into a new token 261
merging (116, 104) into a new token 262
merging (100, 32) into a new token 263
merging (101, 110) into a new token 264
merging (97, 110) into a new token 265
merging (97, 108) into a new token 266
merging (256, 260) into a new token 267
merging (103, 32) into a new token 268
merging (116, 32) into a new token 269
merging (44, 32) into a new token 270
merging (117, 108) into a new token 271
merging (97, 114) into a new token 272
merging (259, 268) into a new token 273
merging (99, 104) into a new token 274
merging (101, 108) into a new token 275


In [9]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [20]:
def decode(ids):
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([129, 45, 267]))

�-tion


In [23]:
def encode(text):
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_frequency(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = replace_tokens(tokens, pair, idx)
  return tokens

print(encode("th"))

[262]
