In [5]:
import json
from collections import Counter, defaultdict
import logging
from collections.abc import Iterable, Iterator
#from heapq import heapify_max, heappush_max, heappop_max  # requires python3.14
from heapq import heapify, heappush, heappop  # before python3.14
import math
from itertools import chain
#from .log import get_logger  # logging logic in local module
#from .pretokenizer import PRE_TOKENIZE_PAT
import regex as re


#log = get_logger("bpe", level=logging.DEBUG)

UTF8 = "utf8"

PRE_TOKENIZE_PAT = re.compile(
    r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)

def count_pretokens(txt: str, counter: Counter[str]):
    "Pretokenize txt and count pretokens w/ counter."
    # makes no sense to consider overlapped matches as it will produce
    # tons of duplicated pre-tokens.
    for m in re.finditer(PRE_TOKENIZE_PAT, txt):
        counter[txt[m.start() : m.end()]] += 1

In [27]:
type(PRE_TOKENIZE_PAT.pattern)

str

In [34]:
d = { (b't', b'h'): 3, (b'st', 't'): 3, (b'q', b' '): 2 }

max(d, key=lambda p: PairCountSortKey(pair=p, cnt=d[p]))

(b't', b'h')

In [47]:
d = {}
v = d.get('k', [])
v.append('foo')
v1 = d.get('bar', [])
print(v, v1)

['foo'] []


In [43]:
help(d.get)

Help on built-in function get:

get(key, default=None, /) method of builtins.dict instance
    Return the value for key if key is in the dictionary, else default.



# Scratch: Baseline BPE & WIP optimization

In [3]:
class BpeIterationStates:
    """
    Manage the state of a single BPE iteartion run. It does following:
        - Keeps the mapping from each token pair found to following:
            - the pair's running count, to find out the final pair to create
              new token.
            - Set of pre-token(s) where the pair is found. To efficiently
              update the mapping of pre-tokens to their count to reflect the
              presence of new token after its creation.
        - Keeps the most frequent token pair(s) found during a single BPE
          iteration run. To find the final pair to create new token.

    (IMO the pretoken -> count mapping shall be part of this state as well,
    esp. to adopt the bpe optimization idea in assignment)
    """

    def __init__(self) -> None:
        self.counter: Counter[tuple[bytes, bytes]] = Counter()
        self.pair_to_pretokens: DefaultDict[
            tuple[bytes, bytes], set[tuple[bytes, ...]]
        ] = defaultdict(set)
        self.max_cnt = 0
        self.most_pairs: list[tuple[bytes, bytes]] = []

    def update(
        self, pair: tuple[bytes, bytes], pretoken: tuple[bytes, ...], pretoken_cnt: int
    ):
        """
        Update the state w/ given the token pair, the pretoken where the pair is
        found, and the count of pretoken in text corpus.

        Spec:
            Increase the pair's running count by pretoken_cnt.
            Save pair -> pretoken mapping to pair_to_pretokens.
            Compare the pair's running cnt w/ self.max_cnt:
            - If cnt < self.max_cnt, nop.
            - If cnt == self.max_cnt, append pair to self.most_pairs.
            - If cnt > self.max_cnt, set self.max_cnt = cnt and set
                self.most_pairs to only contain pair.
        """
        self.counter[pair] += pretoken_cnt
        self.pair_to_pretokens[pair].add(pretoken)
        cnt = self.counter[pair]
        if cnt == self.max_cnt:
            self.most_pairs.append(pair)
        elif cnt > self.max_cnt:
            self.most_pairs = [pair]
            self.max_cnt = cnt

    def pair_to_merge(self) -> tuple[bytes, bytes] | None:
        """
        Returns pair of highest count and highest lexical order, or None if no
        pair found. Only call this after BPE iteration run finishes.
        """
        return max(self.most_pairs, default=None)


