In [None]:
"""cyberthreat-pathfinder-notebook.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1r5NiN210Lm73pXlanKmmdw6SY6oNEvNE

CyberThreat PathFinder: An Agentic Graph Intelligence System

# Setup and Environment
"""

In [None]:
# Install nx-arangodb via pip
!pip install nx-arangodb

In [None]:
# Check if you have an NVIDIA GPU
# Note: If this returns "command not found", then GPU-based algorithms via cuGraph are unavailable
!nvidia-smi
!nvcc --version

In [None]:
# Install nx-cugraph via pip (if GPU is available)
!pip install nx-cugraph-cu12 --extra-index-url=https://pypi.nvidia.com # Requires CUDA-capable GPU

In [None]:
# Install cuGraph
!pip install cugraph-cu12 --extra-index-url=https://pypi.nvidia.com

In [None]:
# Install necessary dependencies
!pip install langchain langchain-community langgraph ollama

In [None]:
# Commented out IPython magic to ensure Python compatibility.
!pip install colab-xterm
%load_ext colabxterm

In [None]:
# Commented out IPython magic to ensure Python compatibility.
%xterm
# Refer this page to setup ollama -- https://medium.com/@abonia/running-ollama-in-google-colab-free-tier-545609258453
# curl https://ollama.ai/install.sh | sh
# ollama serve &
# ollama pull mistral:instruct

In [None]:
!ollama list

In [None]:
!pip install -U langchain-ollama

In [None]:
import ollama

In [None]:
# Test connection to Ollama
response = ollama.generate(model='mistral:instruct',
                          prompt='Give me a brief overview of what Ollama is.',
                          options={'temperature': 0.1})
print(response['response'])

In [None]:
# Install data processing tools
!pip install --quiet kaggle openpyxl stix2 taxii2-client

In [None]:
# Import the required modules
import networkx as nx
import nx_arangodb as nxadb
from arango import ArangoClient
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from random import randint
import re
from langchain_community.graphs import ArangoGraph
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
from langchain_core.tools import tool
from langchain.agents import initialize_agent, AgentType

In [None]:
from langchain_community.llms import Ollama
from langchain.agents import AgentExecutor, Tool
from langchain.agents import create_react_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain.tools import tool

In [None]:
# Additional imports for functionality
import json
from datetime import datetime
from IPython.display import display, HTML
import warnings
import os
import time
import requests
import io
import zipfile
from langchain_community.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

In [None]:
# Import cuGraph - new imports for GPU acceleration
import cugraph
import cudf
from cugraph.structure.graph_implementation.simpleGraph import simpleGraphImpl
import cupy as cp

In [None]:
# Silence warnings
warnings.filterwarnings('ignore')

In [None]:
"""# Initialize Local LLM"""

In [None]:
# Configure the LLM to use Ollama
llm = Ollama(
    model="mistral:instruct",
    temperature=0.1,
    base_url="http://localhost:11434"
)

In [None]:
# Test the LLM
test_response = llm.invoke("What is a cybersecurity threat graph?")
print("\nLLM Test Response:")
print(test_response)

In [None]:
"""# Connect to ArangoDB"""

In [None]:
from google.colab import userdata
ARANGO_URL = userdata.get('ARANGO_URL')
ARANGO_USERNAME = userdata.get('ARANGO_USERNAME')
ARANGO_PASSWORD = userdata.get('ARANGO_PASSWORD')
ARANGO_DB = userdata.get('ARANGO_DB')

In [None]:
ARANGO_DB

In [None]:
# Connect to ArangoDB
def connect_to_arangodb():
    try:
        # Connect to ArangoDB
        client = ArangoClient(hosts=ARANGO_URL)

        # Connect to the system database to create cyberthreat_db if it doesn't exist
        sys_db = client.db("_system", username=ARANGO_USERNAME, password=ARANGO_PASSWORD, verify=True)

        # Create a new database if it doesn't exist
        if not sys_db.has_database(ARANGO_DB):
            sys_db.create_database(ARANGO_DB)
            print(f"Created database: {ARANGO_DB}")

        # Connect to the cyberthreat database
        db = client.db(ARANGO_DB, username=ARANGO_USERNAME, password=ARANGO_PASSWORD, verify=True)

        print(f"Successfully connected to ArangoDB at {ARANGO_URL}")
        return db, client
    except Exception as e:
        print(f"Error connecting to ArangoDB: {e}")
        return None, None

