In [1]:
from mitre_attack.client import SemaphoreClient
from pydantic import BaseModel
import os

from typing import Generator, Tuple, List
from pydantic import BaseModel
import pandas as pd
from mitre_attack.prompt import SYSTEM_PROMPT, USER_PROMPT
from mitre_attack.models import ModelCreate, VulnerabilityLog
from mitre_attack.utils import random_datetime, add_random_time
from datetime import datetime
import random
from uuid import uuid4

In [5]:
from mitre_attack.models import VulnerabilityLog


class Technique(BaseModel):
    technique_id: str
    name: str
    description: str

    @property
    def prompt(self) -> str:
        return f"""
        Technique ID: {self.technique_id}
        Name: {self.name}
        Description: {self.description}
        """


class MitreTechnique:
    def __init__(self, df_path: str):
        self.df = pd.read_csv(df_path)
    
    def generate_with_samples(self, n: int = 5) -> Generator[Tuple[Technique, List[Technique]], None, None]:
        """
        Generator that yields technique information with random samples from other techniques.
        
        Args:
            n: Number of random samples from other techniques (default: 5)
            
        Yields:
            Tuple of (technique, samples) where:
                - technique: Technique model of the current technique
                - samples: List of Technique models of n random other techniques
        """

        returned = 0
        for idx, row in self.df.iterrows():
            # Create the main technique object
            technique = Technique(
                technique_id=row['technique_id'],
                name=row['name'],
                description=row['description']
            )
            
            # Get random samples from other techniques (excluding the current one)
            other_techniques = self.df[self.df['technique_id'] != row['technique_id']]
            if len(other_techniques) > 0:
                samples_df = other_techniques.sample(n=min(n, len(other_techniques)))
                samples = [
                    Technique(
                        technique_id=sample_row['technique_id'],
                        name=sample_row['name'],
                        description=sample_row['description']
                    )
                    for _, sample_row in samples_df.iterrows()
                ]
            else:
                samples = []

            returned += 1
            if returned >= n:
                break
            
            yield technique, samples


class SyntheticTechniqueGenerator:
    def __init__(self, client: SemaphoreClient, sampler: MitreTechnique):
        self.client = client
        self.sampler = sampler

    def generate_message_from_technique_and_samples(self, technique: Technique, samples: list[Technique]) -> List[dict[str, str]]:
        return [
            {
                "role": "system",
                "content": SYSTEM_PROMPT
            },
            {
                "role": "user",
                "content": USER_PROMPT.format(required_technique=technique.prompt, possible_techniques="\n".join([sample.prompt for sample in samples]))
            }
        ]

    def generate_messages_list(self, limit: int = 10) -> List[List[dict[str, str]]]:
        messages_list = []
        for technique, samples in self.sampler.generate_with_samples(limit):
            message = self.generate_message_from_technique_and_samples(technique, samples)
            messages_list.append(message)
            messages_list.append(message)
        return messages_list

    def augment_model_creates(self, model_creates: list[ModelCreate]) -> list[VulnerabilityLog]:

        timestamp = random_datetime(datetime(2024, 1, 1), datetime(2025, 11, 22))
        attacker_id = str(uuid4())

        
        vulnerability_logs = []
        for model_create in model_creates:
            try:
                
                # Always update the session id
                session_id = str(uuid4())

                # Change the attacker id most of the time, but not always
                if random.random() < 0.7:
                    attacker_id = str(uuid4())

                
                for step in model_create.exploitation_steps:
                    timestamp = add_random_time(timestamp, 2, 0.5)
                    vulnerability_log = VulnerabilityLog(
                        created_at=timestamp,
                        vulnerability_type=model_create.vulnerability_type,
                        session_id=session_id,
                        attacker_id=attacker_id,
                        technique_id=step.technique_id,
                    )
                    vulnerability_logs.append(vulnerability_log)
            except Exception as e:
                print(f"Error with model create: {e}")
                continue

        return vulnerability_logs


    async def execute(self, limit: int = 10) -> list[VulnerabilityLog]:
        # Generate messages
        messages_list = self.generate_messages_list(limit)

        # Execute messages to make ModelCreate objects

        print(f"Executing {len(messages_list)} messages")
        model_creates = await self.client.batch_parse_completions(model='gpt-5-mini', messages_list=messages_list, response_format=ModelCreate)

        # Convert the ModelCreate objects to VulnerabilityLog objects
        vulnerability_logs = self.augment_model_creates(model_creates)

        return vulnerability_logs


In [8]:
df = pd.read_csv(".data/parent.csv")
len(df)

216

In [9]:
client = SemaphoreClient()
sampler = MitreTechnique(".data/parent.csv")

generator = SyntheticTechniqueGenerator(client, sampler)

logs = await generator.execute(limit=432)

Executing 432 messages
Making request 0
Making request 0
Making request 1
Making request 1
Making request 2
Making request 2
Making request 3
Making request 3
Making request 4
Making request 4
Making request 5
Making request 5
Making request 6
Making request 6
Making request 7
Making request 7
Making request 8
Making request 8
Making request 9
Making request 9
Making request 10
Making request 10
Making request 11
Making request 11
Making request 12
Making request 12
Making request 13
Making request 13
Making request 14
Making request 14
Making request 15
Making request 16
Making request 17
Making request 18
Making request 19
Making request 20
Making request 21
Making request 22
Making request 23
Making request 24
Making request 25
Making request 26
Making request 27
Making request 28
Making request 29
Making request 30
Making request 31
Making request 32
Making request 33
Making request 34
Making request 35
Making request 36
Making request 37
Making request 38
Making request 39
Making 

In [10]:
len(logs)

3829

In [11]:
for log in logs:
    log.session_id = log.session_id[10:]

In [12]:
logs[0]

VulnerabilityLog(technique_id='T1595', id=UUID('4935b846-84a4-4a3d-b2e1-960d39e7a6e7'), created_at=datetime.datetime(2024, 9, 22, 3, 11, 38, 241317), vulnerability_type='broken-function-level-authorization', session_id='5fd-4bb0-bd6c-c28c9aa05e9d', attacker_id='18400763-dcb8-45cf-9a42-05138805572c')

In [14]:
from typing import Any

def log_to_dict(logs: List[VulnerabilityLog]) -> List[dict[str, Any]]:

    res = []

    for log in logs:
        res.append({
            "id": log.id,
            "base_url": "https://example.com",
            "vulnerability_type": log.vulnerability_type,
            "technique_id": log.technique_id,
            "timestamp": log.created_at,
            "attacker_id": log.attacker_id,
            "session_id": log.session_id,
            "is_synthetic": True,
        })

    return res

df = pd.DataFrame(log_to_dict(logs))
df.to_csv(".data/synthetic_logs.csv", index=False)

In [16]:
df = df[df['technique_id'].str.startswith('T', na=False)]

In [18]:
df.to_csv(".data/synthetic_logs.csv", index=False)

In [19]:
# Update the id column to a new set of uuids
df['id'] = df['id'].apply(lambda x: str(uuid4()))
df.to_csv(".data/synthetic_logs.csv", index=False)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['id'] = df['id'].apply(lambda x: str(uuid4()))
