# TTP Classification Data Augmentation

## Overview
This notebook implements data augmentation for TTP (Tactics, Techniques, and Procedures) classification to expand the MITRE ATT&CK dataset from the current size to 2000 records using LLM-based techniques.

### Data Source
- **Primary**: [MITRE CTI Repository](https://github.com/mitre/cti/tree/master) - Official MITRE ATT&CK dataset in STIX 2.0 format
- **Matrices Included**:
  - **Enterprise**: General enterprise techniques (`enterprise-attack.json`)
  - **ICS**: Industrial Control Systems techniques (`ics-attack.json`)
  - **Mobile**: Mobile platform techniques (`mobile-attack.json`)
  - **Pre-ATT&CK**: Reconnaissance and resource development (`pre-attack.json`)
- **Fallback**: Local dataset file (if download fails)
- **Format**: STIX 2.0 attack-pattern objects converted to classification format

### Task Description
- **Input**: Latest MITRE ATT&CK technique descriptions from official repository
- **Output**: Augmented dataset with variations of TTP descriptions while maintaining semantic accuracy
- **Goal**: Generate 2000 high-quality TTP classification records
- **Model**: Same as entity extraction pipeline (Qwen3-14B)

### Data Structure
Each record contains:
- `instruction`: TTP technique description (input text)
- `input`: null (not used)
- `output`: Structured TTP information with technique ID, name, description, and matrix

### Augmentation Strategy
1. **Paraphrasing**: Generate semantic variations of technique descriptions
2. **Scenario expansion**: Create realistic attack scenarios using the techniques
3. **Context variation**: Present techniques in different operational contexts
4. **Terminology variation**: Use different cybersecurity terminology while preserving meaning

### Advantages of Using Official MITRE CTI
- **Always Current**: Gets the latest technique definitions and updates
- **Authoritative**: Direct from MITRE's official repository
- **Complete**: Includes all enterprise ATT&CK techniques
- **Standardized**: STIX 2.0 format ensures consistency


In [1]:
import json
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from collections import defaultdict
import datetime
import random

# Additional imports for STIX data processing and web requests
import requests
import urllib3

# Load environment and model setup
from dotenv import load_dotenv
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load environment variables
load_dotenv()

print("🔧 Setting up TTP Data Augmentation Pipeline")
print("🌐 Data Source: MITRE ATT&CK CTI Repository (https://github.com/mitre/cti)")
print("=" * 70)

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Disable SSL warnings for downloading from GitHub
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)