def bpe_baseline(
    pretokens: dict[str, int], vocab_size: int, special_tokens: list[str]
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    TODO:
    - UTs

    Train a BPE tokenizer from pre-tokenized input.

    Spec:
        1. Read pre-tokenized data from input_path, which is in form of a json
            object whose key is pre-token and value is the pre-token's frequency in
            text corpus, into a dict[tuple[bytes], int]
        2. Initialize vocabulary as a dict[int, bytes], w/ int key 0 - 255 maps to
            numerically identical 1-byte bytes value. Then assign new IDs
            starting from 256 for each given special token and add
            corresponding mapping to vocabulary. Initialize list[tuple[bytes,
            bytes]] to store merges resulted from BPE.

            At this point:
                1. id for next new token = 256 + len(special_tokens).
                2. # remaining available slots in vocabulary = vocab_size - id_of_next_new_token
        3. Build up vocabulary by running BPE till vocabular size hits
            vocab_size iteratively. Start iteration from pairing adjacent byte-level
            tokens present in each pre token (indiviaul key from step 1). Use
            a Counter to manage the running frequency count of each pair found. Each
            iteration is to find the token pair w/ the highest frequency, and
            when there are multiple pairs w/ the same highest frequency, break
            tie by choosing the pair w/ highest lexico order. Specifically:
                For each pre-token pt:
                    Start from byte index 0, pick pair of adjacent byte-level
                    tokens and increment its count in Counter.
                    Need a way to record the mapping from a pair to pre-tokens
                    which contain such pair, so that later on we can find these
                    pre-tokens in O(1) instead of blindly visiting all
                    pre-tokens. Need a mapping dict[bytes, set[pre-tokens that
                    contains the pair]]

            For efficiency, Counter doesn't
            have an eas way to get keys whose count is the highest so will need
            new logic to track those keys if we want to avoid repeatedly
            sorting Counter's entries to find them out. Must to track all pairs
            w/ highest frequency as future pairs of highest frequency are built
            on top of them.

            Once we find the pair of highest freq and break the tie, give it a
            new ID and add it to vocabulary, also add the pair to merges list.

    """
    token_seq_cnts = {
        tuple(bytes([b]) for b in pt.encode(UTF8)): cnt for pt, cnt in pretokens.items()
    }
    vocab = {i: bytes([i]) for i in range(256)}
    merges: list[tuple[bytes, bytes]] = []
    for t in special_tokens:
        # len(vocab) is the id of next new token
        new_token_id = len(vocab)
        vocab[new_token_id] = t.encode(UTF8)

    # NOTE for bpe we only concern about tokens created from merging existing
    # ones.
    while len(vocab) < vocab_size:
        state = BpeIterationStates()
        for tokens, pretoken_cnt in token_seq_cnts.items():
            # Ignore 1-byte pretoken
            if len(tokens) == 1:
                continue
            # Current pretoken contain > 1 tokens. Iterate each token and
            # collect pair of itself and its successor as merge candidate
            for p in zip(tokens, tokens[1:]):
                state.update(p, tokens, pretoken_cnt)

        p_to_merge = state.pair_to_merge()
        if p_to_merge is None:
            log.warn("Pair to merge is not found. This shall not happen!")
            break
        new_token = b"".join(p_to_merge)
        # NOTE check whether the new token already existed in vocab?
        new_token_id = len(vocab)
        vocab[new_token_id] = new_token
        merges.append(p_to_merge)
        print(
            f"Merging pair {p_to_merge} of count {state.counter[p_to_merge]} to new token {new_token_id}"
        )
        # update token_seq_cnts to reflect the presence of new token
        pretokens_to_update = state.pair_to_pretokens[p_to_merge]
        for pretoken in pretokens_to_update:
            # keep its count in token_seq_cnts
            cnt = token_seq_cnts.pop(pretoken)
            # replace all non-overlapping occurrences of token pair in pretoken w/ new token
            idx, ln, updated_pretoken = 0, len(pretoken), []
            while idx < ln:
                if (
                    idx < ln - 1
                    and pretoken[idx] == p_to_merge[0]
                    and pretoken[idx + 1] == p_to_merge[1]
                ):
                    updated_pretoken.append(new_token)
                    idx += 2
                else:
                    updated_pretoken.append(pretoken[idx])
                    idx += 1
            # save mapping updated_pretoken -> cnt back to token_seq_cnts
            token_seq_cnts[tuple(updated_pretoken)] = cnt

    return vocab, merges


# Functional BPE impl expected to be more efficient than baseline, but only provided mediocre speedup

Due to bottleneck of blindly repeatedly sorting all known pair counts.

In [1]:
'''
Correct, functional impl of time-optimized BPE.
'''
def bpe_time_suboptimal(
    pretokens: dict[str, int], vocab_size: int, special_tokens: list[str]
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:

    # Need a mapping from pretoken str -> pretoken's tokens. NOTE the value can change
    # as we create new merged tokens w/ BPE.
    class Pretoken:
        def __init__(self, seq, cnt):
            '''
            seq: Token sequence representing the pretoken.
            cnt: Count of pretoken in corpus.
            '''
            self.seq = seq
            self.cnt = cnt

        def __repr__(self) -> str:
            return str((self.seq, self.cnt))

    pretoken_info: dict[str, Pretoken] = {}
    for pt, cnt in pretokens.items():
        seq = tuple(bytes([b]) for b in pt.encode(UTF8))
        # BPE merge needs at least 2 tokens
        if len(seq) < 2:
            continue
        pretoken_info[pt] = Pretoken(seq, cnt)

    vocab = {i: bytes([i]) for i in range(256)}
    merges: list[tuple[bytes, bytes]] = []
    for t in special_tokens:
        # len(vocab) is the id of next new token
        new_token_id = len(vocab)
        vocab[new_token_id] = t.encode(UTF8)

    # NOTE for bpe we only concern about tokens created from merging existing ones.
    pair_cnts: Counter[tuple[bytes, bytes]] = Counter()
    # token pair -> pretokens which contain such pair
    pair_to_pretokens: DefaultDict[tuple[bytes, bytes], set[str]] = (
        defaultdict(set)
    )

    # iterative BPE runs
    merged_p: tuple[bytes, bytes] = None
    while len(vocab) < vocab_size:
        if merged_p is None:
            # Initial condition: Haven't identify any merged token
            # run a full pass of pretokens to collect initial byte-token pairs
            for pt, v in pretoken_info.items():
                # Assumption: pt contains > 1 tokens. Iterate each token and
                # collect pair of itself and its successor as merge candidate
                for p in zip(v.seq, v.seq[1:]):
                    pair_cnts[p] += v.cnt # CHIU increment by pretoken's count instead of 1!
                    pair_to_pretokens[p].add(pt)
        elif len(pair_cnts) == 0:
            # Unlikely to happen; But if this was true then it means there are no more
            # new merged token to be created. So log and exit loop
            print(f'Cannot merge further as token pairs have run out. Actual vocab size: {len(vocab)} Expected: {vocab_size}')
            break
        else:
            # Previous iteration has identified a merged token
            # Find candidates of the merged pair (to be identified in current iteration),
            # which can be:
            # - A pair already in pair_cnts, OR
            # - A new pair resulted from creation of merged token from previous iteration
            # and update mapping pair -> count and pair -> pretokens accordingly
            for pt in pair_to_pretokens[merged_p]:
                v = pretoken_info[pt]
                updated_token_seq(pt, v, merged_p, pair_cnts, pair_to_pretokens)
            # Drop merged_p's entry from pair_to_pretoken as it is longer needed
            pair_to_pretokens.pop(merged_p)

        # FIXME: profiling w/ test data from UT shows sorting logic below is the efficiency bottleneck.
        merged_p = max(pair_cnts, key=lambda p: PairCountSortKey(pair=p, cnt=pair_cnts[p]))        
        merged_token = b''.join(merged_p)
        new_token_id = len(vocab)
        vocab[new_token_id] = merged_token
        merges.append(merged_p)
        print(f"Merging pair {merged_p} of count {pair_cnts[merged_p]} to new token {new_token_id}")
        debug_pair_to_pretoken_info = { p: [pretoken_info[pt] for pt in pts] for p, pts in pair_to_pretokens.items() }
        #print(f'DEBUG: pair_cnts = {pair_cnts}\npair_to_pretokens = {debug_pair_to_pretoken_info}')
        # Now remove the merged pair from pair_cnts to clear way for next merged pair
        pair_cnts.pop(merged_p)
        
    return vocab, merges

def updated_token_seq(
    pt: str,
    ptv: Pretoken,
    pair: tuple[bytes, bytes],
    pair_cnts: Counter[tuple[bytes, bytes]],
    pair_to_pretokens: DefaultDict[tuple[bytes, bytes], set[str]],
):
    '''
    Create updated token sequence of a pretoken given
    its existing token sequence and the token pair to merge.

    This is to reflect the merge in the pretoken, the implication is that
    pair(s) which previously overlap w/ pair to merge in pretoken are now
    gone due to the merge, thus we must decrement their count accordingly

    (In this light, using a heap to manage pair counts is more awkward as
    update to items in heap are not straightforward -- one needs to find
    the item, pop it out of heap, update count then push it back. So drop
    such idea)

    We have to save the count of *new* pairs resulted from merging the merged token to pair_cnts!

    pt: Pretoken string
    ptv: `Pretoken` value contain pt's token sequence and corpus count 
    pair: Token pair to merge.
    pair_cnts: Token pair counter.
    '''
    # replace all non-overlapping occurrences of token pair w/ new token
    old = ptv.seq
    idx, ln = 0, len(old)
    new = []
    merged_token =  b''.join(pair)
    #print(f'DEBUG: updating token sequence - merged pair {pair} pretoken str: "{pt}" token seq: {old}')
    while idx < ln:
        if (
            idx < ln - 1
            and old[idx] == pair[0]
            and old[idx + 1] == pair[1]
        ):
            new.append(merged_token)
            # Find overlapping pairs and update their counts:
            # (old[idx-1], old[idx]) and (old[idx+1], old[idx+2])
            # Also record token(s) which can be built by merging
            # the merged token and its neighbor.
            # NOTE!!! Here pair count increments/decrements by
            # pretoken's corpus count, not 1
            if idx-1 >= 0:
                p_gone = (old[idx-1], old[idx])
                if p_gone in pair_cnts:
                    pair_cnts[p_gone] -= ptv.cnt
                    if pair_cnts[p_gone] == 0:
                        pair_cnts.pop(p_gone)
                
                new_p_w_merge_token = (old[idx-1], merged_token)
                pair_cnts[new_p_w_merge_token] += ptv.cnt
                pair_to_pretokens[new_p_w_merge_token].add(pt)
                #print(f'DEBUG: new pair w/ merged token {new_p_w_merge_token} - merged token {merged_token}')
                
            if idx+2 < ln:
                p_gone = (old[idx+1], old[idx+2])
                if p_gone in pair_cnts:
                    pair_cnts[p_gone] -= ptv.cnt
                    if pair_cnts[p_gone] == 0:
                        pair_cnts.pop(p_gone)
                
                new_p_w_merge_token = (merged_token, old[idx+2])
                pair_cnts[new_p_w_merge_token] += ptv.cnt
                pair_to_pretokens[new_p_w_merge_token].add(pt)
                #print(f'DEBUG: new pair w/ merged token {new_p_w_merge_token} - merged token {merged_token}')
                
            idx += 2
        else:
            new.append(old[idx])
            idx += 1

    ptv.seq = new


class PairCountSortKey:
    '''
    Sort key to find token pair of highest count and largest lexical order.

    NOTE this is a useful way to encapsulate complex comparison logic which
    cannot fit into a one-liner lambda function:

    Suppose pairs is a list of token pairs.
    Before:
    sorted(pairs, key=lambda p: # cannot fit in logic to first compare count then lexical order! ...)
    After:
    sorted(pairs, key=lambda p: PairCountSortKey(pair=p, cnt=pair_cnts[pair]))
    '''
    def __init__(self, pair: tuple[bytes, bytes], cnt: int) -> None:
        self.pair = pair
        self.cnt = cnt

    def __lt__(self, other: "PairCountSortKey") -> bool:
        """
        https://docs.python.org/3/reference/datamodel.html#object.__lt__

        This pair is deemed less than other if it has a lowr count, or it is
        lexically smaller when there is a tie on count.

        This seems to work w/ max() too as long as two values in comparison
        is of same type. See https://stackoverflow.com/a/72880603
        """
        if self.cnt != other.cnt:
            return self.cnt < other.cnt
        # A tie on count; Break it by lexical ordering
        return self.pair < other.pair

    def __repr__(self) -> str:
        return str((self.pair, self.cnt))

# Correctly optimized BPE impl

In [5]:
'''
Functional, truely optimized and performant BPE algo implementation.
'''

class Pretoken:
    def __init__(self, seq, cnt):
        """
        seq: Token sequence representing the pretoken.
        cnt: Count of pretoken in corpus.
        """
        self.seq = seq
        self.cnt = cnt

    def __repr__(self) -> str:
        return str((self.seq, self.cnt))


def bpe_time_optimized(
    pretokens: dict[str, int], vocab_size: int, special_tokens: list[str]
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    Start: Make a full pass to all pretokens and get byte-token pair -> count
        mapping. This also yields the 1st merged token, aka the byte-token pair
        w/ the highest count. Denote this pair as p1.

    In byte-token pair -> count mapping, find pairs which are not p1 but
    overlap w/ p1; Suppose pair p meets such criteria, then it means either of
    following is true:
    - p[0] = p1[1], e.g. (b't', b'h') and (b' ', b't')
    - p[1] = p1[0], e.g. (b't', b'h') and ('h', 'e')

    NOTE such property is DIFFERENT from that p shares 1 common token w/
    p1, as the latter case implies there exist cases e.g. p[0] = p1[0] or p[1]
    = p1[1]. However we cannot merge p and p1 to create a new token -- merge is
    only possible when we can concatenate token over the overlapped element,
    e.g. merge p = (b't', b'h') and p1 = b' t' (previously merged token)
    yields pair (b' t', b'h') and token b' th'

    To determine the 2nd merged token, BPE requires us to find
    corresponding token pair w/ the highest count in text corpus; Note such
    pair may or may not overlap w/ p1.
    (as we have seen in `TinyStoriesV2-GPT4-valid.txt.pretokens.json`)

    IMO the optimization idea mentioned by assignment tries relying on the
    hypothesis that tokens created by merging a already merged token and one
    which has overlaps with it will have
    higher count compared to picking up pair of random adjacent token and count
    the pair's occurrence, hence the efficiency increase. As mentioned above,
    such hypothesis can be false, which makes us find out the pair that has
    the 2nd highest count but doesn't overlap w/ p1.

    Idea to remove the bottleneck: Use a max heap to track the next merged
    pair. Item in the heap will be of PairCount type, and we will use
    PairCount.cnt < 0 to signal that an item is no longer valid because
    the corresponding pair's count has changed. We will rely on keep pushing
    items of updated pair-cnt and that for new pairs to maintain max heap's
    correctness. This is not very ideal (aka tracking the next merged pair
    which has highest count and lexical order in O(1) time), but at least we no
    longer need to sort all the known pair-cnt entries (O(nlgn)) - we will only
    need 2 x |pairs w/ updated count| + |new pairs| heap push/pop operations,
    (where 2x is due to popping invalid pairs and pushing updated pairs)
    resulting time complexity of O((2 x |pairs w/ updated count| + |new pairs|)lgn)
    """

    # Need a mapping from pretoken str -> token sequence which forms the
    # pretoken. Because the token sequence can change as we identify token pair
    # to merge and create new merged token from the pair.
    # To facilitate access and update of pretoken's token sequence and
    # pretoken's count, we need a "container" type, hence the Pretoken definition
    pretoken_info: dict[str, Pretoken] = {}
    for pt, cnt in pretokens.items():
        seq = tuple(bytes([b]) for b in pt.encode(UTF8))
        # BPE merge needs at least 2 tokens
        if len(seq) < 2:
            continue
        pretoken_info[pt] = Pretoken(seq, cnt)

    vocab = {i: bytes([i]) for i in range(256)}
    merges: list[tuple[bytes, bytes]] = []
    for t in special_tokens:
        # len(vocab) is the id of next new token
        new_token_id = len(vocab)
        vocab[new_token_id] = t.encode(UTF8)

    pair_cnts: dict[tuple[bytes, bytes], PairCount] = {}
    # token pair -> pretokens which contain such pair
    pair_to_pretokens: DefaultDict[tuple[bytes, bytes], set[str]] = defaultdict(set)
    # PairCount max heap
    pch: list[PairCount] = []

    # iterative BPE runs
    merged_p: tuple[bytes, bytes] | None = None
    while len(vocab) < vocab_size:
        if merged_p is None:
            # Initial condition: Haven't identify any merged token
            # run a full pass of pretokens to collect initial byte-token pairs
            for pt, v in pretoken_info.items():
                # Assumption: pt contains > 1 tokens. Iterate each token and
                # collect pair of itself and its successor as merge candidate
                for p in zip(v.seq, v.seq[1:]):
                    pc = pair_cnts.get(p, PairCount(p, 0))
                    pc.cnt += v.cnt  # NOTE change in count is not 1!
                    pair_cnts[p] = pc
                    pair_to_pretokens[p].add(pt)
            # Build a max heap from created PairCount values
            pch = list(pair_cnts.values())
            heapify_max(pch)
        elif len(pair_cnts) == 0:
            # Unlikely to happen; But if this was true then it means there are no more
            # new merged token to be created. So log and exit loop
            print(
                f"Cannot merge further as token pairs have run out. Actual vocab size: {len(vocab)} Expected: {vocab_size}"
            )
            break
        else:
            # Previous iteration has identified a merged token
            # Find candidates of the merged pair (to be identified in current iteration),
            # which can be:
            # - A pair already in pair_cnts, OR
            # - A new pair resulted from creation of merged token from previous iteration
            # and update mapping pair -> count and pair -> pretokens accordingly
            proto_update_token_seqs(merged_p, pair_to_pretokens, pretoken_info, pair_cnts, pch)
            #for pt in pair_to_pretokens[merged_p]:
            #    v = pretoken_info[pt]
            #    updated_token_seq(pt, v, merged_p, pair_cnts, pair_to_pretokens)
            # Drop merged_p's entry from pair_to_pretoken as it is longer needed
            pair_to_pretokens.pop(merged_p)
            if not pch:
                print(
                    f"Cannot merge further as token pairs have run out. Actual vocab size: {len(vocab)} Expected: {vocab_size}"
                )
                break

        # Q: How can we ensure the top of hcp at this point refers to the
        # merged pair found by the current iteration?
        #print(f'DEBUG: heap before removing the merged pair: {pch}')
        merged_pc = heappop_max(pch)
        merged_p = merged_pc.pair
        merged_token = b"".join(merged_p)
        new_token_id = len(vocab)
        vocab[new_token_id] = merged_token
        merges.append(merged_p)
        print(
            f"Merging pair {merged_p} of count {pair_cnts[merged_p].cnt} to new token {new_token_id}"
        )
        # NOTE comment out following and test speed again
        #debug_pair_to_pretoken_info = {
        #    p: [pretoken_info[pt] for pt in pts] for p, pts in pair_to_pretokens.items()
        #}
        # print(f'DEBUG: pair_cnts = {pair_cnts}\npair_to_pretokens = {debug_pair_to_pretoken_info}')
        # Now remove the merged pair from pair_cnts to clear way for next merged pair
        pair_cnts.pop(merged_p)

    return vocab, merges

def proto_update_token_seqs(
        p: tuple[bytes, bytes],
        pair_to_pretokens: DefaultDict[tuple[bytes, bytes], set[str]],
        pretoken_info: dict[str, Pretoken],
        pair_cnts: dict[tuple[bytes, bytes], PairCount],
        pch: list[PairCount],
):
    '''
    p: merged pair

    Spec:
    Invalidate PairCount values in pch whose pair's count has been updated,
    besides updating other states. After a pair's count has been fully updated,
    create a new PairCount value and push it to heap. Same for new pairs
    created from merging the merged token.
    '''
    new_pcs: dict[tuple[bytes, bytes], PairCount] = {}
    for pt in pair_to_pretokens[p]:
        ptv = pretoken_info[pt]
        # replace all non-overlapping occurrences of token pair w/ new token
        old = ptv.seq
        idx, ln = 0, len(old)
        new = []
        merged_token = b"".join(p)
        # print(f'DEBUG: updating token sequence - merged pair {pair} pretoken str: "{pt}" token seq: {old}')
        while idx < ln:
            if idx < ln - 1 and old[idx] == p[0] and old[idx + 1] == p[1]:
                new.append(merged_token)
                # Find overlapping pairs and update their counts:
                # (old[idx-1], old[idx]) and (old[idx+1], old[idx+2])
                # Also record token(s) which can be built by merging
                # the merged token and its neighbor.
                # NOTE!!! Here pair count increments/decrements by
                # pretoken's corpus count, not 1
                if idx - 1 >= 0:
                    p_gone = (old[idx - 1], old[idx])
                    if p_gone in pair_cnts:
                        # NOTE default value is returned only when we see this
                        # pair in the loop for the 1st time.
                        pc = new_pcs.get(p_gone, pair_cnts[p_gone].copy())
                        pc.cnt -= ptv.cnt
                        new_pcs[p_gone] = pc
                        # Invalidate the PairCount value from pair_cnts as its
                        # count has changed. This will later help us discard
                        # invalid PairCount values in the heap. NOTE this will
                        # NOT break the loose ordering b/w items in heap; Updating
                        # the count value of items in heap will, so avoid it.
                        pair_cnts[p_gone].valid = False
                        if pc.cnt == 0:
                            pair_cnts.pop(p_gone)
                            new_pcs.pop(p_gone)
                            # TODO optional to pop p_gone from
                            # pair_to_pretokens

                    new_p_w_merge_token = (old[idx - 1], merged_token)
                    new_pc_w_merge_token = new_pcs.get(
                            new_p_w_merge_token, PairCount(new_p_w_merge_token, 0))
                    new_pc_w_merge_token.cnt += ptv.cnt
                    new_pcs[new_p_w_merge_token] = new_pc_w_merge_token
                    pair_to_pretokens[new_p_w_merge_token].add(pt)
                    # print(f'DEBUG: new pair w/ merged token {new_p_w_merge_token} - merged token {merged_token}')

                if idx + 2 < ln:
                    p_gone = (old[idx + 1], old[idx + 2])
                    if p_gone in pair_cnts:
                        pc = new_pcs.get(p_gone, pair_cnts[p_gone].copy())
                        pc.cnt -= ptv.cnt
                        new_pcs[p_gone] = pc
                        pair_cnts[p_gone].valid = False
                        if pair_cnts[p_gone] == 0:
                            pair_cnts.pop(p_gone)
                            new_pcs.pop(p_gone)

                    new_p_w_merge_token = (merged_token, old[idx + 2])
                    new_pc_w_merge_token = new_pcs.get(
                            new_p_w_merge_token, PairCount(new_p_w_merge_token, 0))
                    new_pc_w_merge_token.cnt += ptv.cnt
                    new_pcs[new_p_w_merge_token] = new_pc_w_merge_token
                    pair_to_pretokens[new_p_w_merge_token].add(pt)
                    # print(f'DEBUG: new pair w/ merged token {new_p_w_merge_token} - merged token {merged_token}')

                idx += 2
            else:
                new.append(old[idx])
                idx += 1

        ptv.seq = new
    # Now we have collected pair -> count entries for:
    # 1. Existent pairs whose count has been decremented (not down to 0)
    # 2. New pairs built from merging the merged token
    # Now restore the (count and lexical) ordering among pairs by putting
    # these entries to pair_cnts and pair count max heap.
    for pair, pc in new_pcs.items():
        pair_cnts[pair] = pc
        heappush_max(pch, pc)
    # Finally pop the heap until we see a valid PairCount at heap top, which is
    # the next merged pair
    while pch and (not pch[0].valid or pch[0].cnt == 0):
        heappop_max(pch)

class PairCount:
    """
    Sort key to find token pair of highest count and largest lexical order.

    NOTE this is a useful way to encapsulate complex comparison logic which
    cannot fit into a one-liner lambda function:

    Suppose pairs is a list of token pairs.
    Before:
    sorted(pairs, key=lambda p: # cannot fit in logic to first compare count then lexical order! ...)
    After:
    sorted(pairs, key=lambda p: PairCount(pair=p, cnt=pair_cnts[pair]))

    NOTE we mark a PairCount as invalid by setting its count to < 0.
    """

    def __init__(self, pair: tuple[bytes, bytes], cnt: int) -> None:
        self.pair = pair
        self.cnt = cnt
        # Check this when we pop a PairCount value from max heap
        # Discard the value if it is false. NOTE value of this field
        # shall NOT influence a PairCount's ordering in heap -- We
        # shall refrain from deliberately violating heap's invariance.
        self.valid = True

    def __lt__(self, other: "PairCount") -> bool:
        """
        https://docs.python.org/3/reference/datamodel.html#object.__lt__

        This pair is deemed less than other if it has a lowr count, or it is
        lexically smaller when there is a tie on count.

        This seems to work w/ max() too as long as two values in comparison
        is of same type. See https://stackoverflow.com/a/72880603
        """
        if self.cnt != other.cnt:
            return self.cnt < other.cnt
        # A tie on count; Break it by lexical ordering
        return self.pair < other.pair

    def __repr__(self) -> str:
        return str((self.pair, self.valid, self.cnt))

    def copy(self) -> PairCount:
        'Return a copy of itself.'
        return PairCount(self.pair, self.cnt)


# Experiment & Test

In [6]:
txt = '''I've got following inspiring quotes from Ed, the Green Lake legend OG:

Life is a testing place not a resting place.

Your attidue decides your altitude.

Hooping. No begging. Stop running your mouth.

End quote.'''
#txt = 'aab abac'
pretoken_counts = Counter()
vocab_size = 300
special_tokens = ['<|endoftext|>']

count_pretokens(txt, pretoken_counts)

In [9]:
vocab, merges = bpe_baseline(pretoken_counts, vocab_size, special_tokens)

Merging pair (b'i', b'n') of count 8 to new token 257
Merging pair (b'in', b'g') of count 7 to new token 258
Merging pair (b'o', b'u') of count 4 to new token 259
Merging pair (b'o', b't') of count 4 to new token 260
Merging pair (b'e', b's') of count 4 to new token 261
Merging pair (b' ', b'a') of count 4 to new token 262
Merging pair (b'ou', b'r') of count 3 to new token 263
Merging pair (b'y', b'our') of count 2 to new token 264
Merging pair (b'u', b'ot') of count 2 to new token 265
Merging pair (b't', b'ing') of count 2 to new token 266
Merging pair (b't', b'i') of count 2 to new token 267
Merging pair (b't', b'h') of count 2 to new token 268
Merging pair (b'q', b'uot') of count 2 to new token 269
Merging pair (b'p', b'l') of count 2 to new token 270
Merging pair (b'pl', b'a') of count 2 to new token 271
Merging pair (b'pla', b'c') of count 2 to new token 272
Merging pair (b'plac', b'e') of count 2 to new token 273
Merging pair (b'o', b'p') of count 2 to new token 274
Merging pair 

In [45]:
vocab, merges = bpe_time_suboptimal(pretoken_counts, vocab_size, special_tokens)

Merging pair (b'i', b'n') of count 3 to new token 257
Merging pair (b' ', b'a') of count 3 to new token 258
Merging pair (b't', b'in') of count 2 to new token 259
Merging pair (b'tin', b'g') of count 2 to new token 260
Merging pair (b's', b'ting') of count 2 to new token 261
Merging pair (b'p', b'l') of count 2 to new token 262
Merging pair (b'pl', b'a') of count 2 to new token 263
Merging pair (b'pla', b'c') of count 2 to new token 264
Merging pair (b'plac', b'e') of count 2 to new token 265
Merging pair (b'e', b'sting') of count 2 to new token 266
Merging pair (b'e', b'r') of count 2 to new token 267
Merging pair (b' ', b'place') of count 2 to new token 268
Merging pair (b' ', b'o') of count 2 to new token 269
Merging pair (b'v', b'e') of count 1 to new token 270
Merging pair (b've', b'd') of count 1 to new token 271
Merging pair (b't', b'esting') of count 1 to new token 272
Merging pair (b's', b'er') of count 1 to new token 273
Merging pair (b'ser', b'ved') of count 1 to new token 2

In [7]:
vocab_n, merges_n = bpe_time_optimized(pretoken_counts, vocab_size, special_tokens)

Merging pair (b'i', b'n') of count 8 to new token 257
Merging pair (b'in', b'g') of count 7 to new token 258
Merging pair (b'o', b'u') of count 4 to new token 259
Merging pair (b'o', b't') of count 4 to new token 260
Merging pair (b'e', b's') of count 4 to new token 261
Merging pair (b' ', b'a') of count 4 to new token 262
Merging pair (b'ou', b'r') of count 3 to new token 263
Merging pair (b'y', b'our') of count 2 to new token 264
Merging pair (b'u', b'ot') of count 2 to new token 265
Merging pair (b't', b'ing') of count 2 to new token 266
Merging pair (b't', b'i') of count 2 to new token 267
Merging pair (b't', b'h') of count 2 to new token 268
Merging pair (b'q', b'uot') of count 2 to new token 269
Merging pair (b'p', b'l') of count 2 to new token 270
Merging pair (b'pl', b'a') of count 2 to new token 271
Merging pair (b'pla', b'c') of count 2 to new token 272
Merging pair (b'plac', b'e') of count 2 to new token 273
Merging pair (b'o', b'p') of count 2 to new token 274
Merging pair 

In [10]:
assert merges == merges_n, "Merges from optimized BPE doesn't match that from baseline"

In [11]:
vocab_n

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [12]:
merges

[(b'i', b'n'),
 (b'in', b'g'),
 (b'o', b'u'),
 (b'o', b't'),
 (b'e', b's'),
 (b' ', b'a'),
 (b'ou', b'r'),
 (b'y', b'our'),
 (b'u', b'ot'),
 (b't', b'ing'),
 (b't', b'i'),
 (b't', b'h'),
 (b'q', b'uot'),
 (b'p', b'l'),
 (b'pl', b'a'),
 (b'pla', b'c'),
 (b'plac', b'e'),
 (b'o', b'p'),
 (b'n', b'd'),
 (b'es', b'ting'),
 (b'e', b'g'),
 (b'd', b'e'),
 (b' ', b'your'),
 (b' ', b'r'),
 (b' ', b'quot'),
 (b' ', b'place'),
 (b' ', b'f'),
 (b'w', b'ing'),
 (b'v', b'e'),
 (b'u', b'n'),
 (b'un', b'n'),
 (b'unn', b'ing'),
 (b'u', b'e'),
 (b'u', b'de'),
 (b'ti', b't'),
 (b'tit', b'ude'),
 (b'ti', b'd'),
 (b'tid', b'ue'),
 (b'th', b'e'),
 (b't', b'tidue'),
 (b't', b'op'),
 (b't', b'esting'),
 (b's', b'p')]

# Tokenizer encoding and decoding logic

In [6]:
import math
from itertools import chain

class Tokenizer:
    '''
    Tokenizer encodes given text to token sequence and decodes given token
    sequence to text.
    '''

    def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes,
                                                                   bytes]],
                 special_tokens: list[str] | None =None) -> None:
        '''
        Spec:
        Keep vocab for decoding token sequence to text.
        Keep merges for encoding text to token sequence.
        Need a way to look up a token's index given its bytes representation,
        so need to build a reverse bytes -> int mapping from vocab.
        For special tokens, first filter out those already existed in the
        reversed index built from vocab, then assign new id to the remaining
        ones and add mapping to both vocab and the reversed index, starting
        from |vocab|.
        '''
        self.vocab  = vocab
        # Merges need to proceed in the same order as merged pair creation so
        # use index as ordering indicator -- A pass to find token pair to merge
        # in a pretoken will need to find the pair of smallest index in mapping
        # below.
        self.merges: dict[tuple[bytes, bytes], int] = { p: idx for idx, p in
                                                       enumerate(merges) }
        # the reverse bytes -> int mapping from given vocab
        self.bytes_to_token: dict[bytes, int] = { v: k for k, v in
                                                 vocab.items() }
        self.pretoken_to_tokens: dict[str, list[int]] = {}
        self.pretokenize_pat = PRE_TOKENIZE_PAT
        self.special_tokens_pat = None
        special_tokens_regexes = []
        if isinstance(special_tokens, list) :
            for t in sorted(special_tokens, reverse=True):
                # to capture each special token in pretokenization
                special_tokens_regexes.append(re.escape(t))
                tb = t.encode(UTF8)
                if tb not in self.bytes_to_token:
                    new_token_id = len(vocab)
                    vocab[new_token_id] = tb
                    self.bytes_to_token[tb] = new_token_id
                self.pretoken_to_tokens[t] = [self.bytes_to_token[tb]]

            if special_tokens_regexes:
                self.special_tokens_pat = re.compile('|'.join(special_tokens_regexes))

    @classmethod
    def from_files(cls, vocab_filepath: str, merges_filepath: str,
                   special_tokens: list[str] | None = None) -> 'Tokenizer':
        # TODO what does the content look like in vocab_filepath and merges_filepath?
        with open(vocab_filepath) as f_vocab:
            with open(merges_filepath) as f_merges:
                # FIXME clawning for now
                vocab = json.load(f_vocab)
                merges = json.load(f_merges)
                return cls(vocab, merges, special_tokens)

    def encode(self, text: str) -> list[int]:
        '''
        Spec:

        Pretokenize text. The result shall be:
        1. A mapping of pretokens -> list[int], whose value set to None
        2. A list of pretokens in order of given text.
        NOTE before pretokenization we must split text by given special
        tokens if any to avoid tokenizing the special tokens.

        For each pretoken in the mapping, find corresponding token list by
        merging the pretoken's byte-level presentation.

        Finally compose the return value by initing an empty list l, then for
        each pretoken in the list from point 2 above, extend l to
        include the corresponding token list.

        If encoding is called fairly frequently, it can be desirable for the
        Toekenizer instance to maintain a pretoken -> token list mapping as an
        attribute and use resulting cache effect to speed up encoding.
        '''
        txt_and_spt_pairs_iter = [(text, '')]
        if self.special_tokens_pat:
            # Python has handy, comprehensive builtins for dealing w/ iteration:
            # https://docs.python.org/3/library/itertools.html
            special_tokens_iter = chain(
                map(
                    lambda m: text[m.start() : m.end()],
                    re.finditer(self.special_tokens_pat, text)),
                # Below is necessary tomake the loop cover the last txt piece
                # after splitting the given text by special tokens
                [''],
            )
            txt_and_spt_pairs_iter = zip(
                re.splititer(self.special_tokens_pat, text),
                special_tokens_iter,
            )

        tokens = []
        for txt, spt in txt_and_spt_pairs_iter:
            # TODO what if txt is empty str?
            for m in re.finditer(self.pretokenize_pat, txt):
                pt = txt[m.start():m.end()]
                if pt in self.pretoken_to_tokens:
                    tokens.extend(self.pretoken_to_tokens[pt])
                    continue
                # TODO pt not in pretoken -> token list mapping; Compute token list
                # and cache. Start merging from byte-level tokens
                pt_tokens = [bytes([b]) for b in pt.encode(UTF8)]
                # Enumerate the pairs in ptb and replace, until no replacement can
                # be found. TODO finally cache the pretoken -> token list entry
                # to self.pretoken_to_tokens
                # Edge case: pretoken contains only 1 byte
                while True:
                    merged_p, merged_p_order = None, math.inf
                    for p in zip(pt_tokens, pt_tokens[1:]):
                        p_order = self.merges.get(p, math.inf)
                        if p_order is not math.inf and p_order < merged_p_order:
                            merged_p = p
                            merged_p_order = p_order
                    # Found pair to merge or it is still None. If it is latter
                    # break, as no new merged pair is found; Otherwise replace token list
                    # w/ one w/ the merged token.
                    if merged_p is None:
                        break
                    merged_token = b''.join(merged_p)
                    print(f'Processing pretoken "{pt}": Merging {merged_p} into {merged_token}')
                    pt_tokens_w_merged_p = []
                    idx, ln = 0, len(pt_tokens)
                    while idx < ln:
                        if idx < ln-1 and pt_tokens[idx] == merged_p[0] and pt_tokens[idx+1] == merged_p[1]:
                            pt_tokens_w_merged_p.append(merged_token)
                            idx+=2
                        else:
                            pt_tokens_w_merged_p.append(pt_tokens[idx])
                            idx+=1
                    # prepare for next merge run
                    pt_tokens = pt_tokens_w_merged_p
                # Now look up the id of final tokens given their bytes representation
                pt_tokens = [ self.bytes_to_token[b] for b in pt_tokens ]
                self.pretoken_to_tokens[pt] = pt_tokens
                # finally extend the merged tokens to final token list
                tokens.extend(pt_tokens)
            # append id of special token to token list
            if spt != '':
                tokens.extend(self.pretoken_to_tokens[spt])

        return tokens

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        '''
        More on generator see:
        - https://stackoverflow.com/a/1756156
        - https://wiki.python.org/moin/Generators
        Spec:

        Iterate the given iterable - For each txt in iterable:
            tokens = self.encode(txt)
            for t in tokens:
                yield t
        '''
        for txt in iterable:
            yield from self.encode(txt)


    def decode(self, ids: list[int]) -> str:
        '''
        Spec:
        Start w/ an empty byte array.
        For each token id in ids:
            Look up vocab to get the bytes for token identified by id
            Put the bytes into byte array
        Decode byte array directly to str (feasible?)
        '''
        b = bytearray()
        for t_id in ids:
            b.extend(self.vocab[t_id])
        return b.decode(encoding=UTF8, errors='replace')

