<a href="https://colab.research.google.com/github/ntjz-kakarot/llms-for-mitre-flows/blob/main/CTI_to_Graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup the environment

In [2]:
!pip install -qU \
  pdfplumber \
  python-docx \
  bitsandbytes \
  accelerate

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/56.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.4/56.4 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m239.6/239.6 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.3/297.3 kB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m55.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [7]:
# Standard library imports
import io
import json
import logging
import re
import time
import uuid
import locale
from dataclasses import dataclass
from io import BytesIO

# Third-party library imports for data manipulation and analysis
import numpy as np
import pandas as pd

# Text and file processing libraries
import pdfplumber
import html
from bs4 import BeautifulSoup
from PIL import Image
import textwrap
from docx import Document as docx

# Machine Learning, NLP, and PyTorch libraries
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig

# Visualization and display libraries
from IPython.display import HTML, display
import graphviz

# STIX2 Cyber Threat Intelligence objects
from stix2 import Bundle, AttackPattern, Infrastructure, Relationship

# Google Colab specific utilities
from google.colab import userdata
from ipywidgets import FileUpload, widgets


In [8]:
# Override the system's preferred encoding to "UTF-8"
locale.getpreferredencoding = lambda do_setlocale=True: "UTF-8"

# Set 'HUGGINGFACE_AUTH_TOKEN' in the Google Colab Secrets
HF_TOKEN = userdata.get('HUGGINGFACE_AUTH_TOKEN')
if HF_TOKEN is None:
    raise ValueError("Hugging Face authentication token not found. Set the HUGGINGFACE_AUTH_TOKEN environment variable.")

In [9]:
# This has to be installed after setting encoding to "UTF-8"
!pip install -q stix2

## Class Definitions

### Document Processor Class

In [10]:
class DocumentProcessor:
    def __init__(self):
        self.pages_to_exclude = set()
        self.files_content = {}
        self.upload_widget = FileUpload(multiple=True)
        self.setup_widgets()

    def setup_widgets(self):
        start_page_input = widgets.IntText(description="Start Page:")
        end_page_input = widgets.IntText(description="End Page:")
        exclude_pages_button = widgets.Button(description="Exclude Pages")
        exclude_pages_button.on_click(lambda b: self.on_exclude_pages_click(start_page_input, end_page_input))

        process_files_button = widgets.Button(description="Process Files")
        process_files_button.on_click(lambda b: self.on_process_files_click())

        display(self.upload_widget, start_page_input, end_page_input, exclude_pages_button, process_files_button)

    def on_exclude_pages_click(self, start_page_input, end_page_input):
        start_page = start_page_input.value
        end_page = end_page_input.value
        self.pages_to_exclude.update(range(start_page, end_page + 1))
        print(f"Pages {start_page} to {end_page} will be excluded.")

    def on_process_files_click(self):
        self.process_files()
        print("Files processed and stored.")

    def parse_text(self, file_name: str, content: io.BytesIO) -> str:
        if file_name.endswith('.pdf'):
            with pdfplumber.open(content) as pdf:
                text = " ".join(
                    page.extract_text() for i, page in enumerate(pdf.pages, start=1) if i not in self.pages_to_exclude
                )
        elif file_name.endswith('.html'):
            text = BeautifulSoup(content.read().decode('utf-8'), features="html.parser").get_text()
        elif file_name.endswith('.txt'):
            text = content.read().decode('utf-8')
        elif file_name.endswith('.docx'):
            doc = docx.Document(content)
            text = " ".join(
                paragraph.text for i, paragraph in enumerate(doc.paragraphs, start=1) if i not in self.pages_to_exclude
            )
        else:
            raise ValueError(f"Unsupported file type: {file_name}")

        cleaned_text = re.sub(r'\s+', ' ', text).strip()
        return cleaned_text

    def process_files(self):
        for name, content in self.upload_widget.value.items():
            text = self.parse_text(name, io.BytesIO(content['content']))
            self.files_content[name] = text

    def get_processed_texts(self):
        return self.files_content

    def display_processed_texts(self):
        for doc_name, doc_content in self.files_content.items():
            word_count = len(doc_content.split())
            print(f"Document: {doc_name}")
            print(f"Word Count: {word_count}")
            print(f"Content: {doc_content[:100]}...")
            print("-" * 100)

### Model Loader Class

