# Domain Specific Language for Decoding

frontier research labs that provide inference apis are too greedy to share their logits with us and are not up to date with the best sampling practices.

we basically only get the temperature parameter to play around with and that's usually it. maybe also top_k if we are lucky.

granted, even if we had the logits, having to do the decoding on our machine would be a throughput-nightmare (thanks to @stochasm for pointing this out). so what should we do?

the answer was provided by @lun_aaaaa as an off-hand remark (I assume), but it's actually brilliant.

I have decided to take that idea and propose to my knowledge the first domain specific language to define sampling behavior.

as of now, it only consists of 3 commands:

- `sort`, which sorts the logit tensor (default is descending order, `sort +` for asc, `sort -` for desc)

- `slice n:m`, which slices the logit tensor from the $n$-th to $m$-th index. this is implemented as masking the logits at the positions that are out of bounds

- `threshold op n`, which masks all the logits that are above or belov the threshold (depending on `op`)

we can create `top_k` sampling through this dsl code:

```
sort -
slice k:
```

likewise, we can recover `min_p` through the following code:

```
theshold < p
```

# a first interpreter

let's implement the primitives for our basic dsl

In [1]:
import numpy as np

In [2]:
def sort_logits(logits, order="-"):
    """
    Sort the logits tensor.
    
    Args:
        logits (np.ndarray): The logits to sort.
        order (str): "+" for ascending, "-" for descending (default).

    Returns:
        np.ndarray: Sorted logits.
    """
    if order == "-":
        return np.sort(logits)[::-1]
    elif order == "+":
        return np.sort(logits)
    else:
        raise ValueError(f"Invalid order '{order}'. Use '+' or '-'.")

In [3]:
def slice_logits(logits, n=None, m=None):
    """
    Slice logits from n-th to m-th index (masking out-of-bounds logits).
    
    Args:
        logits (np.ndarray): The logits to slice.
        n (int, optional): Start index (inclusive). Defaults to None (start from 0).
        m (int, optional): End index (exclusive). Defaults to None (go to end).

    Returns:
        np.ndarray: Logits with out-of-bounds positions masked as -np.inf.
    """
    if n is None:
        n = 0
    if m is None:
        m = len(logits)
    if n < 0 or m > len(logits) or n >= m:
        raise ValueError(f"Invalid slice range: {n}:{m} for logits of size {len(logits)}.")
    
    mask = np.zeros_like(logits, dtype=bool)
    mask[n:m] = True
    return np.where(mask, logits, -np.inf)

In [4]:
def threshold_logits(logits, op, value):
    """
    Mask logits above or below a threshold.
    
    Args:
        logits (np.ndarray): The logits to threshold.
        op (str): Comparison operator, one of "<", ">", "<=", ">=".
        value (float): Threshold value.

    Returns:
        np.ndarray: Logits with masked values as -np.inf.
    """
    if op == "<":
        return np.where(logits < value, -np.inf, logits)
    elif op == ">":
        return np.where(logits > value, -np.inf, logits)
    elif op == "<=":
        return np.where(logits <= value, -np.inf, logits)
    elif op == ">=":
        return np.where(logits >= value, -np.inf, logits)
    else:
        raise ValueError(f"Invalid operator '{op}'. Use '<', '>', '<=', or '>='.")

now, let's build an interpreter. the idea is that every company can have their own interpreter under the hood, either in jax, pytorch or whatever they use. this is just a reference implementation (and a rudimentary one at that).

In [5]:
def parse_dsl(dsl_string):
    """
    Parse a DSL string into a list of commands.

    Args:
        dsl_string (str): The DSL string.

    Returns:
        list[str]: Parsed commands as individual strings.
    """
    commands = [cmd.strip() for cmd in dsl_string.split("\n") if cmd.strip()]
    return commands

In [11]:
def interpret_dsl(logits, dsl_string):
    """
    Parse and execute DSL commands to transform logits.

    Args:
        logits (np.ndarray): The logits to process.
        dsl_string (str): A DSL string with multiple commands.

    Returns:
        np.ndarray: Transformed logits.
    """
    commands = parse_dsl(dsl_string)

    for command in commands:
        parts = command.split()
        cmd = parts[0]
        
        if cmd == "sort":
            order = parts[1] if len(parts) > 1 else "-"
            logits = sort_logits(logits, order)
        
        elif cmd == "slice":
            n, m = parts[1].split(":")
            n = int(n) if n else None
            m = int(m) if m else None
            logits = slice_logits(logits, n, m)
        
        elif cmd == "threshold":
            op, value = parts[1], float(parts[2])
            logits = threshold_logits(logits, op, value)
        
        else:
            raise ValueError(f"Unknown command: {cmd}")
    
    return logits


let's make sure `top_k` and `min_p` work

In [17]:
top_3 = """sort
slice :3"""

min_1 = "threshold < 1"

In [18]:
logits = np.array([0.1, 2.3, 1.1, 0.7, -1.0])
print(f"original logits: {logits}")

result = interpret_dsl(logits, top_3)
print(f"top-3 logits: {result}")

original logits: [ 0.1  2.3  1.1  0.7 -1. ]
top-3 logits: [ 2.3  1.1  0.7 -inf -inf]


In [19]:
logits = np.array([0.1, 2.3, 1.1, 0.7, -1.0])
print(f"original logits: {logits}")

result = interpret_dsl(logits, min_1)
print(f"top-3 logits: {result}")

original logits: [ 0.1  2.3  1.1  0.7 -1. ]
top-3 logits: [-inf  2.3  1.1 -inf -inf]


# what's next

I coded this up at around 2am in 20 mins. there are a lot of things that are missing that I will be adding over next few days, if there is interest.

a few things that seem reasonable:
- add syntactic sugar for top_k, min_p and other widely used stuff
- add more ops
- add statefullness for beam-search and such
    - perhaps working on logit matrix, not just latest logit array