In [14]:
special_tokens = ['<|endoftext|>', '<|endofprompt|>']

tk = Tokenizer(vocab_n, merges_n, special_tokens)

txt_to_encode = 'I love<|endofprompt|>hooping ÊâìÁêÉüèÄ!<|endoftext|>hoop is life yoyo<|endoftext|>The last txt piece'
#txt_to_encode = 'I love sexy white girls.'

[ (i, vocab_n[i]) for i in tk.encode(txt_to_encode) ]

[(73, b'I'),
 (32, b' '),
 (108, b'l'),
 (111, b'o'),
 (285, b've'),
 (300, b'<|endofprompt|>'),
 (104, b'h'),
 (111, b'o'),
 (274, b'op'),
 (258, b'ing'),
 (32, b' '),
 (230, b'\xe6'),
 (137, b'\x89'),
 (147, b'\x93'),
 (231, b'\xe7'),
 (144, b'\x90'),
 (131, b'\x83'),
 (240, b'\xf0'),
 (159, b'\x9f'),
 (143, b'\x8f'),
 (128, b'\x80'),
 (33, b'!'),
 (256, b'<|endoftext|>'),
 (104, b'h'),
 (111, b'o'),
 (274, b'op'),
 (32, b' '),
 (105, b'i'),
 (115, b's'),
 (32, b' '),
 (108, b'l'),
 (105, b'i'),
 (102, b'f'),
 (101, b'e'),
 (32, b' '),
 (121, b'y'),
 (111, b'o'),
 (121, b'y'),
 (111, b'o'),
 (256, b'<|endoftext|>'),
 (84, b'T'),
 (104, b'h'),
 (101, b'e'),
 (32, b' '),
 (108, b'l'),
 (97, b'a'),
 (115, b's'),
 (116, b't'),
 (32, b' '),
 (116, b't'),
 (120, b'x'),
 (116, b't'),
 (32, b' '),
 (112, b'p'),
 (105, b'i'),
 (101, b'e'),
 (99, b'c'),
 (101, b'e')]

