In [20]:
import json
import os
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any, Tuple
import logging
import re
from tqdm import tqdm

# Transformers and training
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("📚 Dependencies loaded successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🤖 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


📚 Dependencies loaded successfully!
🔥 PyTorch version: 2.8.0
🤖 CUDA available: False


In [21]:
# Load merged MITRE ATT&CK dataset
dataset_path = "../data/TTP-classification/merged_mitre_attack_dataset.json"

print("📂 Loading MITRE ATT&CK dataset...")
with open(dataset_path, 'r', encoding='utf-8') as f:
    mitre_data = json.load(f)

# Extract training samples
training_samples = mitre_data['dataset']
print(f"✅ Loaded {len(training_samples):,} training samples")

# Analyze dataset structure
print("\n📊 Dataset Analysis:")
matrices = {'enterprise': 0, 'mobile': 0, 'ics': 0}
techniques_count = 0
sub_techniques_count = 0
technique_ids = set()

for sample in training_samples:
    technique = sample['output']['techniques'][0]
    matrix = technique['matrix']
    technique_id = technique['id']

    matrices[matrix] += 1
    technique_ids.add(technique_id)

    if '.' in technique_id:
        sub_techniques_count += 1
    else:
        techniques_count += 1

print(f"  • Enterprise: {matrices['enterprise']:,} samples")
print(f"  • Mobile: {matrices['mobile']:,} samples")
print(f"  • ICS: {matrices['ics']:,} samples")
print(f"  • Techniques: {techniques_count:,} samples")
print(f"  • Sub-techniques: {sub_techniques_count:,} samples")
print(f"  • Unique technique IDs: {len(technique_ids):,}")

# Show sample data
print("\n📝 Sample Training Example:")
sample = training_samples[0]
print(f"Instruction: {sample['instruction'][:100]}...")
print(f"Technique ID: {sample['output']['techniques'][0]['id']}")
print(f"Technique Name: {sample['output']['techniques'][0]['name']}")
print(f"Matrix: {sample['output']['techniques'][0]['matrix']}")


📂 Loading MITRE ATT&CK dataset...
✅ Loaded 921 training samples

📊 Dataset Analysis:
  • Enterprise: 691 samples
  • Mobile: 135 samples
  • ICS: 95 samples
  • Techniques: 406 samples
  • Sub-techniques: 515 samples
  • Unique technique IDs: 921

📝 Sample Training Example:
Instruction: Adversaries may inject malicious code into process via Extra Window Memory (EWM) in order to evade p...
Technique ID: T1055.011
Technique Name: Extra Window Memory Injection
Matrix: enterprise


In [22]:
def format_training_sample(sample: Dict[str, Any]) -> Dict[str, str]:
    """
    Format training sample for instruction tuning

    Input format:
    {
        "instruction": "Adversary behavior description...",
        "output": {
            "techniques": [{
                "id": "T1055.011",
                "name": "Extra Window Memory Injection",
                "description": "...",
                "matrix": "enterprise"
            }]
        }
    }

    Output format for instruction tuning:
    {
        "text": "<|im_start|>system\nYou are a cybersecurity expert...\n<|im_end|>\n<|im_start|>user\n...\n<|im_end|>\n<|im_start|>assistant\n...\n<|im_end|>"
    }
    """
    instruction = sample['instruction']
    technique = sample['output']['techniques'][0]

    # Create system prompt
    system_prompt = """You are a cybersecurity expert specializing in MITRE ATT&CK framework. Your task is to analyze threat intelligence descriptions and identify the corresponding MITRE ATT&CK techniques.

Given a description of adversary behavior, identify the most relevant MITRE ATT&CK technique and provide:
1. Technique ID (e.g., T1055.011)
2. Technique Name
3. Matrix (enterprise/mobile/ics)

Respond in JSON format."""

    # Create user input
    user_input = f"Analyze this threat behavior and identify the MITRE ATT&CK technique:\n\n{instruction}"

    # Create assistant response
    assistant_response = json.dumps({
        "technique_id": technique['id'],
        "technique_name": technique['name'],
        "matrix": technique['matrix'],
        "description": technique['description']
    }, ensure_ascii=False)

    # Format for Qwen chat template
    formatted_text = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
{assistant_response}<|im_end|>"""

    return {"text": formatted_text}

# Format all training samples
print("🔄 Formatting training data...")
formatted_samples = []
for sample in tqdm(training_samples):
    formatted_sample = format_training_sample(sample)
    formatted_samples.append(formatted_sample)

print(f"✅ Formatted {len(formatted_samples):,} training samples")

# Show formatted example
print("\n📝 Formatted Training Example:")
print(formatted_samples[0]['text'][:500] + "...")


🔄 Formatting training data...


100%|██████████| 921/921 [00:00<00:00, 143231.52it/s]

✅ Formatted 921 training samples

📝 Formatted Training Example:
<|im_start|>system
You are a cybersecurity expert specializing in MITRE ATT&CK framework. Your task is to analyze threat intelligence descriptions and identify the corresponding MITRE ATT&CK techniques.

Given a description of adversary behavior, identify the most relevant MITRE ATT&CK technique and provide:
1. Technique ID (e.g., T1055.011)
2. Technique Name
3. Matrix (enterprise/mobile/ics)

Respond in JSON format.<|im_end|>
<|im_start|>user
Analyze this threat behavior and identify the MITRE ...





In [23]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
MAX_LENGTH = 2048
BATCH_SIZE = 4
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
WARMUP_STEPS = 100

print(f"🤖 Setting up model: {MODEL_NAME}")

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Using device: {device}")

# Load tokenizer
print("📝 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right"
)

# Add padding token if not exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("🔧 Set pad_token = eos_token")

print(f"✅ Tokenizer loaded")
print(f"📊 Vocab size: {len(tokenizer):,}")
print(f"🔑 Special tokens: {tokenizer.special_tokens_map}")

# Load model
print("🤖 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True
)

print(f"✅ Model loaded")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"🎯 Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


🤖 Setting up model: Qwen/Qwen2.5-1.5B-Instruct
🔧 Using device: cpu
📝 Loading tokenizer...
✅ Tokenizer loaded
📊 Vocab size: 151,665
🔑 Special tokens: {'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}
🤖 Loading model...
✅ Model loaded
📊 Model parameters: 1,543,714,304
🎯 Trainable parameters: 1,543,714,304


In [24]:
def tokenize_function(examples):
    """
    Tokenize training examples
    """
    # Tokenize the text
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",  # SỬA LỖI: Pad to max length
        max_length=MAX_LENGTH,
        return_tensors=None
    )

    # For causal language modeling, labels are the same as input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()

    return tokenized

# Create train/validation split
print("📊 Creating train/validation split...")
train_samples, val_samples = train_test_split(
    formatted_samples,
    test_size=0.1,
    random_state=42,
    stratify=[sample['text'].split('"matrix": "')[1].split('"')[0] for sample in formatted_samples]
)

print(f"📚 Training samples: {len(train_samples):,}")
print(f"🔍 Validation samples: {len(val_samples):,}")

# Create HuggingFace datasets
print("🔄 Creating datasets...")
train_dataset = Dataset.from_list(train_samples)
val_dataset = Dataset.from_list(val_samples)

# Tokenize datasets
print("🔤 Tokenizing datasets...")
train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
    desc="Tokenizing train data"
)

val_dataset = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
    desc="Tokenizing validation data"
)

print("✅ Datasets prepared")
print(f"📊 Train dataset: {len(train_dataset):,} samples")
print(f"📊 Validation dataset: {len(val_dataset):,} samples")

# Check tokenization
sample_tokens = train_dataset[0]
print(f"\n📝 Sample tokenization:")
print(f"Input IDs length: {len(sample_tokens['input_ids'])}")
print(f"Attention mask length: {len(sample_tokens['attention_mask'])}")
print(f"Labels length: {len(sample_tokens['labels'])}")


📊 Creating train/validation split...
📚 Training samples: 828
🔍 Validation samples: 93
🔄 Creating datasets...
🔤 Tokenizing datasets...


Tokenizing train data:   0%|          | 0/828 [00:00<?, ? examples/s]

Tokenizing validation data:   0%|          | 0/93 [00:00<?, ? examples/s]

✅ Datasets prepared
📊 Train dataset: 828 samples
📊 Validation dataset: 93 samples

📝 Sample tokenization:
Input IDs length: 2048
Attention mask length: 2048
Labels length: 2048


In [25]:
# Create output directory
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = f"../models/qwen-ttp-classification-{timestamp}"
os.makedirs(output_dir, exist_ok=True)

print(f"📁 Output directory: {output_dir}")

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,

    # Training parameters
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_steps=WARMUP_STEPS,

    # Optimization
    fp16=True,
    dataloader_pin_memory=False,
    gradient_checkpointing=True,

    # Logging and saving
    logging_dir=f"{output_dir}/logs",
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    eval_steps=100,
    eval_strategy="steps",  # FIXED: was evaluation_strategy

    # Other settings
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # Report
    report_to=None,  # Disable wandb/tensorboard
    run_name=f"qwen-ttp-classification-{timestamp}"
)

print("✅ Training arguments configured")
print(f"🎯 Batch size: {BATCH_SIZE}")
print(f"📈 Learning rate: {LEARNING_RATE}")
print(f"🔄 Epochs: {NUM_EPOCHS}")
print(f"🔥 FP16: {training_args.fp16}")
print(f"💾 Gradient checkpointing: {training_args.gradient_checkpointing}")

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # No masking for causal LM
)

print("✅ Data collator configured")


📁 Output directory: ../models/qwen-ttp-classification-2025-08-09_10-27-13
✅ Training arguments configured
🎯 Batch size: 4
📈 Learning rate: 2e-05
🔄 Epochs: 3
🔥 FP16: True
💾 Gradient checkpointing: True
✅ Data collator configured


In [26]:
# ✅ CODE ĐÃ SỬA HOÀN CHỈNH CHO MPS (APPLE SILICON)
print("🔧 Initializing MPS-compatible training...")

import torch
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datetime import datetime
import os

# Clear memory
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
    print("🧹 MPS cache cleared")

# Configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
BATCH_SIZE = 2  # Smaller for MPS
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
WARMUP_STEPS = 100

# Create output directory
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = f"../models/qwen-ttp-mps-{timestamp}"
os.makedirs(output_dir, exist_ok=True)
print(f"📁 Output: {output_dir}")

# ✅ MPS-COMPATIBLE TrainingArguments (SỬA LỖI)
training_args_mps = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,

    # Training params
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=2,  # Tăng để bù đắp batch nhỏ
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    warmup_steps=WARMUP_STEPS,

    # ⚠️ QUAN TRỌNG: MPS settings
    fp16=False,              # PHẢI False cho MPS
    bf16=False,              # PHẢI False cho MPS
    dataloader_pin_memory=False,
    gradient_checkpointing=True,

    # Logging
    logging_steps=10,
    save_steps=100,
    eval_steps=100,
    eval_strategy="steps",   # SỬA: không phải evaluation_strategy
    save_total_limit=3,

    # Evaluation
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    # Disable external services
    report_to=None,
    dataloader_num_workers=0,  # Single thread cho MPS
)

# ✅ Data collator (SỬA LỖI)
data_collator_mps = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

print("✅ Configuration created!")
print(f"🔥 FP16: {training_args_mps.fp16} (phải False)")
print(f"🍎 BF16: {training_args_mps.bf16} (phải False)")

# ✅ Initialize Trainer (SỬA LỖI - KHÔNG có accelerator argument)
print("\n🏃‍♂️ Creating trainer...")

os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
os.environ["ACCELERATE_DISABLE_RICH"] = "1"

# Khởi tạo trainer
trainer = Trainer(
    model=model,
    args=training_args_mps,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator_mps,
)

print("✅ Trainer initialized successfully!")
print("🍎 Ready for MPS training!")
print(f"📊 Batch size: {BATCH_SIZE}")
print(f"🔄 Effective batch: {BATCH_SIZE * 2} (với gradient accumulation)")
print(f"📈 Learning rate: {LEARNING_RATE}")

# ✅ Bây giờ có thể bắt đầu training
print("\n🚀 Ready to train! Run: trainer.train()")

🔧 Initializing MPS-compatible training...
🧹 MPS cache cleared
📁 Output: ../models/qwen-ttp-mps-2025-08-09_10-27-13
✅ Configuration created!
🔥 FP16: False (phải False)
🍎 BF16: False (phải False)

🏃‍♂️ Creating trainer...
✅ Trainer initialized successfully!

  trainer = Trainer(



🍎 Ready for MPS training!
📊 Batch size: 2
🔄 Effective batch: 4 (với gradient accumulation)
📈 Learning rate: 2e-05

🚀 Ready to train! Run: trainer.train()


In [27]:
# Evaluate the model
print("📊 Evaluating model...")
eval_result = trainer.evaluate()

print("\n📈 Evaluation Results:")
for key, value in eval_result.items():
    print(f"  {key}: {value:.4f}")

# Save evaluation results
eval_file = os.path.join(output_dir, "evaluation_results.json")
with open(eval_file, 'w') as f:
    json.dump(eval_result, f, indent=2)

print(f"✅ Evaluation results saved to: {eval_file}")


📊 Evaluating model...



KeyboardInterrupt



In [28]:
def test_ttp_classification(model, tokenizer, threat_description: str) -> Dict[str, Any]:
    """
    Test TTP classification on a threat description
    """
    system_prompt = """You are a cybersecurity expert specializing in MITRE ATT&CK framework. Your task is to analyze threat intelligence descriptions and identify the corresponding MITRE ATT&CK techniques.

Given a description of adversary behavior, identify the most relevant MITRE ATT&CK technique and provide:
1. Technique ID (e.g., T1055.011)
2. Technique Name
3. Matrix (enterprise/mobile/ics)

Respond in JSON format."""

    user_input = f"Analyze this threat behavior and identify the MITRE ATT&CK technique:\n\n{threat_description}"

    # Format input
    prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
"""

    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assistant_response = full_response.split("<|im_start|>assistant\n")[-1]

    return {
        "input": threat_description,
        "response": assistant_response,
        "full_prompt": prompt
    }

# Test examples
test_cases = [
    "Adversaries may inject malicious code into processes in order to evade process-based defenses or elevate privileges.",
    "Attackers send phishing emails with malicious attachments to gain initial access to the target system.",
    "The malware establishes persistence by creating scheduled tasks that execute at system startup.",
    "Adversaries may abuse elevation control mechanisms to gain higher-level permissions on a system."
]

print("🧪 Testing model inference...")
for i, test_case in enumerate(test_cases, 1):
    print(f"\n📝 Test Case {i}:")
    print(f"Input: {test_case}")

    result = test_ttp_classification(model, tokenizer, test_case)
    print(f"Output: {result['response']}")
    print("-" * 80)


🧪 Testing model inference...

📝 Test Case 1:
Input: Adversaries may inject malicious code into processes in order to evade process-based defenses or elevate privileges.
Output: system
You are a cybersecurity expert specializing in MITRE ATT&CK framework. Your task is to analyze threat intelligence descriptions and identify the corresponding MITRE ATT&CK techniques.

Given a description of adversary behavior, identify the most relevant MITRE ATT&CK technique and provide:
1. Technique ID (e.g., T1055.011)
2. Technique Name
3. Matrix (enterprise/mobile/ics)

Respond in JSON format.
user
Analyze this threat behavior and identify the MITRE ATT&CK technique:

Adversaries may inject malicious code into processes in order to evade process-based defenses or elevate privileges.
assistant
```json
{
  "TechniqueID": "T1078",
  "TechniqueName": "Elevate Privileges via Process Injection",
  "Matrix": "enterprise"
}
```

Explanation:
- **TechniqueID**: `T1078` - This technique involves elevating priv

In [None]:
import os
import re
import json
import uuid
from datetime import datetime
from typing import Dict, Any, List, Tuple, Optional

import pandas as pd
import networkx as nx
from tqdm import tqdm

# ========== Config ==========
# Input file with threat intel. You can switch to any structured JSON/JSONL/CSV.
INPUT_FILE = "../data/raw/threat_intelligence_multi_source_20250726_231524.json"  # adjust if needed
# Optional: limit number of records for quick tests
MAX_RECORDS: Optional[int] = 5

# Field mappings (adjust to your data schema)
FIELD_MAP = {
    "id": ["id", "_id", "uuid"],
    "threat_type": ["threat_type", "type", "category", "topic"],
    "title": ["title", "headline"],
    "text": ["text", "content", "description", "body"],
    "hashes": ["hashes", "artifacts.hashes", "ioc.hashes", "ioc"],
}

# Output directory
REL_TIMESTAMP = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
REL_OUT_DIR = f"../models/ttp_batch_analysis_{REL_TIMESTAMP}"
os.makedirs(REL_OUT_DIR, exist_ok=True)

HASH_REGEX = re.compile(r"\b([A-Fa-f0-9]{32}|[A-Fa-f0-9]{40}|[A-Fa-f0-9]{64})\b")  # MD5|SHA1|SHA256


def _coerce_text(raw: Any) -> str:
    """Coerce list/dict/other into a single plain text string."""
    if raw is None:
        return ""
    if isinstance(raw, str):
        return raw
    if isinstance(raw, list):
        return " ".join([str(x) for x in raw if isinstance(x, (str, int, float))])
    if isinstance(raw, dict):
        return " ".join([str(v) for v in raw.values() if isinstance(v, (str, int, float))])
    return str(raw)

def _first_present(d: Dict[str, Any], keys: List[str], default: Any = None):
    for k in keys:
        # support dotted paths like 'artifacts.hashes'
        node = d
        valid = True
        for part in k.split('.'):
            if isinstance(node, dict) and part in node:
                node = node[part]
            else:
                valid = False
                break
        if valid:
            return node
    return default


def _extract_hashes(record: Dict[str, Any]) -> List[str]:
    # Try mapped fields first
    hashes_field = _first_present(record, FIELD_MAP["hashes"], default=None)
    found: List[str] = []
    if isinstance(hashes_field, list):
        for x in hashes_field:
            if isinstance(x, str) and HASH_REGEX.fullmatch(x):
                found.append(x.lower())
            elif isinstance(x, dict):
                for v in x.values():
                    if isinstance(v, str) and HASH_REGEX.fullmatch(v):
                        found.append(v.lower())
    elif isinstance(hashes_field, dict):
        for v in hashes_field.values():
            if isinstance(v, str) and HASH_REGEX.fullmatch(v):
                found.append(v.lower())
            elif isinstance(v, list):
                for x in v:
                    if isinstance(x, str) and HASH_REGEX.fullmatch(x):
                        found.append(x.lower())
    # Fallback: scan text (coerce list/dict to string)
    raw_text = _first_present(record, FIELD_MAP["text"], default="")
    text = _coerce_text(raw_text)
    found += [h.lower() for h in HASH_REGEX.findall(text)]
    # Normalize unique
    uniq = sorted(set(found))
    return uniq


def _make_prompt(threat_text: str) -> str:
    system_prompt = (
        "You are a cybersecurity expert specializing in MITRE ATT&CK framework. "
        "Your task is to analyze threat intelligence descriptions and identify the corresponding MITRE ATT&CK techniques.\n\n"
        "Given a description of adversary behavior, identify the most relevant MITRE ATT&CK technique and provide:\n"
        "1. Technique ID (e.g., T1055.011)\n"
        "2. Technique Name\n"
        "3. Matrix (enterprise/mobile/ics)\n\n"
        "Respond in JSON format."
    )
    user_input = f"Analyze this threat behavior and identify the MITRE ATT&CK technique:\n\n{threat_text}"
    return (
        f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        f"<|im_start|>user\n{user_input}<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )


def classify_one(text_value: str) -> Tuple[Optional[str], Optional[str], Optional[str], str]:
    """Run model inference; return (technique_id, technique_name, matrix, raw_response)."""
    if not text_value or not isinstance(text_value, str):
        return None, None, None, ""

    prompt = _make_prompt(text_value)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.0,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assistant_response = full_response.split("<|im_start|>assistant\n")[-1]

    # Try to parse JSON
    tech_id = tech_name = matrix = None
    try:
        # Extract JSON block if wrapped
        json_match = re.search(r"\{[\s\S]*\}", assistant_response)
        if json_match:
            data = json.loads(json_match.group(0))
        else:
            data = json.loads(assistant_response)
        tech_id = (data.get("technique_id") or data.get("id") or "").strip() or None
        tech_name = (data.get("technique_name") or data.get("name") or "").strip() or None
        matrix = (data.get("matrix") or "").strip() or None
    except Exception:
        pass

    # Ensure technique id label format
    if tech_id and not re.match(r"^T\d{4}(?:\.\d{3})?$", tech_id):
        tech_id = None

    return tech_id, tech_name, matrix, assistant_response


def load_records(input_path: str) -> List[Dict[str, Any]]:
    if input_path.endswith(".jsonl"):
        rows = []
        with open(input_path, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    rows.append(json.loads(line))
        return rows
    elif input_path.endswith(".json"):
        with open(input_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        # try top-level list or known keys
        if isinstance(data, list):
            return data
        for key in ["data", "items", "records", "dataset", "documents"]:
            if key in data and isinstance(data[key], list):
                return data[key]
        # fallback: single object
        return [data]
    elif input_path.endswith(".csv"):
        df = pd.read_csv(input_path)
        return df.to_dict(orient="records")
    else:
        raise ValueError(f"Unsupported file type: {input_path}")


print(f"📂 Loading records from: {INPUT_FILE}")
records = load_records(INPUT_FILE)
if MAX_RECORDS is not None:
    records = records[:MAX_RECORDS]
print(f"✅ Loaded {len(records)} records")

# Inference
results: List[Dict[str, Any]] = []
print("🧠 Running TTP classification...")
for rec in tqdm(records):
    rec_id = _first_present(rec, FIELD_MAP["id"], default=str(uuid.uuid4()))
    raw_val = _first_present(rec, FIELD_MAP["text"], default="")
    text_val = _coerce_text(raw_val)
    threat_type = _first_present(rec, FIELD_MAP["threat_type"], default=None)
    title = _first_present(rec, FIELD_MAP["title"], default=None)

    tech_id, tech_name, matrix, raw_resp = classify_one(text_val)
    hashes = _extract_hashes(rec)

    results.append({
        "record_id": rec_id,
        "title": title,
        "threat_type": threat_type,
        "text": text_val,
        "technique_id": tech_id,      # label: technique ID
        "technique_name": tech_name,
        "matrix": matrix,
        "hashes": hashes,
        "raw_response": raw_resp,
    })

# Save predictions
pred_df = pd.DataFrame(results)
pred_csv = os.path.join(REL_OUT_DIR, "ttp_predictions.csv")
pred_jsonl = os.path.join(REL_OUT_DIR, "ttp_predictions.jsonl")
pred_df.to_csv(pred_csv, index=False)
with open(pred_jsonl, "w", encoding="utf-8") as f:
    for row in results:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"📄 Saved predictions to: {pred_csv} and {pred_jsonl}")

# Build relationships graph: ThreatType -> TechniqueID, Hash -> TechniqueID
G = nx.DiGraph()

for row in results:
    rec_id = row["record_id"]
    threat_type = row.get("threat_type")
    technique_id = row.get("technique_id")
    tech_name = row.get("technique_name")
    matrix = row.get("matrix")
    hashes = row.get("hashes") or []

    # Nodes
    if threat_type:
        G.add_node(("threat_type", threat_type), label="threat_type", name=threat_type)
    if technique_id:
        G.add_node(("technique", technique_id), label="technique", name=technique_id, tech_name=tech_name, matrix=matrix)
    for h in hashes:
        G.add_node(("hash", h), label="hash", name=h)

    # Edges
    if threat_type and technique_id:
        G.add_edge(("threat_type", threat_type), ("technique", technique_id), relation="HAS_TECHNIQUE")
    for h in hashes:
        if technique_id:
            G.add_edge(("hash", h), ("technique", technique_id), relation="INDICATES_TECHNIQUE")
        # Optionally link record->hash for provenance
        G.add_edge(("hash", h), ("threat_type", threat_type) if threat_type else ("record", rec_id), relation="ASSOCIATED_WITH")

# Export graph
gexf_path = os.path.join(REL_OUT_DIR, "ttp_relations.gexf")
nx.write_gexf(G, gexf_path)
print(f"🕸️ Graph exported to: {gexf_path}")

# Export edge list JSON
edges = [
    {
        "source_type": s[0],
        "source": s[1],
        "target_type": t[0],
        "target": t[1],
        "relation": data.get("relation"),
    }
    for s, t, data in G.edges(data=True)
]
edges_path = os.path.join(REL_OUT_DIR, "ttp_edges.json")
with open(edges_path, "w", encoding="utf-8") as f:
    json.dump(edges, f, indent=2, ensure_ascii=False)
print(f"🔗 Edge list saved to: {edges_path}")

print("✅ Done.")


📂 Loading records from: ../data/raw/threat_intelligence_multi_source_20250726_231524.json
✅ Loaded 257 records
🧠 Running TTP classification...


 20%|█▉        | 51/257 [19:30<1:18:49, 22.96s/it]  


KeyboardInterrupt: 

In [12]:
# Summaries
from collections import Counter

# Threat type ↔ Technique ID
pair_counts_tt = Counter()
for row in results:
    threat_type = row.get("threat_type")
    technique_id = row.get("technique_id")
    if threat_type and technique_id:
        pair_counts_tt[(threat_type, technique_id)] += 1

summary_tt = pd.DataFrame(
    [(tt, tid, c) for (tt, tid), c in pair_counts_tt.items()],
    columns=["threat_type", "technique_id", "count"],
).sort_values("count", ascending=False)

# Hash ↔ Technique ID
pair_counts_ht = Counter()
for row in results:
    technique_id = row.get("technique_id")
    for h in row.get("hashes") or []:
        if technique_id:
            pair_counts_ht[(h, technique_id)] += 1

summary_ht = pd.DataFrame(
    [(h, tid, c) for (h, tid), c in pair_counts_ht.items()],
    columns=["hash", "technique_id", "count"],
).sort_values("count", ascending=False)

# Save
summary_tt_csv = os.path.join(REL_OUT_DIR, "summary_threattype_technique.csv")
summary_ht_csv = os.path.join(REL_OUT_DIR, "summary_hash_technique.csv")
summary_tt.to_csv(summary_tt_csv, index=False)
summary_ht.to_csv(summary_ht_csv, index=False)

print(f"📄 Saved: {summary_tt_csv}")
print(f"📄 Saved: {summary_ht_csv}")

# Display top rows
display(summary_tt.head(20))
display(summary_ht.head(20))


📄 Saved: ../models/ttp_batch_analysis_2025-08-09_10-15-44/summary_threattype_technique.csv
📄 Saved: ../models/ttp_batch_analysis_2025-08-09_10-15-44/summary_hash_technique.csv


Unnamed: 0,threat_type,technique_id,count


Unnamed: 0,hash,technique_id,count


In [None]:
# Quick test overrides
TOP_K = 3  # limit techniques per record
print(f"TOP_K overridden to: {TOP_K}")


wandb-core(10761) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(10762) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(10765) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(10766) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [35]:
from typing import Iterable

# Settings for multi-technique extraction
TOP_K = 6  # max techniques to extract per record
SAVE_JSON_PATH = "../data/TTP-classification/extended_relationships_output.json"

# Local fallback in case _coerce_text wasn't executed in earlier cells
try:
    _coerce_text
except NameError:
    def _coerce_text(raw):
        if raw is None:
            return ""
        if isinstance(raw, str):
            return raw
        if isinstance(raw, list):
            return " ".join(str(x) for x in raw if isinstance(x, (str, int, float)))
        if isinstance(raw, dict):
            return " ".join(str(v) for v in raw.values() if isinstance(v, (str, int, float)))
        return str(raw)

def classify_multiple_techniques(threat_text: str, top_k: int = TOP_K) -> List[Dict[str, Any]]:
    """
    Query the model to extract up to top_k techniques from text.
    Returns list of {id, name, description, matrix} using technique ids as labels.
    """
    if not threat_text or not isinstance(threat_text, str):
        return []

    system_prompt = (
        "You are a cybersecurity expert specializing in MITRE ATT&CK framework. "
        "Read the text and return a JSON object with a 'techniques' array listing up to N relevant techniques. "
        "Each technique must include: id, name, description, matrix. Use MITRE technique ids strictly (e.g., T1218, T1218.011)."
    )
    user_input = (
        "Extract up to N MITRE ATT&CK techniques from the following content. "
        f"N={top_k}. Return JSON with key 'techniques'.\n\nContent:\n{threat_text}"
    )

    prompt = (
        f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        f"<|im_start|>user\n{user_input}<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.2,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
    content = resp.split("<|im_start|>assistant\n")[-1]

    # Parse JSON
    techniques: List[Dict[str, Any]] = []
    try:
        json_match = re.search(r"\{[\s\S]*\}", content)
        data = json.loads(json_match.group(0) if json_match else content)
        items = data.get("techniques") or data
        if isinstance(items, dict) and "techniques" in items:
            items = items["techniques"]
        if isinstance(items, list):
            for t in items[:top_k]:
                tid = (t.get("id") or t.get("technique_id") or "").strip()
                if not re.match(r"^T\d{4}(?:\.\d{3})?$", tid):
                    continue
                techniques.append({
                    "id": tid,
                    "name": t.get("name"),
                    "description": t.get("description"),
                    "matrix": t.get("matrix"),
                })
    except Exception:
        pass

    # Deduplicate by id
    seen = set()
    uniq: List[Dict[str, Any]] = []
    for t in techniques:
        if t["id"] in seen:
            continue
        seen.add(t["id"])
        uniq.append(t)
    return uniq


def build_extended_relationships(record: Dict[str, Any], techniques: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Construct extended_relationships.relationships using the requested format:
    [ [subject, predicate, object], ... ]
    - Map high-level threat types/entities to technique ids
    - Also include 'associated_with' links from the main threat label to all techniques
    """
    title = _first_present(record, FIELD_MAP["title"], default=None)
    link = record.get("link") or record.get("url")
    raw_content = _first_present(record, FIELD_MAP["text"], default="")
    content_text = _coerce_text(raw_content)

    # Use existing entity extraction section if present
    extraction = record.get("extraction") or {}
    ents = extraction.get("entities") or []

    # Collect subjects from entities and threat_type field
    subjects: List[str] = []
    if ents:
        for item in ents:
            if isinstance(item, (list, tuple)) and len(item) >= 2:
                subjects.append(str(item[0]))
    tt = _first_present(record, FIELD_MAP["threat_type"], default=None)
    if tt:
        subjects.append(str(tt))

    subjects = [s for s in {s.strip(): None for s in subjects}.keys() if s]

    # Relationships
    rels: List[List[str]] = []
    # Example mapping rules: ransomware uses T1486 etc. We simply create generic relations here.
    for subj in subjects:
        for t in techniques:
            rels.append([subj, "associated_with", t["id"]])
    # Add specific 'uses' if subject text suggests usage verbs in content
    verbs = ["use", "uses", "abuse", "abuses", "leverage", "leverages", "employ", "employs"]
    text_lower = content_text.lower() if isinstance(content_text, str) else _coerce_text(content_text).lower()
    if any(v in text_lower for v in verbs):
        for subj in subjects:
            for t in techniques:
                rels.append([subj, "uses", t["id"]])

    return {
        "title": title,
        "link": link,
        "content": content_text,
        "extraction": extraction or None,
        "ttp_classification": {
            "instruction": None,
            "input": None,
            "output": {"techniques": techniques},
        },
        "extended_relationships": {"relationships": rels},
    }


print("🧩 Building extended relationships per record...")
extended_docs: List[Dict[str, Any]] = []
for rec in tqdm(records):
    raw_content = _first_present(rec, FIELD_MAP["text"], default="")
    content_text = _coerce_text(raw_content)
    techniques = classify_multiple_techniques(content_text, top_k=TOP_K)
    extended = build_extended_relationships(rec, techniques)
    extended_docs.append(extended)

# Save to data/TTP-classification
os.makedirs("../data/TTP-classification", exist_ok=True)
with open(SAVE_JSON_PATH, "w", encoding="utf-8") as f:
    json.dump(extended_docs, f, ensure_ascii=False, indent=2)
print(f"💾 Extended relationships saved to: {SAVE_JSON_PATH}")


🧩 Building extended relationships per record...


100%|██████████| 5/5 [04:53<00:00, 58.80s/it]

💾 Extended relationships saved to: ../data/TTP-classification/extended_relationships_output.json





In [40]:
import os, json, re
from typing import Dict, Any, List

SRC_ENTITY_JSON = "../data/entity-extraction/entity_extraction_results_Qwen/Qwen2.5-1.5B-Instruct_test_2025-08-04_09-45-09_250_300.json"
DST_JSON = "../data/TTP-classification/qwen-ttp-classifiation/extended_relationships_output_250_300.json"
os.makedirs(os.path.dirname(DST_JSON), exist_ok=True)

print(f"📂 Loading entity-extraction file: {SRC_ENTITY_JSON}")
with open(SRC_ENTITY_JSON, "r", encoding="utf-8") as f:
    entity_records = json.load(f)
print(f"✅ Loaded {len(entity_records)} records")

augmented: List[Dict[str, Any]] = []

for rec in tqdm(entity_records):
    title = rec.get("title")
    link = rec.get("link") or rec.get("url")
    content_text = _coerce_text(rec.get("content"))

    # techniques via the multi-tech classifier
    techniques = classify_multiple_techniques(content_text, top_k=TOP_K if 'TOP_K' in globals() else 6)

    # relationships: subjects from extraction.entities and threat_type if exists
    extraction = rec.get("extraction") or {}
    ents = extraction.get("entities") or []
    subjects = []
    for item in ents:
        if isinstance(item, (list, tuple)) and len(item) >= 2:
            subjects.append(str(item[0]))
    threat_type = rec.get("threat_type")
    if threat_type:
        subjects.append(str(threat_type))
    subjects = [s for s in {s.strip(): None for s in subjects}.keys() if s]

    # uses only
    rels: List[List[str]] = []
    verbs = ["use", "uses", "abuse", "abuses", "leverage", "leverages", "employ", "employs"]
    if any(v in content_text.lower() for v in verbs):
        for subj in subjects:
            for t in (techniques or [])[:1]:
                rels.append([subj, "uses", t["id"]])

    # dedup
    rels = [list(x) for x in dict.fromkeys(tuple(r) for r in rels)]

    # append without changing original fields
    out = dict(rec)
    out["ttp_classification"] = {
        "instruction": None,
        "input": None,
        "output": {"techniques": techniques},
    }
    out["extended_relationships"] = {"relationships": rels}

    augmented.append(out)

with open(DST_JSON, "w", encoding="utf-8") as f:
    json.dump(augmented, f, ensure_ascii=False, indent=2)
print(f"💾 Written: {DST_JSON}")


📂 Loading entity-extraction file: ../data/entity-extraction/entity_extraction_results_Qwen/Qwen2.5-1.5B-Instruct_test_2025-08-04_09-45-09_250_300.json
✅ Loaded 50 records


100%|██████████| 50/50 [31:07<00:00, 37.34s/it] 

💾 Written: ../data/TTP-classification/qwen-ttp-classifiation/extended_relationships_output.json





In [None]:
def build_extended_relationships(record: Dict[str, Any], techniques: List[Dict[str, Any]]) -> Dict[str, Any]:
    title = _first_present(record, FIELD_MAP["title"], default=None)
    link = record.get("link") or record.get("url")
    raw_content = _first_present(record, FIELD_MAP["text"], default="")
    content_text = _coerce_text(raw_content)

    extraction = record.get("extraction") or {}
    ents = extraction.get("entities") or []

    subjects: List[str] = []
    if ents:
        for item in ents:
            if isinstance(item, (list, tuple)) and len(item) >= 2:
                subjects.append(str(item[0]))
    tt = _first_present(record, FIELD_MAP["threat_type"], default=None)
    if tt:
        subjects.append(str(tt))
    subjects = [s for s in {s.strip(): None for s in subjects}.keys() if s]

    # Relationships: uses only
    rels: List[List[str]] = []
    verbs = ["use", "uses", "abuse", "abuses", "leverage", "leverages", "employ", "employs"]
    text_lower = content_text.lower()
    if any(v in text_lower for v in verbs):
        for subj in subjects:
            for t in (techniques or [])[:1]:
                rels.append([subj, "uses", t["id"]])
    rels = [list(x) for x in dict.fromkeys(tuple(r) for r in rels)]

    return {
        "title": title,
        "link": link,
        "content": content_text,
        "extraction": extraction or None,
        "ttp_classification": {
            "instruction": None,
            "input": None,
            "output": {"techniques": techniques},
        },
        "extended_relationships": {"relationships": rels},
    }