🔧 Setting up TTP Data Augmentation Pipeline
🌐 Data Source: MITRE ATT&CK CTI Repository (https://github.com/mitre/cti)


In [2]:
# Additional utilities for STIX data processing
from io import BytesIO

def download_mitre_attack_data() -> Dict:
    """
    Download the latest MITRE ATT&CK data from all matrices in official CTI repository.
    Includes: Enterprise, ICS, Mobile, and Pre-ATT&CK matrices.
    """
    print("🌐 Downloading MITRE ATT&CK data from official repository...")

    # MITRE CTI repository URLs for all matrices
    matrices = {
        "enterprise": "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json",
        "ics": "https://raw.githubusercontent.com/mitre/cti/master/ics-attack/ics-attack.json",
        "mobile": "https://raw.githubusercontent.com/mitre/cti/master/mobile-attack/mobile-attack.json",
        "pre-attack": "https://raw.githubusercontent.com/mitre/cti/master/pre-attack/pre-attack.json"
    }

    all_objects = []
    download_stats = {}

    for matrix_name, url in matrices.items():
        try:
            print(f"📥 Downloading {matrix_name.upper()} matrix from: {url}")

            # Download the STIX file
            response = requests.get(url, verify=False, timeout=30)
            response.raise_for_status()

            # Parse STIX data
            stix_data = response.json()
            matrix_objects = stix_data.get('objects', [])

            # Add matrix identifier to each object
            for obj in matrix_objects:
                if obj.get('type') == 'attack-pattern':
                    if 'x_mitre_domains' not in obj:
                        obj['x_mitre_domains'] = [matrix_name]
                    elif matrix_name not in obj['x_mitre_domains']:
                        obj['x_mitre_domains'].append(matrix_name)

            all_objects.extend(matrix_objects)
            download_stats[matrix_name] = len(matrix_objects)

            print(f"   ✅ {matrix_name.upper()}: {len(matrix_objects)} objects")

        except Exception as e:
            print(f"   ⚠️  Failed to download {matrix_name.upper()}: {e}")
            download_stats[matrix_name] = 0
            continue

    if all_objects:
        combined_data = {"objects": all_objects}
        print(f"\n🎯 Combined Statistics:")
        print(f"   Total objects: {len(all_objects)}")
        for matrix, count in download_stats.items():
            if count > 0:
                print(f"   {matrix.upper()}: {count} objects")

        return combined_data
    else:
        print("❌ Failed to download any matrix data")
        return None

def parse_stix_to_ttp_dataset(stix_data: Dict) -> Dict:
    """
    Parse STIX 2.0 data from MITRE CTI repository and convert to TTP classification format.
    Handles all matrices: Enterprise, ICS, Mobile, and Pre-ATT&CK.
    """
    if not stix_data or 'objects' not in stix_data:
        return {"dataset": []}

    print("🔄 Converting STIX data to TTP classification format...")

    dataset = []
    technique_count = 0
    matrix_stats = {}

    for obj in stix_data['objects']:
        # Focus on attack-pattern objects (techniques)
        if obj.get('type') == 'attack-pattern':
            try:
                # Extract technique information
                technique_id = None
                technique_name = obj.get('name', 'Unknown Technique')
                description = obj.get('description', '')

                # Determine matrix/domain
                domains = obj.get('x_mitre_domains', ['enterprise'])
                if isinstance(domains, str):
                    domains = [domains]
                primary_matrix = domains[0] if domains else 'enterprise'

                # Extract technique ID from external references
                external_refs = obj.get('external_references', [])
                for ref in external_refs:
                    if ref.get('source_name') == 'mitre-attack':
                        technique_id = ref.get('external_id')
                        break

                if not technique_id or not description:
                    continue

                # Create record in the expected format
                record = {
                    "instruction": description,
                    "input": None,
                    "output": {
                        "techniques": [
                            {
                                "id": technique_id,
                                "name": technique_name,
                                "description": description,
                                "matrix": primary_matrix,
                                "domains": domains  # Include all applicable domains
                            }
                        ]
                    }
                }

                dataset.append(record)
                technique_count += 1

                # Track statistics by matrix
                if primary_matrix not in matrix_stats:
                    matrix_stats[primary_matrix] = 0
                matrix_stats[primary_matrix] += 1

            except Exception as e:
                print(f"⚠️  Error processing technique: {e}")
                continue

    print(f"✅ Converted {technique_count} techniques to TTP format")
    print(f"📊 Techniques by matrix:")
    for matrix, count in matrix_stats.items():
        print(f"   {matrix.upper()}: {count} techniques")

    return {"dataset": dataset}

def load_ttp_data_from_mitre() -> Dict:
    """
    Load TTP classification data directly from MITRE CTI repository.
    """
    print("🎯 Loading MITRE ATT&CK data from official source...")

    # Try to download fresh data
    stix_data = download_mitre_attack_data()

    if stix_data:
        # Convert STIX to our format
        ttp_data = parse_stix_to_ttp_dataset(stix_data)

        if ttp_data['dataset']:
            print(f"✅ Successfully loaded {len(ttp_data['dataset'])} techniques from MITRE CTI")
            return ttp_data

    # Fallback to local file if download fails
    print("🔄 Download failed, falling back to local dataset...")
    local_path = 'data/TTP-classification/merged_mitre_attack_dataset.json'
    try:
        with open(local_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"✅ Loaded {len(data['dataset'])} TTP records from local file")
        return data
    except Exception as e:
        print(f"❌ Error loading local file: {e}")
        return {"dataset": []}

# Load TTP classification data from MITRE CTI
ttp_data = load_ttp_data_from_mitre()

if ttp_data['dataset']:
    print(f"\n📊 Current dataset size: {len(ttp_data['dataset'])} records")
    print(f"🎯 Target size: 2000 records")
    print(f"📈 Need to generate: {2000 - len(ttp_data['dataset'])} additional records")
    print(f"\n📋 Sample record structure:")
    sample = ttp_data['dataset'][0]
    print(f"   Keys: {list(sample.keys())}")
    print(f"   Technique ID: {sample['output']['techniques'][0]['id']}")
    print(f"   Technique Name: {sample['output']['techniques'][0]['name']}")
    print(f"   Instruction length: {len(sample['instruction'])} chars")

    # Show sample of unique technique IDs
    unique_techniques = set()
    for record in ttp_data['dataset']:
        tech_id = record['output']['techniques'][0]['id']
        unique_techniques.add(tech_id)

    sample_ids = sorted(list(unique_techniques))[:10]
    print(f"\n🔍 Sample technique IDs: {sample_ids}")
    print(f"📊 Total unique techniques: {len(unique_techniques)}")
else:
    print("❌ No TTP data loaded. Cannot proceed with augmentation.")


🎯 Loading MITRE ATT&CK data from official source...
🌐 Downloading MITRE ATT&CK data from official repository...
📥 Downloading ENTERPRISE matrix from: https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json
   ✅ ENTERPRISE: 22652 objects
📥 Downloading ICS matrix from: https://raw.githubusercontent.com/mitre/cti/master/ics-attack/ics-attack.json
   ✅ ICS: 1650 objects
📥 Downloading MOBILE matrix from: https://raw.githubusercontent.com/mitre/cti/master/mobile-attack/mobile-attack.json
   ✅ MOBILE: 2147 objects
📥 Downloading PRE-ATTACK matrix from: https://raw.githubusercontent.com/mitre/cti/master/pre-attack/pre-attack.json
   ✅ PRE-ATTACK: 268 objects

🎯 Combined Statistics:
   Total objects: 26717
   ENTERPRISE: 22652 objects
   ICS: 1650 objects
   MOBILE: 2147 objects
   PRE-ATTACK: 268 objects
🔄 Converting STIX data to TTP classification format...
✅ Converted 1076 techniques to TTP format
📊 Techniques by matrix:
   ENTERPRISE-ATTACK: 823 techniques

In [4]:
# Device setup
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

print(f"🖥️  Using device: {device.upper()}")
print(f"🔧 PyTorch version: {torch.__version__}")

# Memory cleanup
if device == "cuda":
    torch.cuda.empty_cache()
elif device == "mps":
    import gc
    gc.collect()
    if hasattr(torch.mps, 'empty_cache'):
        torch.mps.empty_cache()


🖥️  Using device: CUDA
🔧 PyTorch version: 2.7.1+cu128


In [5]:
# Model configuration - using same models as entity extraction
DEFAULT_MODEL = 'unsloth/Qwen3-1.7B-bnb-4bit'
FALLBACK_MODEL = 'unsloth/Qwen3-1.7B-bnb-4bit'

# Get Hugging Face token from environment
HF_TOKEN = os.getenv('HF_TOKEN')

print(f'Default model: {DEFAULT_MODEL}')
print(f'Fallback model: {FALLBACK_MODEL}')
print(f'HF Token: {"✅ Found" if HF_TOKEN else "❌ Missing"}')


Default model: unsloth/Qwen3-1.7B-bnb-4bit
Fallback model: unsloth/Qwen3-1.7B-bnb-4bit
HF Token: ✅ Found


In [6]:
def setup_model_for_augmentation(model_name: str = None, hf_token: str = None):
    """
    Load model from Hugging Face for TTP data augmentation.
    """
    model_name = model_name or DEFAULT_MODEL
    hf_token = hf_token or HF_TOKEN

    print(f"🤖 Loading model: {model_name}")
    print(f"📱 Device: {device.upper()}")
    print(f"🔑 Token: {'✅ Found' if hf_token else '❌ Missing'}")

    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            token=hf_token,
            trust_remote_code=True
        )
        tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

        # Setup data types and device mapping
        torch_dtype = torch.float16 if device == "cuda" else torch.float32
        device_map = "auto" if device == "cuda" else None

        # Load model
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            token=hf_token,
            trust_remote_code=True,
            torch_dtype=torch_dtype,
            device_map=device_map,
            use_cache=False
        )

        if device_map is None and device in ["mps", "cuda"]:
            model.to(device)

        if device_map is None:
            pipe = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                device=0 if device != "cpu" else -1,
                torch_dtype=torch_dtype,
                model_kwargs={"use_cache": False}
            )
        else:
            pipe = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                torch_dtype=torch_dtype,
                model_kwargs={"use_cache": False}
            )

        print(f"✅ Successfully loaded {model_name} on {device.upper()}")
        return pipe

    except Exception as e:
        print(f"❌ Error loading {model_name}: {e}")
        return setup_fallback_model(hf_token)