In [None]:
db, client = connect_to_arangodb()

In [None]:
"""# Download and Process CVE and MITRE ATT&CK Data"""

In [None]:
# Download CVE data from Kaggle
def download_cve_data():
    # First, set up Kaggle API credentials if needed
    # You may need to upload your kaggle.json file to Colab
    if not os.path.exists('/root/.kaggle'):
        os.makedirs('/root/.kaggle')

    # Check if we already have the data
    if not os.path.exists('cve_data'):
        os.makedirs('cve_data')

    if not os.path.exists('cve_data/cve.csv'):
        print("Downloading CVE dataset from Kaggle...")
        try:
            # Try using Kaggle API if credentials are available
            import kaggle
            kaggle.api.authenticate()
            kaggle.api.dataset_download_files('andrewkronser/cve-common-vulnerabilities-and-exposures', path='cve_data', unzip=True)
        except:
            # If Kaggle API doesn't work, use direct download
            print("Kaggle API not available, downloading from alternative source...")
            cve_url = "https://www.cisa.gov/sites/default/files/csv/known_exploited_vulnerabilities.csv"
            response = requests.get(cve_url)
            with open('cve_data/cve.csv', 'wb') as f:
                f.write(response.content)

    print("Loading CVE data...")
    # Load the main CVE data
    try:
        cve_df = pd.read_csv('cve_data/cve.csv')
        print(f"Successfully loaded CVE data with {len(cve_df)} entries")
        return cve_df
    except:
        # If main dataset fails, try CISA KEV as fallback
        try:
            cve_df = pd.read_csv('cve_data/known_exploited_vulnerabilities.csv')
            print(f"Successfully loaded CISA KEV data with {len(cve_df)} entries")
            return cve_df
        except Exception as e:
            print(f"Error loading CVE data: {e}")
            # Create a minimal dataset as fallback
            return pd.DataFrame({
                'cve_id': ['CVE-2023-0001', 'CVE-2023-0002', 'CVE-2023-0003'],
                'description': [
                    'Remote code execution vulnerability in web server',
                    'SQL injection vulnerability in database application',
                    'Cross-site scripting vulnerability in web application'
                ],
                'published_date': ['2023-01-15', '2023-02-20', '2023-03-10'],
                'cvss_score': [9.8, 8.5, 7.2]
            })

In [None]:
def download_mitre_attack_data():
    # Check if we already have the data
    if not os.path.exists('mitre_data'):
        os.makedirs('mitre_data')

    if not os.path.exists('mitre_data/enterprise-attack.json'):
        print("Downloading MITRE ATT&CK Enterprise framework...")
        url = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json"
        response = requests.get(url)
        with open('mitre_data/enterprise-attack.json', 'wb') as f:
            f.write(response.content)

    print("Loading MITRE ATT&CK data...")
    # Load the MITRE ATT&CK STIX data
    try:
        # Explicitly use UTF-8 encoding when reading the file
        with open('mitre_data/enterprise-attack.json', 'r', encoding='utf-8') as f:
            attack_data = json.load(f)
        print(f"Successfully loaded MITRE ATT&CK data with {len(attack_data['objects'])} objects")
        return attack_data
    except Exception as e:
        print(f"Error loading MITRE ATT&CK data: {str(e)}")

        # Try a more robust approach with error handling
        try:
            print("Trying alternative loading approach...")
            import codecs
            with codecs.open('mitre_data/enterprise-attack.json', 'r', encoding='utf-8', errors='replace') as f:
                content = f.read()
                attack_data = json.loads(content)
            print(f"Successfully loaded MITRE ATT&CK data with alternative method")
            return attack_data
        except Exception as e2:
            print(f"Alternative loading also failed: {str(e2)}")

            # Create minimal dataset as fallback
            print("Using fallback MITRE data")
            return {
                'objects': [
                    {
                        'type': 'attack-pattern',
                        'id': 'attack-pattern--t1059',
                        'name': 'Command and Scripting Interpreter',
                        'description': 'Adversaries may abuse command and script interpreters to execute commands',
                        'kill_chain_phases': [{'kill_chain_name': 'mitre-attack', 'phase_name': 'execution'}]
                    },
                    {
                        'type': 'attack-pattern',
                        'id': 'attack-pattern--t1566',
                        'name': 'Phishing',
                        'description': 'Adversaries may send phishing messages to gain access to victim systems',
                        'kill_chain_phases': [{'kill_chain_name': 'mitre-attack', 'phase_name': 'initial-access'}]
                    },
                    {
                        'type': 'intrusion-set',
                        'id': 'intrusion-set--apt28',
                        'name': 'APT28',
                        'description': 'APT28 is a threat group that has been attributed to Russia',
                    }
                ]
            }

