In [None]:
# Dependencies

%pip install rigging tqdm editdistance

In [None]:
# References

VULNERABLE_FUNCTION = """\
static bool nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len)
{
    int mac_off = skb_mac_header(skb) - skb->data;
    u8 *vlanh, *dst_u8 = (u8 *) d;
    struct vlan_ethhdr veth;
    u8 vlan_hlen = 0;

    if ((skb->protocol == htons(ETH_P_8021AD) ||
         skb->protocol == htons(ETH_P_8021Q)) &&
        offset >= VLAN_ETH_HLEN && offset < VLAN_ETH_HLEN + VLAN_HLEN)
        vlan_hlen += VLAN_HLEN;

    vlanh = (u8 *) &veth;

    if (offset < VLAN_ETH_HLEN + vlan_hlen) {
        u8 ethlen = len;

        if (vlan_hlen &&
            skb_copy_bits(skb, mac_off, &veth, VLAN_ETH_HLEN) < 0)
            return false;
        else if (!nft_payload_rebuild_vlan_hdr(skb, mac_off, &veth))
            return false;

        if (offset + len > VLAN_ETH_HLEN + vlan_hlen)
            ethlen -= offset + len - VLAN_ETH_HLEN + vlan_hlen;

        memcpy(dst_u8, vlanh + offset - vlan_hlen, ethlen);

        len -= ethlen;
        if (len == 0)
            return true;

        dst_u8 += ethlen;
        offset = ETH_HLEN + vlan_hlen;
    } else {
        offset -= VLAN_HLEN + vlan_hlen;
    }

    return skb_copy_bits(skb, offset + mac_off, dst_u8, len) == 0;
}
"""

PATCHED_FUNCTION = """\
static bool nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len)
{
    int mac_off = skb_mac_header(skb) - skb->data;
    u8 *vlanh, *dst_u8 = (u8 *) d;
    struct vlan_ethhdr veth;
    u8 vlan_hlen = 0;

    if ((skb->protocol == htons(ETH_P_8021AD) ||
         skb->protocol == htons(ETH_P_8021Q)) &&
        offset >= VLAN_ETH_HLEN && offset < VLAN_ETH_HLEN + VLAN_HLEN)
        vlan_hlen += VLAN_HLEN;

    vlanh = (u8 *) &veth;

    if (offset < VLAN_ETH_HLEN + vlan_hlen) {
        u8 ethlen = len;

        if (vlan_hlen &&
            skb_copy_bits(skb, mac_off, &veth, VLAN_ETH_HLEN) < 0)
            return false;
        else if (!nft_payload_rebuild_vlan_hdr(skb, mac_off, &veth))
            return false;

        if (offset + len > VLAN_ETH_HLEN + vlan_hlen)
            ethlen -= offset + len - VLAN_ETH_HLEN - vlan_hlen;

        memcpy(dst_u8, vlanh + offset - vlan_hlen, ethlen);

        len -= ethlen;
        if (len == 0)
            return true;

        dst_u8 += ethlen;
        offset = ETH_HLEN + vlan_hlen;
    } else {
        offset -= VLAN_HLEN + vlan_hlen;
    }

    return skb_copy_bits(skb, offset + mac_off, dst_u8, len) == 0;
}
"""