def setup_fallback_model(hf_token: str = None):
    """
    Load fallback model if main model fails.
    """
    fallback_name = FALLBACK_MODEL
    hf_token = hf_token or HF_TOKEN
    print(f"🔄 Loading fallback model: {fallback_name}")

    try:
        tokenizer = AutoTokenizer.from_pretrained(fallback_name, token=hf_token)
        tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            fallback_name,
            token=hf_token,
            torch_dtype=torch.float32,
            use_cache=False
        )

        if device in ["cuda", "mps"]:
            model.to(device)

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            device=0 if device != "cpu" else -1,
            model_kwargs={"use_cache": False}
        )

        print(f"✅ {FALLBACK_MODEL} ready on {device.upper()}")
        return pipe

    except Exception as e:
        print(f"❌ Error loading {FALLBACK_MODEL} fallback: {e}")
        return None

# Load model
augmentation_model = setup_model_for_augmentation()


🤖 Loading model: unsloth/Qwen3-1.7B-bnb-4bit
📱 Device: CUDA
🔑 Token: ✅ Found


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

Device set to use cuda:0


✅ Successfully loaded unsloth/Qwen3-1.7B-bnb-4bit on CUDA


In [7]:
def create_ttp_augmentation_prompts(technique_record: Dict) -> List[str]:
    """
    Create different augmentation prompts for TTP technique descriptions.
    Returns multiple prompt variations for different augmentation strategies.
    """
    original_instruction = technique_record['instruction']
    technique_info = technique_record['output']['techniques'][0]
    technique_id = technique_info['id']
    technique_name = technique_info['name']

    # Truncate original instruction to avoid token limits
    instruction_truncated = (original_instruction[:1200] if original_instruction else "").replace('\n', ' ').strip()

    prompts = []

    # Strategy 1: Paraphrasing
    paraphrase_prompt = f"""Task: Rewrite the following cybersecurity technique description using different wording while preserving the exact meaning, technical accuracy, and all important details.

Requirements:
- Keep the same technical concepts and attack methodology
- Use alternative cybersecurity terminology where appropriate
- Maintain the same level of technical detail
- Preserve all specific examples, tools, or procedures mentioned
- Output only the rewritten description, no additional text

Original description: {instruction_truncated}

Rewritten description:"""
    prompts.append(("paraphrase", paraphrase_prompt))

    # Strategy 2: Scenario expansion
    scenario_prompt = f"""Task: Create a realistic attack scenario description that demonstrates the use of the following cybersecurity technique. Include specific context, tools, and step-by-step procedures.

Technique: {technique_name} ({technique_id})
Base description: {instruction_truncated}

Requirements:
- Create a realistic operational scenario
- Include specific attacker actions and tools
- Describe the technical implementation
- Maintain technical accuracy
- Output only the scenario description, no additional text

Attack scenario:"""
    prompts.append(("scenario", scenario_prompt))

    # Strategy 3: Context variation
    context_prompt = f"""Task: Rewrite the following cybersecurity technique description from a different perspective or context while maintaining technical accuracy.

Original context: {instruction_truncated}

Requirements:
- Present from defender's perspective OR incident response viewpoint OR threat hunting angle
- Include detection indicators or mitigation considerations
- Maintain all technical details about the technique
- Use professional cybersecurity language
- Output only the rewritten description, no additional text

Alternative perspective:"""
    prompts.append(("context", context_prompt))

    # Strategy 4: Technical detail expansion
    technical_prompt = f"""Task: Expand the following cybersecurity technique description with additional technical details, implementation specifics, and operational considerations.

Base description: {instruction_truncated}

Requirements:
- Add more technical implementation details
- Include specific tools, commands, or procedures
- Explain technical prerequisites or dependencies
- Describe variations or subtechniques
- Maintain factual accuracy
- Output only the expanded description, no additional text

Detailed description:"""
    prompts.append(("technical", technical_prompt))

    return prompts