In [15]:
assert tk.decode(tk.encode(txt_to_encode)) == txt_to_encode, 'Tokenizer encode decode trip failed'

In [16]:
txt_list = [
    "Today is raining, ",
    "plus it is foggy.",
]

tks_from_gen = [ t for t in tk.encode_iterable(txt_list) ]

tks_from_encode = tk.encode(''.join(txt_list))

assert tks_from_gen == tks_from_encode, "Tokens generated by encode_iterable don't match that from encode"

In [77]:
from itertools import chain

'''
Split given txt by given special tokens.
Then map each splitted piece to the special token which comes right after the piece.
Pretokenize the piece as normal, resulting a list of token ids.
Append the token id of the following special token to the list.
'''
pieces = re.splititer(tk.special_tokens_pat, txt_to_encode)
spts = map(lambda m: txt_to_encode[m.start() : m.end()], re.finditer(tk.special_tokens_pat, txt_to_encode))
spts = chain(spts, [''])
for piece, spt in zip(pieces, spts):
    print(f'piece: "{piece}" special token = "{spt}"')

piece: "I am running " special token = "<|endofprompt|>"
piece: "and hooping ÊâìÁêÉüèÄ!" special token = "<|endoftext|>"
piece: "some foobar text yoyo" special token = "<|endofprompt|>"
piece: "some other txt heyhey" special token = "<|endoftext|>"
piece: "The last txt piece" special token = ""