In [None]:
# Process MITRE ATT&CK data into suitable format for ArangoDB
def process_mitre_attack_data(attack_data):
    # Lists to store processed data
    techniques = []
    tactics = []
    threat_actors = []
    relationships = []

    # Process objects from MITRE ATT&CK data
    for obj in attack_data['objects']:
        if obj['type'] == 'attack-pattern':
            # This is a technique
            technique_id = obj.get('external_references', [{}])[0].get('external_id', '').upper() if 'external_references' in obj else None
            if not technique_id:
                continue

            tactic = ''
            if 'kill_chain_phases' in obj:
                for phase in obj['kill_chain_phases']:
                    if phase.get('kill_chain_name') == 'mitre-attack':
                        tactic = phase.get('phase_name', '').replace('-', ' ').title()
                        break

            techniques.append({
                '_key': technique_id,
                'name': obj.get('name', ''),
                'tactic': tactic,
                'description': obj.get('description', '')
            })

        elif obj['type'] == 'intrusion-set':
            # This is a threat actor/group
            actor_id = obj['id'].split('--')[1]
            threat_actors.append({
                '_key': actor_id,
                'name': obj.get('name', ''),
                'type': 'Nation State' if 'government' in obj.get('description', '').lower() else 'Criminal',
                'description': obj.get('description', '')
            })

        elif obj['type'] == 'relationship':
            # This is a relationship between objects
            if obj.get('relationship_type') == 'uses' and obj.get('source_ref', '').startswith('intrusion-set') and obj.get('target_ref', '').startswith('attack-pattern'):
                source_id = obj['source_ref'].split('--')[1]
                target_id = obj['target_ref'].split('--')[1]

                # Get external ID for technique if available
                for rel_obj in attack_data['objects']:
                    if rel_obj.get('id') == obj['target_ref'] and 'external_references' in rel_obj:
                        for ref in rel_obj['external_references']:
                            if ref.get('source_name') == 'mitre-attack':
                                target_id = ref.get('external_id', '').upper()
                                break

                relationships.append({
                    'source': source_id,
                    'target': target_id,
                    'type': 'uses'
                })

    return {
        'techniques': techniques,
        'threat_actors': threat_actors,
        'relationships': relationships
    }