# Test prompt creation
if ttp_data['dataset']:
    sample_prompts = create_ttp_augmentation_prompts(ttp_data['dataset'][0])
    print(f"📝 Generated {len(sample_prompts)} augmentation strategies:")
    for strategy, prompt in sample_prompts:
        print(f"\n🔹 Strategy: {strategy.upper()}")
        print(f"   Prompt length: {len(prompt)} chars")
        print(f"   Preview: {prompt[:150]}...")


📝 Generated 4 augmentation strategies:

🔹 Strategy: PARAPHRASE
   Prompt length: 1724 chars
   Preview: Task: Rewrite the following cybersecurity technique description using different wording while preserving the exact meaning, technical accuracy, and al...

🔹 Strategy: SCENARIO
   Prompt length: 1705 chars
   Preview: Task: Create a realistic attack scenario description that demonstrates the use of the following cybersecurity technique. Include specific context, too...

🔹 Strategy: CONTEXT
   Prompt length: 1710 chars
   Preview: Task: Rewrite the following cybersecurity technique description from a different perspective or context while maintaining technical accuracy.

Origina...

🔹 Strategy: TECHNICAL
   Prompt length: 1683 chars
   Preview: Task: Expand the following cybersecurity technique description with additional technical details, implementation specifics, and operational considerat...


In [8]:
def augment_ttp_description(pipe, prompt: str, max_retries: int = 3) -> str:
    """
    Generate augmented TTP description using the LLM.
    """
    for attempt in range(max_retries):
        try:
            response = pipe(
                prompt,
                max_new_tokens=400,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=pipe.tokenizer.eos_token_id,
            )

            # Extract generated text
            generated_text = response[0]['generated_text']
            augmented_text = generated_text[len(prompt):].strip()

            # Clean up the response
            augmented_text = clean_augmented_text(augmented_text)

            if len(augmented_text) > 50:  # Ensure we have substantial content
                return augmented_text
            else:
                print(f"⚠️  Short response on attempt {attempt + 1}, retrying...")

        except Exception as e:
            print(f"❌ Error in augmentation attempt {attempt + 1}: {e}")

    return ""  # Return empty if all attempts fail