In [7]:
from tests.test_tokenizer import get_tokenizer_from_vocab_merges_path


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.3.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/doobdoob/projects/stanford.css336/assignment1-basics/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/doobdoob/projects/stanford.css336/assignment1-basics/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/doobdoob/projects/stanford.css336/assignment1-b

In [38]:
from pathlib import Path

FIXTURE_PATH = '../tests/fixtures/'
_tk = get_tokenizer_from_vocab_merges_path(
    vocab_path=Path(FIXTURE_PATH) / 'gpt2_vocab.json',
    merges_path=Path(FIXTURE_PATH) / 'gpt2_merges.txt',
)

In [39]:
tk = Tokenizer(vocab=_tk.vocab, merges=_tk.merges.keys(), special_tokens=["<|endoftext|>", "<|endoftext|><|endoftext|>"])

In [13]:
print(f'gpt2 vocab len = {len(vocab)}, merges len = {len(merges)}')
vocab[50256]

gpt2 vocab len = 50257, merges len = 50000


b'<|endoftext|>'

In [20]:
merges[(b'Hel', b'lo')]

KeyError: (b'Hel', b'lo')

In [41]:
txt = "Hello, how <|endoftext|><|endoftext|> are you?<|endoftext|>"

[ (i, vocab[i]) for i in tk.encode(txt) ]

[(15496, b'Hello'),
 (11, b','),
 (703, b' how'),
 (220, b' '),
 (50256, b'<|endoftext|>'),
 (50256, b'<|endoftext|>'),
 (389, b' are'),
 (345, b' you'),
 (30, b'?'),
 (50256, b'<|endoftext|>')]

In [18]:
vocab_bytes = set(vocab.values())
b'Hello' in vocab_bytes

True

In [30]:
merges[(b'e', b'l')]

161

In [1]:
help(sorted)

Help on built-in function sorted in module builtins:

sorted(iterable, /, *, key=None, reverse=False)
    Return a new list containing all items from the iterable in ascending order.

    A custom key function can be supplied to customize the sort order, and the
    reverse flag can be set to request the result in descending order.

