In [None]:
# Dependencies

%pip install rigging tqdm

In [None]:
# References

VULNERABLE_FUNCTION = """\
static bool nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len)
{
    int mac_off = skb_mac_header(skb) - skb->data;
    u8 *vlanh, *dst_u8 = (u8 *) d;
    struct vlan_ethhdr veth;
    u8 vlan_hlen = 0;

    if ((skb->protocol == htons(ETH_P_8021AD) ||
         skb->protocol == htons(ETH_P_8021Q)) &&
        offset >= VLAN_ETH_HLEN && offset < VLAN_ETH_HLEN + VLAN_HLEN)
        vlan_hlen += VLAN_HLEN;

    vlanh = (u8 *) &veth;

    if (offset < VLAN_ETH_HLEN + vlan_hlen) {
        u8 ethlen = len;

        if (vlan_hlen &&
            skb_copy_bits(skb, mac_off, &veth, VLAN_ETH_HLEN) < 0)
            return false;
        else if (!nft_payload_rebuild_vlan_hdr(skb, mac_off, &veth))
            return false;

        if (offset + len > VLAN_ETH_HLEN + vlan_hlen)
            ethlen -= offset + len - VLAN_ETH_HLEN + vlan_hlen;

        memcpy(dst_u8, vlanh + offset - vlan_hlen, ethlen);

        len -= ethlen;
        if (len == 0)
            return true;

        dst_u8 += ethlen;
        offset = ETH_HLEN + vlan_hlen;
    } else {
        offset -= VLAN_HLEN + vlan_hlen;
    }

    return skb_copy_bits(skb, offset + mac_off, dst_u8, len) == 0;
}
"""


In [None]:
# Model calls

from dataclasses import dataclass

import litellm
from tqdm.notebook import tqdm

import rigging as rg

vertex_safety_override = {
    "safety_settings": [
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_ONLY_HIGH",
        },
    ]
}

MODELS = [
    rg.get_generator("vertex_ai/gemini-1.5-pro-001", params=rg.GenerateParams(extra=vertex_safety_override)),
    rg.get_generator("vertex_ai/codechat-bison", params=rg.GenerateParams(extra=vertex_safety_override)),
    "openai/gpt-4-turbo",
    "openai/gpt-4o",
    "anthropic/claude-3-sonnet-20240229",
    "mistral/mistral-large-latest",
    "mistral/mistral-medium-latest",
    "mistral/codestral-latest",
    "groq/llama3-70b-8192",
]

SYSTEM_PROMPT = """\
You are a senior Linux developer who specializes in code security.
"""

pipeline = rg.get_generator("anthropic/claude-3-opus-20240229").chat({
    "role": "system", "content": SYSTEM_PROMPT
}).catch(litellm.APIError, on_failed="skip")


@dataclass
class Triage:
    chat: rg.Chat
    is_vulnerable: bool


@pipeline.prompt
async def is_vulnerable(source: str) -> Triage:
    """
    Analyze this source code and identify if it contains a security vulnerability.
    """

triages: list[Triage] = []
for _ in tqdm(range(25)):
    triages.extend(await is_vulnerable.run_over(MODELS, VULNERABLE_FUNCTION))

In [None]:
# Save data

for triage in triages:
    triage.chat.meta(
        vulnerable=triage.is_vulnerable,
        model=triage.chat.generator.model,
    )

# await rg.watchers.write_chats_to_jsonl("data/triage.jsonl")([t.chat for t in triages])

In [None]:
# Read data

all_chats: list[rg.Chat] = []

with open("data/triage.jsonl") as f:
    for line in f.readlines():
        all_chats.append(rg.Chat.model_validate_json(line))

In [None]:
# Pull sample data

searcher = iter(
    chat for chat in all_chats if
    "gpt-4-" in chat.metadata["model"] and
    chat.metadata["vulnerable"] and
    len(chat.last.content) > 50
)

chat = next(searcher)
print(chat.conversation)

In [None]:
# Analysis

flat: dict[str, list[bool]] = {}
for chat in all_chats:
    model = chat.metadata["model"]
    flat[model] = flat.get(model, []) + [chat.metadata["vulnerable"]]

for model, results in flat.items():
    print(f"{model:<40}: {sum(results)}/{len(results)}")