def clean_augmented_text(text: str) -> str:
    """
    Clean and validate the augmented text.
    """
    # Remove common unwanted prefixes/suffixes
    unwanted_prefixes = [
        "Here is", "Here's", "This is", "The following", "Below is",
        "Task:", "Requirements:", "Original description:", "Rewritten description:",
        "Attack scenario:", "Alternative perspective:", "Detailed description:"
    ]

    for prefix in unwanted_prefixes:
        if text.lower().startswith(prefix.lower()):
            text = text[len(prefix):].strip()
            if text.startswith(":"):
                text = text[1:].strip()

    # Remove trailing metadata or instructions
    text = re.sub(r'\n\n.*?(Requirements|Note|Important).*$', '', text, flags=re.IGNORECASE | re.DOTALL)

    # Clean up extra whitespace
    text = ' '.join(text.split())

    return text

def create_augmented_record(original_record: Dict, augmented_instruction: str, strategy: str) -> Dict:
    """
    Create a new augmented record with the same output structure.
    """
    return {
        "instruction": augmented_instruction,
        "input": original_record["input"],  # Usually null
        "output": original_record["output"].copy(),  # Keep same technique information
        "augmentation_strategy": strategy,
        "original_technique_id": original_record["output"]["techniques"][0]["id"]
    }

# Test the augmentation function
if augmentation_model and ttp_data['dataset']:
    print("\n🧪 Testing TTP augmentation on sample data...")
    sample_record = ttp_data['dataset'][0]
    sample_prompts = create_ttp_augmentation_prompts(sample_record)

    # Test first strategy (paraphrasing)
    strategy, prompt = sample_prompts[0]
    print(f"\n🔄 Testing {strategy} strategy...")
    print(f"Original length: {len(sample_record['instruction'])} chars")

    augmented_text = augment_ttp_description(augmentation_model, prompt)

    if augmented_text:
        print(f"Augmented length: {len(augmented_text)} chars")
        print(f"\n📄 Original: {sample_record['instruction'][:200]}...")
        print(f"\n📝 Augmented: {augmented_text[:200]}...")

        # Create augmented record
        augmented_record = create_augmented_record(sample_record, augmented_text, strategy)
        print(f"\n✅ Created augmented record with strategy: {strategy}")
    else:
        print("❌ Failed to generate augmented text")



🧪 Testing TTP augmentation on sample data...

🔄 Testing paraphrase strategy...
Original length: 2338 chars
Augmented length: 1336 chars

📄 Original: Adversaries may inject malicious code into process via Extra Window Memory (EWM) in order to evade process-based defenses as well as possibly elevate privileges. EWM injection is a method of executing...

📝 Augmented: Adversaries can infiltrate a process by exploiting a vulnerability in the Extra Window Memory (EWM) section of the Windows process memory to execute arbitrary code within the address space of a separa...

✅ Created augmented record with strategy: paraphrase


In [9]:
def _load_existing_augmented_records(output_path: Path) -> List[Dict]:
    """Load existing augmented JSON array from file; return [] if file doesn't exist/invalid."""
    if not output_path.exists():
        return []
    try:
        with output_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
            if isinstance(data, dict) and "dataset" in data:
                return data["dataset"]
            elif isinstance(data, list):
                return data
    except Exception as e:
        print(f"⚠️  Error loading existing file: {e}")
    return []

def _persist_augmented_records(output_path: Path, all_records: List[Dict]) -> None:
    """Atomically write JSON dataset to file (UTF-8, pretty, no ASCII escaping)."""
    dataset_structure = {"dataset": all_records}
    tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
    with tmp_path.open("w", encoding="utf-8") as f:
        json.dump(dataset_structure, f, ensure_ascii=False, indent=2)
        f.write("\n")
    tmp_path.replace(output_path)

def calculate_augmentation_plan(current_size: int, target_size: int, strategies: List[str]) -> Dict:
    """
    Calculate how many records to generate per strategy to reach target size.
    """
    needed = target_size - current_size
    if needed <= 0:
        return {"needed": 0, "per_strategy": {}}

    # Distribute evenly across strategies
    base_per_strategy = needed // len(strategies)
    remainder = needed % len(strategies)

    plan = {}
    for i, strategy in enumerate(strategies):
        plan[strategy] = base_per_strategy + (1 if i < remainder else 0)

    return {
        "needed": needed,
        "per_strategy": plan,
        "total_planned": sum(plan.values())
    }