In [None]:
# Process CVE data
def process_cve_data(cve_df, limit=1000):
    print(f"Processing {len(cve_df)} CVE entries...")

    # Map the CISA KEV column names to our expected column names
    if 'cveID' in cve_df.columns and 'CVE ID' not in cve_df.columns:
        print("Detected CISA KEV format, mapping columns...")
        cve_df = cve_df.rename(columns={
            'cveID': 'CVE ID',
            'shortDescription': 'Description',
            'dateAdded': 'Published',
            'cwes': 'CWE ID'
        })

        # Add a default CVSS Score column if missing
        if 'CVSS Score' not in cve_df.columns:
            print("Adding default CVSS scores based on CWE...")
            # Assign a default CVSS based on the shortDescription seriousness
            def estimate_cvss(row):
                desc = row['Description'].lower() if isinstance(row['Description'], str) else ''
                if any(term in desc for term in ['critical', 'remote code execution', 'rce']):
                    return 9.0
                elif any(term in desc for term in ['high', 'arbitrary code', 'privilege escalation']):
                    return 7.5
                elif any(term in desc for term in ['medium', 'information disclosure', 'cross-site']):
                    return 5.0
                else:
                    return 4.0

            cve_df['CVSS Score'] = cve_df.apply(estimate_cvss, axis=1)

    # Process into the format for ArangoDB
    vulnerabilities = []

    for _, row in cve_df.iterrows():
        cve_id = row.get('CVE ID', '')
        if not cve_id:
            continue

        # Clean up the CVE ID to use as a key
        cve_key = cve_id.replace('-', '_')

        # Determine severity based on CVSS if available
        cvss_score = row.get('CVSS Score', 0)
        if not isinstance(cvss_score, (int, float)):
            try:
                cvss_score = float(cvss_score)
            except:
                cvss_score = 0

        severity = 'Unknown'
        if cvss_score >= 9.0:
            severity = 'Critical'
        elif cvss_score >= 7.0:
            severity = 'High'
        elif cvss_score >= 4.0:
            severity = 'Medium'
        elif cvss_score > 0:
            severity = 'Low'

        # Parse the published date
        published_date = row.get('Published', '')

        vulnerabilities.append({
            '_key': cve_key,
            'name': cve_id,
            'description': row.get('Description', ''),
            'cvss_score': cvss_score,
            'published_date': published_date,
            'severity': severity,
            'cwe_id': row.get('CWE ID', '')
        })

    print(f"Created {len(vulnerabilities)} vulnerability objects")
    return vulnerabilities