In [11]:
class ModelLoader:
    def __init__(self, model_id: str, hf_token: str):
        self.model_id = model_id
        self.hf_token = hf_token
        self.device = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else torch.device('cpu')
        self.bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype='bfloat16')
        self.model, self.tokenizer = self.load_model_and_tokenizer()

    def load_model_and_tokenizer(self):
        try:
            model_config = transformers.AutoConfig.from_pretrained(
                self.model_id,
                token=self.hf_token,
            )
            model = transformers.AutoModelForCausalLM.from_pretrained(
                self.model_id,
                trust_remote_code=True,
                config=model_config,
                quantization_config=self.bnb_config,
                device_map='auto',
                token=self.hf_token
            ).eval()

            tokenizer = transformers.AutoTokenizer.from_pretrained(
                self.model_id,
                token=self.hf_token
            )

            print(f"Model loaded on {'GPU: ' + torch.cuda.get_device_name(self.device) if torch.cuda.is_available() else 'CPU'} - {self.device}")
            return model, tokenizer
        except Exception as e:
            print(f"Failed to load model or tokenizer: {e}")
            return None, None

### Answer and Prompt Processor Class

In [12]:
@dataclass
class Answer:
    answer: str
    elapse: int

class PromptProcessor:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, prompt: str) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(
            **inputs,
            max_new_tokens=4096,
            do_sample=False,
            num_return_sequences=1,
            repetition_penalty=1.0,
            length_penalty=1.0
        )
        prompt_length = inputs['input_ids'].shape[1]
        return self.tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)

    def run_prompt(self, prompt: str) -> Answer:
        start_time = time.time()
        answer_text = self.generate(prompt)
        elapse_time = round(time.time() - start_time)
        return Answer(answer_text, elapse_time)

    def process_answer(self, answer: Answer) -> dict:
        print(f"Time to Generate: {answer.elapse} seconds")
        try:
            parsed_json = json.loads(answer.answer)
            print(json.dumps(parsed_json, indent=4))
            return parsed_json
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON from the answer: {e}")
            return {}

### Graph Generator Class