def process_ttp_augmentation(
    ttp_data: Dict,
    pipe,
    target_size: int = 2000,
    output_path: Path = None,
    batch_size: int = 50
) -> List[Dict]:
    """
    Process TTP data augmentation to reach target dataset size.
    """
    current_records = ttp_data['dataset']
    current_size = len(current_records)

    print(f"🎯 TTP Data Augmentation Plan:")
    print(f"   Current size: {current_size}")
    print(f"   Target size: {target_size}")

    if current_size >= target_size:
        print(f"✅ Already at target size. No augmentation needed.")
        return current_records

    # Setup output path
    if output_path is None:
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        output_path = Path(f"data/TTP-classification/augmented_ttp_dataset_{timestamp}.json")

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Load existing augmented data if any
    existing_augmented = _load_existing_augmented_records(output_path)
    all_records = current_records + existing_augmented

    print(f"   Existing augmented: {len(existing_augmented)}\")\n   Total current: {len(all_records)}\")")

    if len(all_records) >= target_size:
        print(f"✅ Already have {len(all_records)} records. Target reached.")
        return all_records[:target_size]

    # Define augmentation strategies
    strategies = ["paraphrase", "scenario", "context", "technical"]

    # Calculate augmentation plan
    plan = calculate_augmentation_plan(len(all_records), target_size, strategies)
    print(f"\\n📋 Augmentation Plan:")
    print(f"   Records needed: {plan['needed']}")
    for strategy, count in plan['per_strategy'].items():
        print(f"   {strategy}: {count} records")

    generated_count = 0
    strategy_counts = {s: 0 for s in strategies}

    # Start augmentation process
    print(f"\\n🚀 Starting augmentation process...")

    while generated_count < plan['needed']:
        # Process in batches to save progress
        batch_generated = 0

        for original_record in current_records:
            if generated_count >= plan['needed']:
                break

            # Get augmentation prompts for this record
            prompts = create_ttp_augmentation_prompts(original_record)

            for strategy, prompt in prompts:
                # Check if we still need this strategy
                if strategy_counts[strategy] >= plan['per_strategy'][strategy]:
                    continue

                print(f"\\n🔄 Generating {strategy} variation for {original_record['output']['techniques'][0]['id']}...")

                # Generate augmented text
                augmented_text = augment_ttp_description(pipe, prompt)

                if augmented_text and len(augmented_text) > 100:
                    # Create augmented record
                    augmented_record = create_augmented_record(original_record, augmented_text, strategy)
                    all_records.append(augmented_record)

                    generated_count += 1
                    strategy_counts[strategy] += 1
                    batch_generated += 1

                    print(f"   ✅ Generated record {generated_count}/{plan['needed']} ({strategy})")

                    # Save progress periodically
                    if batch_generated % batch_size == 0:
                        _persist_augmented_records(output_path, all_records)
                        print(f"   💾 Saved progress: {len(all_records)} total records")

                        # Memory cleanup
                        if device == "cuda":
                            torch.cuda.empty_cache()

                if generated_count >= plan['needed']:
                    break

            if generated_count >= plan['needed']:
                break

        # Safety check to avoid infinite loops
        if batch_generated == 0:
            print(f"⚠️  No new records generated in this batch. Stopping.")
            break

    # Final save
    _persist_augmented_records(output_path, all_records)

    print(f"\\n🎉 Augmentation Complete!")
    print(f"   Original records: {current_size}")
    print(f"   Generated records: {generated_count}")
    print(f"   Total records: {len(all_records)}")
    print(f"   Output file: {output_path}")

    # Print strategy breakdown
    print(f"\\n📊 Generation breakdown:")
    for strategy, count in strategy_counts.items():
        print(f"   {strategy}: {count} records")

    return all_records[:target_size]


In [10]:
# Configuration for augmentation run
TARGET_SIZE = 2000
BATCH_SIZE = 25  # Save progress every 25 records

# Generate timestamp for output file
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = Path(f"data/TTP-classification/augmented_ttp_dataset_{timestamp}.json")

print(f"🎯 Starting TTP Data Augmentation")
print(f"   Target dataset size: {TARGET_SIZE}")
print(f"   Output file: {output_file}")
print(f"   Batch size: {BATCH_SIZE}")