In [None]:
# Generate relationships between entities with improved connectivity
def generate_relationships(vulnerabilities, threat_actors, techniques):
    """Generate realistic relationships between CVEs, assets, and threat actors with improved connectivity"""
    # Define common asset types in organizations
    assets = [
        {"_key": "server001", "name": "Production Web Server", "type": "Server", "criticality": "High", "operating_system": "Ubuntu 20.04", "ip_address": "10.0.0.1"},
        {"_key": "server002", "name": "Database Server", "type": "Server", "criticality": "Critical", "operating_system": "CentOS 8", "ip_address": "10.0.0.2"},
        {"_key": "server003", "name": "Test Web Server", "type": "Server", "criticality": "Low", "operating_system": "Ubuntu 20.04", "ip_address": "10.0.0.3"},
        {"_key": "server004", "name": "Email Server", "type": "Server", "criticality": "High", "operating_system": "Windows Server 2019", "ip_address": "10.0.0.4"},
        {"_key": "server005", "name": "Domain Controller", "type": "Server", "criticality": "Critical", "operating_system": "Windows Server 2019", "ip_address": "10.0.0.5"},
        {"_key": "workstation001", "name": "CEO Laptop", "type": "Endpoint", "criticality": "Medium", "operating_system": "Windows 11", "ip_address": "10.0.1.1"},
        {"_key": "workstation002", "name": "CFO Laptop", "type": "Endpoint", "criticality": "Medium", "operating_system": "MacOS", "ip_address": "10.0.1.2"},
        {"_key": "workstation003", "name": "Developer Workstation", "type": "Endpoint", "criticality": "Low", "operating_system": "Ubuntu 22.04", "ip_address": "10.0.1.3"},
        {"_key": "router001", "name": "Main Router", "type": "Network", "criticality": "High", "operating_system": "Cisco IOS", "ip_address": "10.0.0.254"},
        {"_key": "firewall001", "name": "Perimeter Firewall", "type": "Network", "criticality": "Critical", "operating_system": "Palo Alto PAN-OS", "ip_address": "10.0.0.253"}
    ]

    # Create relationships
    exploits = []  # ThreatActors -> Vulnerabilities
    targets = []   # Vulnerabilities -> Assets
    uses = []      # ThreatActors -> Techniques

    import random
    from datetime import datetime, timedelta

    # Get lists of keys
    vuln_keys = [v['_key'] for v in vulnerabilities]
    actor_keys = [a['_key'] for a in threat_actors]
    technique_keys = [t['_key'] for t in techniques]
    asset_keys = [a['_key'] for a in assets]

    # Filter vulnerabilities by severity - actors prefer high impact vulns
    critical_vulns = [v['_key'] for v in vulnerabilities if v.get('severity', '') in ['Critical', 'High']]
    if not critical_vulns:
        critical_vulns = vuln_keys

    # Create actor -> vulnerability (exploits) edges
    # Enhanced: Each actor exploits more vulnerabilities for better connectivity
    for actor in threat_actors:
        # Nation state actors exploit more vulnerabilities than criminal groups
        # Increased numbers for better connectivity
        num_exploits = random.randint(5, 12) if actor.get('type', '') == 'Nation State' else random.randint(3, 8)

        # Select vulnerabilities to exploit - prefer critical/high but also some random ones
        exploit_vulns = random.sample(critical_vulns, min(num_exploits // 2 + 1, len(critical_vulns)))
        exploit_vulns += random.sample(vuln_keys, min(num_exploits - len(exploit_vulns), len(vuln_keys)))

        for vuln_key in exploit_vulns:
            # Generate a random date in the last year
            days_ago = random.randint(30, 365)
            date_observed = (datetime.now() - timedelta(days=days_ago)).strftime('%Y-%m-%d')

            confidence = random.choice(['High', 'Medium', 'Low'])

            exploits.append({
                "_from": f"ThreatActors/{actor['_key']}",
                "_to": f"Vulnerabilities/{vuln_key}",
                "confidence": confidence,
                "date_observed": date_observed
            })

    # Create vulnerability -> asset (targets) edges
    # Enhanced: Each vulnerability affects more assets
    for vuln in vulnerabilities:
        # Determine how many assets this vulnerability affects
        # More severe vulnerabilities tend to affect more assets
        # Increased numbers for better connectivity
        if vuln.get('severity', '') == 'Critical':
            num_affected = random.randint(3, 7)
        elif vuln.get('severity', '') == 'High':
            num_affected = random.randint(2, 5)
        else:
            num_affected = random.randint(1, 3)

        # Select assets to affect
        affected_assets = random.sample(asset_keys, min(num_affected, len(asset_keys)))

        for asset_key in affected_assets:
            # Determine impact based on vulnerability severity and asset criticality
            asset = next((a for a in assets if a['_key'] == asset_key), None)
            if not asset:
                continue

            if vuln.get('severity', '') == 'Critical' and asset['criticality'] == 'Critical':
                impact = 'Critical'
            elif vuln.get('severity', '') == 'Critical' or asset['criticality'] == 'Critical':
                impact = 'High'
            elif vuln.get('severity', '') == 'High' and asset['criticality'] == 'High':
                impact = 'High'
            elif vuln.get('severity', '') == 'High' or asset['criticality'] == 'High':
                impact = 'Medium'
            else:
                impact = 'Low'

            # Determine remediation status
            status_weights = {
                'Patched': 30,
                'In Progress': 30,
                'Unpatched': 40
            }
            statuses = list(status_weights.keys())
            weights = list(status_weights.values())

            remediation_status = random.choices(statuses, weights=weights, k=1)[0]

            targets.append({
                "_from": f"Vulnerabilities/{vuln['_key']}",
                "_to": f"Assets/{asset_key}",
                "impact": impact,
                "remediation_status": remediation_status
            })

    # Create actor -> technique (uses) edges
    # Enhanced: Each actor uses more techniques
    for actor in threat_actors:
        # Determine how many techniques this actor uses
        # Increased for better connectivity
        num_techniques = random.randint(4, 10)

        # Select techniques
        used_techniques = random.sample(technique_keys, min(num_techniques, len(technique_keys)))

        for tech_key in used_techniques:
            # Generate a random date in the last year
            days_ago = random.randint(30, 365)
            last_observed = (datetime.now() - timedelta(days=days_ago)).strftime('%Y-%m-%d')

            frequency = random.choice(['High', 'Medium', 'Low'])

            uses.append({
                "_from": f"ThreatActors/{actor['_key']}",
                "_to": f"Techniques/{tech_key}",
                "frequency": frequency,
                "last_observed": last_observed
            })

    # Ensure all vulnerability to asset paths connect to at least one threat actor
    # This ensures the completeness of attack paths
    return {
        "assets": assets,
        "exploits": exploits,
        "targets": targets,
        "uses": uses
    }

In [None]:
# Download and process the datasets
print("Downloading and processing CVE data...")
cve_df = download_cve_data()
cve_vulnerabilities = process_cve_data(cve_df)

In [None]:
print("Downloading and processing MITRE ATT&CK data...")
attack_data = download_mitre_attack_data()
processed_attack = process_mitre_attack_data(attack_data)

In [None]:
# Generate relationships between the entities
relationships = generate_relationships(
    cve_vulnerabilities,
    processed_attack['threat_actors'],
    processed_attack['techniques']
)

In [None]:
print(f"Processed data summary:")
print(f"- {len(cve_vulnerabilities)} CVE vulnerabilities")
print(f"- {len(processed_attack['techniques'])} MITRE techniques")
print(f"- {len(processed_attack['threat_actors'])} threat actors")
print(f"- {len(relationships['assets'])} assets")
print(f"- {len(relationships['exploits'])} exploit relationships")
print(f"- {len(relationships['targets'])} target relationships")
print(f"- {len(relationships['uses'])} technique usage relationships")

In [None]:
"""# Create Cybersecurity Graph Schema"""

In [None]:
# Set up collections for the threat graph
def setup_threat_graph_collections(db):
    # Define collections needed for cybersecurity graph
    collections = [
        {"name": "Vulnerabilities", "edge": False},
        {"name": "Assets", "edge": False},
        {"name": "ThreatActors", "edge": False},
        {"name": "Techniques", "edge": False},
        {"name": "Exploits", "edge": True},
        {"name": "Targets", "edge": True},
        {"name": "Uses", "edge": True}
    ]

    # Create collections if they don't exist
    for col in collections:
        if not db.has_collection(col["name"]):
            db.create_collection(col["name"], edge=col["edge"])
            print(f"Created {col['name']} {'edge ' if col['edge'] else ''}collection")

    # Create a named graph if it doesn't exist
    if not db.has_graph("ThreatGraph"):
        graph = db.create_graph("ThreatGraph")

        # Define edge definitions
        graph.create_edge_definition(
            edge_collection="Exploits",
            from_vertex_collections=["ThreatActors"],
            to_vertex_collections=["Vulnerabilities"]
        )

        graph.create_edge_definition(
            edge_collection="Targets",
            from_vertex_collections=["Vulnerabilities"],
            to_vertex_collections=["Assets"]
        )

        graph.create_edge_definition(
            edge_collection="Uses",
            from_vertex_collections=["ThreatActors"],
            to_vertex_collections=["Techniques"]
        )

        print("Created ThreatGraph with edge definitions")
    else:
        graph = db.graph("ThreatGraph")
        print("Using existing ThreatGraph")

    return graph

In [None]:
# Load data into ArangoDB
def load_threat_data_to_arangodb(db, vulnerabilities, techniques, threat_actors, assets, exploits, targets, uses):
    # Load vertex collections
    collections_data = {
        "Vulnerabilities": vulnerabilities,
        "Assets": assets,
        "ThreatActors": threat_actors,
        "Techniques": techniques
    }

    for collection_name, data in collections_data.items():
        if db.collection(collection_name).count() == 0:
            print(f"Loading {len(data)} documents into {collection_name}...")
            # Load in batches to avoid memory issues
            batch_size = 1000
            for i in range(0, len(data), batch_size):
                batch = data[i:i+batch_size]
                db.collection(collection_name).import_bulk(batch)
            print(f"Loaded data into {collection_name}")
        else:
            print(f"Collection {collection_name} already has data, skipping import")

    # Load edge collections
    edge_collections_data = {
        "Exploits": exploits,
        "Targets": targets,
        "Uses": uses
    }

    for collection_name, data in edge_collections_data.items():
        if db.collection(collection_name).count() == 0:
            print(f"Loading {len(data)} edges into {collection_name}...")
            # Load in batches to avoid memory issues
            batch_size = 1000
            for i in range(0, len(data), batch_size):
                batch = data[i:i+batch_size]
                db.collection(collection_name).import_bulk(batch)
            print(f"Loaded data into {collection_name}")
        else:
            print(f"Collection {collection_name} already has data, skipping import")

In [None]:
# Set up the graph and load the data
graph = setup_threat_graph_collections(db)
load_threat_data_to_arangodb(
    db,
    cve_vulnerabilities,
    processed_attack['techniques'],
    processed_attack['threat_actors'],
    relationships['assets'],
    relationships['exploits'],
    relationships['targets'],
    relationships['uses']
)

In [None]:
"""# Create NetworkX and cuGraph from ArangoDB"""

In [None]:
# Create NetworkX graph from ArangoDB
def create_threat_graph_nx():
    try:
        # Create a nx-arangodb graph directly connected to the ArangoDB graph
        G_threat = nxadb.Graph(
            name="ThreatGraph",
            db=db,
            create=False  # Don't create a new graph, use existing one
        )

        print(f"Successfully connected to ThreatGraph in ArangoDB")
        print(f"Graph has {G_threat.number_of_nodes()} nodes and {G_threat.number_of_edges()} edges")

        return G_threat
    except Exception as e:
        print(f"Error creating ThreatGraph from ArangoDB: {e}")
        return None

In [None]:
# New function to convert NetworkX graph to cuGraph
def convert_to_cugraph(G_nx):
    """
    Convert a NetworkX graph to a cuGraph graph for GPU-accelerated analytics
    """
    try:
        print("Converting NetworkX graph to cuGraph...")
        
        # Create edge list dataframe for cuGraph
        sources = []
        destinations = []
        weights = []
        edge_attrs = []
        
        # Create node mappings (to handle string node IDs)
        node_map = {node: i for i, node in enumerate(G_nx.nodes())}
        reverse_node_map = {i: node for node, i in node_map.items()}
        
        # Store node attributes 
        node_attrs = {node_map[node]: attrs for node, attrs in G_nx.nodes(data=True)}
        
        # Extract edges and their attributes
        for source, target, data in G_nx.edges(data=True):
            sources.append(node_map[source])
            destinations.append(node_map[target])
            
            # For simplicity, use weight=1.0 for all edges
            weights.append(1.0)
            
            # Store edge attributes for later reference
            edge_attrs.append(data)
            
        # Create cuDF DataFrame for edges
        df = cudf.DataFrame()
        df['src'] = sources
        df['dst'] = destinations
        df['weight'] = weights
        
        # Create cuGraph from DataFrame
        G_cu = cugraph.Graph()
        G_cu.from_cudf_edgelist(df, source='src', destination='dst', edge_attr='weight', renumber=False)
        
        # Store metadata for mapping back to original graph
        G_cu.node_map = node_map
        G_cu.reverse_node_map = reverse_node_map
        G_cu.node_attrs = node_attrs
        G_cu.edge_attrs = edge_attrs
        
        print(f"Successfully converted to cuGraph with {G_cu.number_of_vertices()} vertices and {G_cu.number_of_edges()} edges")
        return G_cu
    except Exception as e:
        print(f"Error converting to cuGraph: {e}")
        print("Falling back to NetworkX graph")
        return None