In [13]:
class GraphGenerator:
    VIZ_IGNORE_COMMON_PROPERTIES = (
        "created",
        "external_references",
        "id",
        "modified",
        "revoked",
        "spec_version",
        "type",
    )

    VIZ_IGNORE_SDOS = ("attack-flow", "extension-definition")

    def convert_json_to_stix(self, json_data):
        # Create a list to store STIX objects
        stix_objects = []

        # Create a dictionary to store the mapping of action names to their corresponding STIX IDs
        action_id_mapping = {}

        # Iterate over each action in the JSON data
        for action in json_data:
            # Create an AttackPattern object for each action
            attack_pattern = AttackPattern(
                id="attack-pattern--" + str(uuid.uuid4()),
                created="2023-04-08T00:00:00.000Z",
                modified="2023-04-08T00:00:00.000Z",
                name=action["action_name"],
                description=action["evidence"],
                kill_chain_phases=[
                    {
                        "kill_chain_name": "mitre-attack",
                        "phase_name": action["tactic_id"].lower()
                    }
                ],
                external_references=[
                    {
                        "source_name": "mitre-attack",
                        "external_id": action["technique_id"]
                    }
                ]
            )
            stix_objects.append(attack_pattern)

            # Store the mapping of action name to its STIX ID
            action_id_mapping[action["action_name"]] = attack_pattern.id

            # Create an Infrastructure object for each affected asset
            for asset in action["affected_assets"]:
                infrastructure = Infrastructure(
                    id="infrastructure--" + str(uuid.uuid4()),
                    created="2023-04-08T00:00:00.000Z",
                    modified="2023-04-08T00:00:00.000Z",
                    name=asset
                )
                stix_objects.append(infrastructure)

                # Create a Relationship object between the AttackPattern and Infrastructure
                relationship = Relationship(
                    id="relationship--" + str(uuid.uuid4()),
                    created="2023-04-08T00:00:00.000Z",
                    modified="2023-04-08T00:00:00.000Z",
                    relationship_type="uses",
                    source_ref=attack_pattern.id,
                    target_ref=infrastructure.id
                )
                stix_objects.append(relationship)

        # Create Relationship objects for the proceeding actions
        for action in json_data:
            source_action_id = action_id_mapping[action["action_name"]]
            for proceeding_action in action["proceeding_actions"]:
                if proceeding_action in action_id_mapping:
                    target_action_id = action_id_mapping[proceeding_action]
                    relationship = Relationship(
                        id="relationship--" + str(uuid.uuid4()),
                        created="2023-04-08T00:00:00.000Z",
                        modified="2023-04-08T00:00:00.000Z",
                        relationship_type="proceeds-to",
                        source_ref=source_action_id,
                        target_ref=target_action_id
                    )
                    stix_objects.append(relationship)

        # Create a STIX bundle with the list of STIX objects
        bundle = Bundle(objects=stix_objects)

        return bundle

    def label_escape(self, text):
        return graphviz.escape(html.escape(text))

    def convert2dot(self, StixBundle):
        dotObj = graphviz.Digraph()
        dotObj.body = self._get_body_label(StixBundle)
        ignored_ids = self.get_ignored_ids(StixBundle)

        for o in StixBundle.objects:
            if o.type == "attack-action":
                dotObj.node(o.id, self._get_action_label(o), shape="plaintext")
                for ref in o.get("asset_refs", []):
                    dotObj.edge(o.id, ref, "asset")
                for ref in o.get("effect_refs", []):
                    dotObj.edge(o.id, ref, "effect")

            elif o.type == "attack-asset":
                dotObj.node(o.id, self._get_asset_label(o), shape="plaintext")
                if object_ref := o.get("object_ref"):
                    dotObj.edge(o.id, object_ref, "object")

            elif o.type == "attack-condition":
                dotObj.node(o.id, self._get_condition_label(o), shape="plaintext")
                for ref in o.get("on_true_refs", []):
                    dotObj.edge(o.id, ref, "on_true")
                for ref in o.get("on_false_refs", []):
                    dotObj.edge(o.id, ref, "on_false")

            elif o.type == "attack-operator":
                dotObj.node(o.id, o['operator'], shape="circle", style="filled", fillcolor="#ff9900")
                for ref in o.get("effect_refs", []):
                    dotObj.edge(o.id, ref, "effect")

            elif o.type == "relationship":
                dotObj.edge(o.source_ref, o.target_ref, o.relationship_type)

            elif o.id not in ignored_ids:
                dotObj.node(o.id, self._get_builtin_label(o), shape="plaintext")

        return dotObj

    def _get_body_label(self, StixBundle):
        flow = self.get_flow_object(StixBundle)
        if flow is None:
            # Handle the case when there is no "attack-flow" object
            return ['\tlabel="No attack-flow object found";\n', '\tlabelloc="t";\n']

        author = self.get_obj(StixBundle, flow.created_by_ref)

        description = "<br/>".join(
            textwrap.wrap(
                self.label_escape(flow.get("description", "(missing description)")), width=100
            )
        )
        lines = [
            f'<font point-size="24">{self.label_escape(flow.name)}</font>',
            f"<i>{description}</i>",
            f'<font point-size="10">Author: {self.label_escape(author.get("name", "(missing)"))} &lt;{self.label_escape(author.get("contact_information", "n/a"))}&gt;</font>',
            f'<font point-size="10">Created: {flow.get("created", "(missing)")}</font>',
            f'<font point-size="10">Modified: {flow.get("modified", "(missing)")}</font>',
        ]
        label = "<br/>".join(lines)

        return [f"\tlabel=<{label}>;\n", '\tlabelloc="t";\n']

    def _get_action_label(self, StixObject):
        if tid := StixObject.get("technique_id", None):
            heading = f"Action: {tid}"
        else:
            heading = "Action"
        description = "<br/>".join(
            textwrap.wrap(self.label_escape(StixObject.get("description", "")), width=40)
        )
        confidence = 95
        name = self.label_escape(StixObject.name)

        return "".join(
            [
                '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
                f'<TR><TD BGCOLOR="#99ccff" COLSPAN="2"><B>{heading}</B></TD></TR>',
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Name</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{name}</TD></TR>',
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Description</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{description}</TD></TR>',
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Confidence</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{confidence}</TD></TR>',
                "</TABLE>>",
            ]
        )

    def _get_asset_label(self, StixObject):
        name = self.label_escape(StixObject.name)
        description = "<br/>".join(
            textwrap.wrap(self.label_escape(StixObject.get("description", "")), width=40)
        )
        return "".join(
            [
                '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
                f'<TR><TD BGCOLOR="#cc99ff" COLSPAN="2"><B>Asset: {name}</B></TD></TR>',
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Description</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{description}</TD></TR>',
                "</TABLE>>",
            ]
        )

    def _get_condition_label(self, StixObject):
        description = "<br/>".join(
            textwrap.wrap(self.label_escape(StixObject.description), width=40)
        )
        return "".join(
            [
                '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
                '<TR><TD BGCOLOR="#99ff99" COLSPAN="2"><B>Condition</B></TD></TR>',
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Description</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{description}</TD></TR>'
                "</TABLE>>",
            ]
        )

    def _get_builtin_label(self, StixObject):
        title = StixObject.type.replace("-", " ").title()
        lines = [
            '<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="5">',
            f'<TR><TD BGCOLOR="#cccccc" COLSPAN="2"><B>{title}</B></TD></TR>',
        ]
        for key, value in StixObject.items():
            if key in self.VIZ_IGNORE_COMMON_PROPERTIES:
                continue
            pretty_key = key.replace("_", " ").title()
            if isinstance(value, list):
                value = ", ".join(str(v) for v in value)
            pretty_value = self.label_escape(str(value))
            lines.append(
                f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>{pretty_key}</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{pretty_value}</TD></TR>'
            )
        lines.append("</TABLE>>")
        return "".join(lines)

    def get_flow_object(self, StixBundle):
        for obj in StixBundle.objects:
            if obj.type == "attack-flow":
                return obj
        return None

    def get_obj(self, StixBundle, id):
        for obj in StixBundle.objects:
            if obj.id == id:
                return obj

    def get_ignored_ids(self, StixBundle):
        ignored = set()

        # Ignore flow creator identity:
        flow = self.get_flow_object(StixBundle)
        if flow is not None and (flow_creator := flow.get("created_by_ref", None)):
            ignored.add(flow_creator)

        # Ignore by SDO type:
        for obj in StixBundle.objects:
            if obj.type in self.VIZ_IGNORE_SDOS:
                ignored.add(obj.id)

            # Ignore extension creator identity:
            if obj.type == "extension-definition" and (ext_creator := obj.get("created_by_ref", None)):
                ignored.add(ext_creator)

        return ignored