VULNERABILITY_DESCRIPTION = """\
The vulnerability consists of a stack buffer overflow due to an integer underflow vulnerability inside the nft_payload_copy_vlan function, which is invoked with nft_payload expressions as long as a VLAN tag is present in the current skb. (net/netfilter/nft_payload.c)

The initial checks look for a second VLAN tag from the EtherType field and, if the offset falls between the first VLAN_ETH_HLEN bytes and VLAN_ETH_HLEN plus the size of another VLAN header, then nftables should also try and process the second VLAN. The if statement preceeding memcopy correctly checks the boundary of the header using the offset and len variables (8-bit unsigned ints), evaluating to true whenever offset + len exceeds the double-tagged VLAN header. The use of inline statements successfully prevents wrappings because u8 types are automatically promoted before the comparison.

However, on the next line, the subtraction does not grant type promotion, and ethlen (u8) may wrap to UINT8_MAX under certain conditions. Some examples of vulnerable offset and len pairs are:

offset: 19 & len: 4 & ethlen = 251 offset: 16 & len: 19 & ethlen = 254 offset: 20 & len: 32 & ethlen = 250 ... Other pairs can be listed with the following algorithm:

```
uint8_t vlan_hlen = VLAN_HLEN, ethlen;
for (uint8_t len = 0; len < UINT8_MAX; len++) {
    for (uint8_t offset = 0; offset < UINT8_MAX; offset++) {
        if (offset < VLAN_ETH_HLEN + vlan_hlen) {
            uint8_t ethlen = len;
            if (offset + len > VLAN_ETH_HLEN + vlan_hlen) {
                ethlen -= offset + len - VLAN_ETH_HLEN + vlan_hlen;
                printf("offset: %hhu & len: %hhu & ethlen = %hhu\n",
offset, len, ethlen);
            }
        }
    }
}
```

Finally, during the memcpy an up to 255-byte buffer gets copied to the destination register located on the stack, overwriting the adjacent memory. Since we can control the destination register, we can pick NFT_REG32_15 to trigger a 251-byte OOB write on the stack (since NFT_REG32_15 occupies 4 bytes). The vulnerable code path can be reached if the function skb_vlan_tag_present(skb) evaluates to true, that is if the skb->vlan_tci field is set. This is known to happen when the host is placed inside a VLAN, although a modified skb could also be forged manually. (perhaps by forging the packet itself or with some other nft_expr that can edit packets?)

The calling function is nft_payload_eval which evaluates the Nftables expression:

```
void nft_payload_eval(const struct nft_expr *expr,
              struct nft_regs *regs,
              const struct nft_pktinfo *pkt) {
    const struct nft_payload *priv = nft_expr_priv(expr);
    const struct sk_buff *skb = pkt->skb;
    u32 *dest = &regs->data[priv->dreg]; <===== (0)
    int offset;

    if (priv->len % NFT_REG32_SIZE)
        dest[priv->len / NFT_REG32_SIZE] = 0;

    switch (priv->base) {
    case NFT_PAYLOAD_LL_HEADER:             <===== (1)
        if (!skb_mac_header_was_set(skb))
            goto err;

        if (skb_vlan_tag_present(skb)) {
            if (!nft_payload_copy_vlan(dest, skb,
                           priv->offset, priv->len)) <===== (2)
                goto err;
            return;
        }
        ...
```

At (0) dest is set to the chosen destination register, where the payload expression will store its result. If the payload offset base is NFT_PAYLOAD_LL_HEADER (1) and a mac header is present, the vulnerable code path will be taken (2). Furthermore, the kernel must be built with the configuration CONFIG_NETFILTER, CONFIG_NF_TABLES, CONFIG_VLAN_8021Q enabled, and the CAP_NET_ADMIN capability must be enabled, which can be accomplished by entering a new user namespace beforehand.
"""

In [None]:
import difflib
from dataclasses import dataclass

import litellm
from tqdm.notebook import tqdm

import rigging as rg
from rigging.watchers import write_chats_to_jsonl


def diff(a: str, b: str, all_lines: bool = True) -> str:
    num_lines = max([len(a), len(b)]) if all_lines else 3
    diff_gen = difflib.unified_diff(
        a.splitlines(True), b.splitlines(True), fromfile="before", tofile="after", n=num_lines
    )
    return "".join(diff_gen)


vertex_safety_override = {
    "safety_settings": [
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_ONLY_HIGH",
        },
    ]
}

MODELS = [
    rg.get_generator("vertex_ai/gemini-1.5-pro-001", params=rg.GenerateParams(extra=vertex_safety_override)),
    rg.get_generator("vertex_ai/codechat-bison", params=rg.GenerateParams(extra=vertex_safety_override)),
    "openai/gpt-4-turbo",
    "openai/gpt-4o",
    "anthropic/claude-3-sonnet-20240229",
    "mistral/mistral-large-latest",
    "mistral/mistral-medium-latest",
    "mistral/codestral-latest",
    "groq/llama3-70b-8192",
]

SYSTEM_PROMPT = """\
You are a senior Linux developer who specializes in code security.
"""

pipeline = (
    rg.get_generator("anthropic/claude-3-opus-20240229")
    .chat({"role": "system", "content": SYSTEM_PROMPT})
    .catch(litellm.APIError, on_failed="skip")
)