# Display current dataset statistics
if ttp_data['dataset']:
    print(f"\n📊 Current Dataset Statistics:")
    print(f"   Total records: {len(ttp_data['dataset'])}")

    # Count unique techniques
    unique_techniques = set()
    for record in ttp_data['dataset']:
        tech_id = record['output']['techniques'][0]['id']
        unique_techniques.add(tech_id)

    print(f"   Unique techniques: {len(unique_techniques)}")
    print(f"   Records needed: {TARGET_SIZE - len(ttp_data['dataset'])}")

    # Show sample of technique IDs
    sample_ids = list(unique_techniques)[:10]
    print(f"   Sample technique IDs: {sample_ids}")
else:
    print("❌ No TTP data loaded. Cannot proceed with augmentation.")


🎯 Starting TTP Data Augmentation
   Target dataset size: 2000
   Output file: data\TTP-classification\augmented_ttp_dataset_20250824_133112.json
   Batch size: 25

📊 Current Dataset Statistics:
   Total records: 1076
   Unique techniques: 1076
   Records needed: 924
   Sample technique IDs: ['T1592.004', 'T1027.015', 'T1619', 'T1208', 'T1564.001', 'T1583.007', 'T1217', 'T1599', 'T1485', 'T1418.001']


In [None]:
# Execute the augmentation process
if augmentation_model and ttp_data['dataset']:
    print("🚀 Starting TTP data augmentation process...")

    try:
        # Run the augmentation
        augmented_dataset = process_ttp_augmentation(
            ttp_data=ttp_data,
            pipe=augmentation_model,
            target_size=TARGET_SIZE,
            output_path=output_file,
            batch_size=BATCH_SIZE
        )

        print(f"\n✅ Augmentation process completed successfully!")
        print(f"   Final dataset size: {len(augmented_dataset)}")
        print(f"   Output saved to: {output_file}")

        # Verify the output file
        if output_file.exists():
            file_size = output_file.stat().st_size / (1024 * 1024)  # Size in MB
            print(f"   File size: {file_size:.2f} MB")

    except Exception as e:
        print(f"❌ Error during augmentation process: {e}")
        import traceback
        traceback.print_exc()

else:
    print("❌ Cannot start augmentation - model or data not available")
    if not augmentation_model:
        print("   - Model not loaded")
    if not ttp_data['dataset']:
        print("   - TTP data not loaded")


In [16]:
# Validation and Quality Assessment
def validate_augmented_dataset(dataset_path: Path) -> Dict:
    """
    Validate the quality and structure of the augmented dataset.
    """
    validation_results = {
        "total_records": 0,
        "original_records": 0,
        "augmented_records": 0,
        "strategies_used": {},
        "technique_coverage": {},
        "average_lengths": {},
        "quality_issues": []
    }

    try:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        records = data['dataset'] if isinstance(data, dict) else data
        validation_results["total_records"] = len(records)

        original_count = 0
        augmented_count = 0
        strategies = defaultdict(int)
        techniques = defaultdict(int)
        lengths = defaultdict(list)

        for record in records:
            # Check if record is original or augmented
            if "augmentation_strategy" in record:
                augmented_count += 1
                strategy = record["augmentation_strategy"]
                strategies[strategy] += 1
                lengths[f"augmented_{strategy}"].append(len(record["instruction"]))
            else:
                original_count += 1
                lengths["original"].append(len(record["instruction"]))

            # Track technique coverage
            tech_id = record["output"]["techniques"][0]["id"]
            techniques[tech_id] += 1

            # Quality checks
            if len(record["instruction"]) < 50:
                validation_results["quality_issues"].append(f"Short instruction in record with technique {tech_id}")

            if not record["instruction"].strip():
                validation_results["quality_issues"].append(f"Empty instruction in record with technique {tech_id}")

        validation_results["original_records"] = original_count
        validation_results["augmented_records"] = augmented_count
        validation_results["strategies_used"] = dict(strategies)
        validation_results["technique_coverage"] = dict(techniques)

        # Calculate average lengths
        for category, length_list in lengths.items():
            if length_list:
                validation_results["average_lengths"][category] = sum(length_list) / len(length_list)

        print(f"📊 Dataset Validation Results:")
        print(f"   Total records: {validation_results['total_records']}")
        print(f"   Original records: {validation_results['original_records']}")
        print(f"   Augmented records: {validation_results['augmented_records']}")
        print(f"   Unique techniques: {len(validation_results['technique_coverage'])}")

        print(f"\n🎯 Augmentation Strategy Breakdown:")
        for strategy, count in validation_results['strategies_used'].items():
            print(f"   {strategy}: {count} records")

        print(f"\n📏 Average Instruction Lengths:")
        for category, avg_length in validation_results['average_lengths'].items():
            print(f"   {category}: {avg_length:.1f} chars")

        if validation_results['quality_issues']:
            print(f"\n⚠️  Quality Issues Found ({len(validation_results['quality_issues'])}):")
            for issue in validation_results['quality_issues'][:5]:  # Show first 5
                print(f"   - {issue}")
            if len(validation_results['quality_issues']) > 5:
                print(f"   ... and {len(validation_results['quality_issues']) - 5} more")
        else:
            print(f"\n✅ No quality issues detected!")

    except Exception as e:
        validation_results["error"] = str(e)
        print(f"❌ Validation error: {e}")

    return validation_results