## Prompt Engineering

In [31]:
prompt = """<s>[INST] <<SYS>>
You are a cybersecurity analyst with expertise in the MITRE ATT&CK framework. Your task is to process the given section of a Cyber Threat Intelligence (CTI) report, which has been parsed and includes only the relevant parts pertaining to an attack's timeline. You will analyze the text sentence-by-sentence, and when necessary, consider additional sentences for context, to identify specific actions taken by attackers.
For each action identified, you will output information in JSON format with the following structure:

1. Action Name: The specific action taken by the attacker, as described in the text, like 'Vulnerability Scanning', 'Exploit Public-Facing Application', etc.
2. Tactic ID: The Tactic ID from the MITRE ATT&CK framework that categorizes the overarching goal of the action (e.g., TA0001 for Initial Access).
3. Technique ID/Sub-technique ID: The specific Technique or Sub-technique ID from the MITRE ATT&CK framework that the action corresponds to.
4. Evidence: A direct quote from the CTI report that supports the identification of the action.
5. Affected Asset(s): The asset(s) targeted or compromised by the action, based on the report's context.
6. Exploited Vulnerability: Name the specific vulnerability exploited by the action, if applicable.
7. Proceeding Actions: List the actions that directly follow the identified one, in chronological order, to help construct an attack flow graph.

Your response should be structured as follows (sample JSON for guidance):

{{
  "action_name": "Example Action",
  "tactic_id": "TA000X",
  "technique_id": "TXXXX",
  "sub_technique_id": "TXXXX.YYY",
  "evidence": "Exact text from the CTI report evidencing the action.",
  "affected_assets": ["Example Affected Asset"],
  "exploited_vulnerability": "CVE-XXXX-XXXX",
  "proceeding_actions": ["Next Action in the flow"]
}}

Ensure each identified action from the CTI report is processed into a separate JSON object.
<</SYS>>

What are the specific attacker actions described in the report, and how do they map to the MITRE ATT&CK framework? For each identified action, provide the action name, tactic ID, technique ID (and sub-technique ID if applicable), supporting evidence from the report, affected asset(s), exploited vulnerability (if any), and proceeding actions in the attack flow. [/INST]
[
{{
"action_name": "Vulnerability Scanning",
"tactic_id": "TA0001",
"technique_id": "T1595",
"sub_technique_id": "T1595.002",
"evidence": "On January 5th, 2023 at 2:15 AM PST, our security monitoring systems detected suspicious network scanning activity originating from the IP address 192.168.1.100. The scans were targeting our public-facing web server, specifically probing for known vulnerabilities in our content management system (CMS).",
"affected_assets": ["Public-facing web server", "Content Management System (CMS)"],
"exploited_vulnerability": null,
"proceeding_actions": ["Exploit Public-Facing Application"]
}},
{{
"action_name": "Exploit Public-Facing Application",
"tactic_id": "TA0001",
"technique_id": "T1190",
"sub_technique_id": null,
"evidence": "At 2:30 AM PST, the attacker successfully exploited a SQL injection vulnerability (CVE-2022-1234) in our CMS, gaining unauthorized access to the web server.",
"affected_assets": ["Public-facing web server", "Content Management System (CMS)"],
"exploited_vulnerability": "CVE-2022-1234",
"proceeding_actions": ["Server Software Component", "Command and Scripting Interpreter"]
}},
{{
"action_name": "Server Software Component",
"tactic_id": "TA0003",
"technique_id": "T1505",
"sub_technique_id": "T1505.003",
"evidence": "They proceeded to upload a malicious PHP webshell, granting them persistent backdoor access.",
"affected_assets": ["Public-facing web server"],
"exploited_vulnerability": null,
"proceeding_actions": ["Account Discovery", "Network Service Discovery"]
}},
{{
"action_name": "Command and Scripting Interpreter",
"tactic_id": "TA0002",
"technique_id": "T1059",
"sub_technique_id": "T1059.004",
"evidence": "Using the webshell, at 3:00 AM PST, the attacker performed discovery commands to enumerate the server's configuration and network information.",
"affected_assets": ["Public-facing web server"],
"exploited_vulnerability": null,
"proceeding_actions": ["Account Discovery", "Network Service Discovery"]
}},
{{
"action_name": "Account Discovery",
"tactic_id": "TA0007",
"technique_id": "T1087",
"sub_technique_id": null,
"evidence": "They identified a privileged service account used by the web server to access the backend database.",
"affected_assets": ["Privileged service account"],
"exploited_vulnerability": null,
"proceeding_actions": ["Brute Force"]
}},
] </s>

<s>[INST]
What are the specific attacker actions described in the report, and how do they map to the MITRE ATT&CK framework? For each identified action, provide the action name, tactic ID, technique ID (and sub-technique ID if applicable), supporting evidence from the report, affected asset(s), exploited vulnerability (if any), and proceeding actions in the attack flow. Output your findings in the specified JSON format:
{{
"action_name": "Example Action",
"tactic_id": "TA000X",
"technique_id": "TXXXX",
"sub_technique_id": "TXXXX.YYY",
"evidence": "Exact text from the CTI report evidencing the action.",
"affected_assets": ["Example Affected Asset"],
"exploited_vulnerability": "CVE-XXXX-XXXX",
"proceeding_actions": ["Next Action in the flow"]
}}
Your output should only be the JSON objects, with no other text.

Report: {BODY}
[/INST]"""