@dataclass
class Fix:
    chat: rg.Chat
    fixed_function: str

    @property
    def diff(self) -> str:
        return diff(VULNERABLE_FUNCTION, self.fixed_function)


@pipeline.prompt
async def fix_code(vulnerable_function: str, vulnerability_description: str) -> Fix:
    """
    Rewrite the source code to fix the vulnerability described.
    """

print(fix_code.template)

In [None]:
# Gather fixes with reference description

ref_fixes: list[Fix] = []
for _ in tqdm(range(10)):
    ref_fixes.extend(await fix_code.run_over(MODELS, VULNERABLE_FUNCTION, VULNERABILITY_DESCRIPTION))

In [None]:
# Save ref fixes

for fix in ref_fixes:
    fix.chat.meta(
        diff=fix.diff,
        model=fix.chat.generator.model,
        used_ref_description=True
    )

# await write_chats_to_jsonl("data/fixes.jsonl")([f.chat for f in ref_fixes])

In [None]:
# Gather model-generated descriptions

MAX_MODEL_REFS = 5

triage_chats: list[rg.Chat] = []
with open("triage.jsonl") as f:
    for line in f.readlines():
        triage_chats.append(rg.Chat.model_validate_json(line))

chats_per_model: dict[str, list[rg.Chat]] = {}
for chat in triage_chats:
    model = chat.generator.model
    chats_per_model.setdefault(model, []).append(chat)

longest: dict[str, list[str]] = {}
for model, chats in chats_per_model.items():
    sorted_chats = sorted(chats, key=lambda chat: len(chat.last.content), reverse=True)
    longest[model] = [chat.last.content for chat in sorted_chats][:MAX_MODEL_REFS]

In [None]:
# Generate fixes with model references

import asyncio # noqa: I001


model_pipes: list[rg.ChatPipeline] = [pipeline]
for model in MODELS:
    clone = pipeline.clone()
    clone.generator = rg.get_generator(model) if isinstance(model, str) else model
    model_pipes.append(clone)

model_fixes: list[Fix] = []
for i in tqdm(range(MAX_MODEL_REFS)):
    descriptions = [longest[pipe.generator.model][i] for pipe in model_pipes]
    coros = [pipe.run_prompt(fix_code, VULNERABLE_FUNCTION, description) for pipe, description in zip(model_pipes, descriptions)]
    model_fixes.extend(await asyncio.gather(*coros))

In [None]:
# Save model-generated fixes

for fix in model_fixes:
    fix.chat.meta(
        diff=fix.diff,
        model=fix.chat.generator.model,
        used_ref_description=False
    )

# await write_chats_to_jsonl("data/fixes.jsonl")([f.chat for f in model_fixes])

In [None]:
# Actual patch

REAL_PATCH = diff(VULNERABLE_FUNCTION, PATCHED_FUNCTION)

print(REAL_PATCH)

In [None]:
# Read all chats from disk

all_chats: list[rg.Chat] = []

with open("data/fixes.jsonl") as f:
    for line in f.readlines():
        chat = rg.Chat.model_validate_json(line)
        chat.last.parts.clear()
        all_chats.append(chat)

len(all_chats)

In [None]:
import editdistance  # noqa: I001


@dataclass
class Distances:
    ref_distances: list[int]
    model_distances: list[int]


stats: dict[str, Distances] = {}

for chat in all_chats:
    model = chat.generator.model
    if model not in stats:
        stats[model] = Distances([], [])

    distance = editdistance.eval(REAL_PATCH, chat.metadata["diff"])

    if chat.metadata["used_ref_description"]:
        stats[model].model_distances.append(distance)
    else:
        stats[model].ref_distances.append(distance)

stats

In [None]:
# Format stats

for model, distances in stats.items():
    print(model)
    if distances.ref_distances:
        print("Ref:   ", sum(distances.ref_distances) / len(distances.ref_distances))
    if distances.model_distances:
        print("Model: ", sum(distances.model_distances) / len(distances.model_distances))
    print()

In [None]:
# Grab a random example

gen = iter(chat for chat in all_chats if 'gpt-4-' in chat.generator.model)

inspect = next(gen)

print(inspect.metadata["diff"])