# Run validation if output file exists
if 'output_file' in locals() and output_file.exists():
    print(f"\n🔍 Validating augmented dataset...")
    validation_results = validate_augmented_dataset(output_file)
else:
    print(f"\n⚠️  Cannot validate - output file not found or augmentation not completed")


📁 Output directory: ../models/qwen-ttp-classification-2025-08-23_21-48-38
✅ Training arguments configured
🎯 Batch size: 4
📈 Learning rate: 2e-05
🔄 Epochs: 3
🔥 FP16: True
💾 Gradient checkpointing: True
✅ Data collator configured


## Summary and Next Steps

### What This Notebook Accomplishes
1. **Data Loading**: Downloads the latest MITRE ATT&CK dataset directly from the [official CTI repository](https://github.com/mitre/cti/tree/master)
2. **STIX Processing**: Converts STIX 2.0 attack-pattern objects to TTP classification format
3. **Model Setup**: Uses the same Qwen3-14B model as the entity extraction pipeline
4. **Augmentation Strategies**: Implements 4 different augmentation approaches:
   - **Paraphrasing**: Semantic variations while preserving technical accuracy
   - **Scenario Expansion**: Realistic attack scenarios demonstrating the technique
   - **Context Variation**: Different perspectives (defender, incident response, threat hunting)
   - **Technical Expansion**: Additional implementation details and specifics

5. **Quality Control**:
   - Text cleaning and validation
   - Progress saving with batch processing
   - Comprehensive validation and quality assessment

6. **Output**: Structured dataset with 2000 high-quality TTP classification records

### Key Features
- **Official Data Source**: Uses latest MITRE ATT&CK data from authoritative repository
- **STIX 2.0 Support**: Processes official STIX format and converts to classification format
- **Fallback Mechanism**: Automatically falls back to local data if download fails
- **Preserves Original Structure**: Maintains the exact JSON format required for training
- **Incremental Processing**: Saves progress periodically to avoid data loss
- **Quality Validation**: Comprehensive checks for output quality and consistency
- **Memory Management**: Efficient GPU memory handling for large model inference
- **Strategy Tracking**: Tracks which augmentation strategy was used for each record

### Usage Instructions
1. Ensure internet connectivity for downloading MITRE data
2. Set required environment variables (HF_TOKEN)
3. Run cells sequentially to load data and model
4. Execute the augmentation process (Cell 10)
5. Validate the results (Cell 11)

### Expected Output
- **Input**: Latest MITRE ATT&CK technique records from all matrices:
  - **Enterprise**: ~823 techniques (general enterprise environments)
  - **ICS**: ~95 techniques (industrial control systems)
  - **Mobile**: ~188 techniques (mobile platforms)
  - **Pre-ATT&CK**: ~174 techniques (reconnaissance & resource development)
  - **Total**: ~1,280 techniques across all domains
- **Output**: 2000 total records (original + augmented variants)
- **File Format**: JSON with same structure as training datasets
- **Quality**: Validated for completeness, consistency, and technical accuracy

### Advantages of Comprehensive Data Source
- **Always Current**: Automatically gets the latest ATT&CK updates across all matrices
- **Authoritative**: Direct from MITRE's official repository
- **Complete Coverage**: All ATT&CK matrices and domains included
- **Domain Diversity**: Covers enterprise, ICS, mobile, and pre-attack scenarios
- **Quality Assured**: Official STIX 2.0 format ensures data consistency

### Integration with LLM-TIKG Pipeline
This comprehensive augmented dataset can be used for:
- **Multi-Domain TTP Classification**: Fine-tuning models across enterprise, ICS, mobile, and pre-attack domains
- **Comprehensive Knowledge Graphs**: Building threat intelligence graphs with full ATT&CK coverage
- **Cross-Platform Analysis**: Understanding attack techniques across different environments
- **Advanced Threat Research**: Supporting research with complete, current ATT&CK methodology
- **Operational Security**: Improving defense strategies across all attack domains

### Dataset Coverage Benefits
- **Enterprise Security**: Traditional IT infrastructure and network attacks
- **Industrial Security**: SCADA, PLC, and critical infrastructure threats
- **Mobile Security**: iOS and Android platform-specific techniques
- **Early-Stage Threats**: Reconnaissance and initial access methodologies
- **Cross-Domain Attacks**: Understanding how techniques apply across multiple environments