## Model Init

In [15]:
MODEL_ID = 'mistralai/Mistral-7B-Instruct-v0.2'
model_loader = ModelLoader(MODEL_ID, HF_TOKEN)

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

Model loaded on GPU: Tesla T4 - cuda:0


## Run \#1 - ANU

In [16]:
doc_processor = DocumentProcessor()

FileUpload(value={}, description='Upload', multiple=True)

IntText(value=0, description='Start Page:')

IntText(value=0, description='End Page:')

Button(description='Exclude Pages', style=ButtonStyle())

Button(description='Process Files', style=ButtonStyle())

Pages 1 to 5 will be excluded.
Pages 10 to 20 will be excluded.
Files processed and stored.


In [17]:
doc_processor.display_processed_texts()

Document: ANU3.pdf
Word Count: 2176
Content: D E T A I L E D T I M E L I N E O F T H E D A T A B R E A C H Overview This section provides a chron...
----------------------------------------------------------------------------------------------------


In [22]:
doc_name = list(doc_processor.get_processed_texts().keys())[0]
processed_texts = list(doc_processor.get_processed_texts().values())[0]

In [24]:
p1 = prompt.format(BODY=processed_texts)
prompt_processor = PromptProcessor(model_loader.model, model_loader.tokenizer)
answer = prompt_processor.run_prompt(p1)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [25]:
processed_answer = prompt_processor.process_answer(answer)

Time to Generate: 264 seconds
[
    {
        "action_name": "Spearphishing Email Attack",
        "tactic_id": "TA0001",
        "technique_id": "T1192",
        "sub_technique_id": null,
        "evidence": "On 9 November 2018, the actor sent a spearphishing email to a senior member of staff at the Australian National University (ANU). The email did not require the recipient to click on any link or download an attachment but still resulted in the senior staff member's credentials being sent to several external web addresses.",
        "affected_assets": [
            "Senior staff member's credentials"
        ],
        "exploited_vulnerability": null,
        "proceeding_actions": [
            "Credentials Harvesting"
        ]
    },
    {
        "action_name": "Credentials Harvesting",
        "tactic_id": "TA0008",
        "technique_id": "T1003",
        "sub_technique_id": null,
        "evidence": "The actor gained access to the senior staff member's calendar and used the i

In [26]:
graph_generator = GraphGenerator()

# Convert the JSON data to STIX
stix_bundle = graph_generator.convert_json_to_stix(processed_answer)

# Convert the STIX bundle to a Graphviz DOT object
dot_obj = graph_generator.convert2dot(stix_bundle)

# Render the DOT object as an image
img_data = dot_obj.pipe(format='png')

# Create a PIL Image object from the image data
img = Image.open(BytesIO(img_data))

# Display the image
img.show()

# Save the image to a file
img.save(f"{doc_name}.png")

## Run \#2 - Equifax

In [27]:
doc_processor = DocumentProcessor()

FileUpload(value={}, description='Upload', multiple=True)

IntText(value=0, description='Start Page:')

IntText(value=0, description='End Page:')

Button(description='Exclude Pages', style=ButtonStyle())

Button(description='Process Files', style=ButtonStyle())

Pages 1 to 7 will be excluded.
Pages 13 to 96 will be excluded.
Files processed and stored.


In [28]:
doc_processor.display_processed_texts()

Document: Equifax-Report.pdf
Word Count: 1404
Content: Timeline of Key Events March 7, 2017  Apache Struts Project Management Committee announces the CVE-...
----------------------------------------------------------------------------------------------------


In [29]:
doc_name = list(doc_processor.get_processed_texts().keys())[0]
processed_texts = list(doc_processor.get_processed_texts().values())[0]

In [32]:
p2 = prompt.format(BODY=processed_texts)
prompt_processor = PromptProcessor(model_loader.model, model_loader.tokenizer)
answer = prompt_processor.run_prompt(p2)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [33]:
processed_answer = prompt_processor.process_answer(answer)

Time to Generate: 136 seconds
[
    {
        "action_name": "Vulnerability Exploitation",
        "tactic_id": "TA0002",
        "technique_id": "T1192",
        "sub_technique_id": null,
        "evidence": "On May 13, 2017, attackers entered the Equifax network through the Apache Struts vulnerability located within the Automated Consumer Interview System (ACIS) application and dropped web shells onto the Equifax system.",
        "affected_assets": [
            "Equifax network",
            "Automated Consumer Interview System (ACIS)"
        ],
        "exploited_vulnerability": "CVE-2017-5638",
        "proceeding_actions": [
            "Data Exfiltration"
        ]
    },
    {
        "action_name": "Data Exfiltration",
        "tactic_id": "TA0010",
        "technique_id": "T1003",
        "sub_technique_id": null,
        "evidence": "Approximately 9,000 queries were performed to sensitive databases within the Equifax system between May 13, 2017 and July 30, 2017.",
       

In [34]:
graph_generator = GraphGenerator()

# Convert the JSON data to STIX
stix_bundle = graph_generator.convert_json_to_stix(processed_answer)

# Convert the STIX bundle to a Graphviz DOT object
dot_obj = graph_generator.convert2dot(stix_bundle)

# Render the DOT object as an image
img_data = dot_obj.pipe(format='png')

# Create a PIL Image object from the image data
img = Image.open(BytesIO(img_data))

# Display the image
img.show()

# Save the image to a file
img.save(f"{doc_name}.png")