In [None]:
# Setup
import asyncio
from typing import Annotated, AsyncGenerator
from semantic_kernel import Kernel
from semantic_kernel.agents import ChatCompletionAgent, ChatHistoryAgentThread, GroupChatOrchestration, RoundRobinGroupChatManager, OrchestrationHandoffs, HandoffOrchestration
from semantic_kernel.agents.runtime import InProcessRuntime
from semantic_kernel.connectors.ai.open_ai import OpenAIChatPromptExecutionSettings, OpenAIChatCompletion
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents import FunctionCallContent, FunctionResultContent
from semantic_kernel.functions import kernel_function, KernelArguments

import subprocess
import pandas as pd
import gradio as gr
from datetime import datetime, timedelta
from jaws.jaws_config import *
from jaws.jaws_utils import dbms_connection

# Database not passed, uses the database set in jaws_config.py
driver = dbms_connection(DATABASE)
kernel = Kernel()
settings = OpenAIChatPromptExecutionSettings()
reasoning_service = OpenAIChatCompletion(ai_model_id=OPENAI_REASONING_MODEL, api_key=OPENAI_API_KEY)
kernel.add_service(reasoning_service)
lang_service = OpenAIChatCompletion(ai_model_id=OPENAI_MODEL, api_key=OPENAI_API_KEY)
kernel.add_service(lang_service)

In [2]:
# Tools
class ListInterfaces:
    @kernel_function(description="List available network interfaces. You will never want to select interfaces such as; 'lo', 'docker0', 'wlo1', etc.")
    def list_interfaces(self) -> Annotated[str, "A list of available network interfaces."]:
        interfaces = subprocess.run(['python', './jaws/jaws_capture.py', '--list', '--agent'], capture_output=True, text=True)
        return str(interfaces.stdout)


class CapturePackets:
    @kernel_function(description="Captures packets into the database. Choose a duration depending on the amount of data you want to capture. Recommended not to exceed 60 seconds.")
    def capture_packets(self, interface: str, duration: int) -> Annotated[str, "A process status message once the process is complete."]:
        if duration > 60:
            duration = 60
        packets = subprocess.run(['python', './jaws/jaws_capture.py', '--interface', interface, '--duration', str(duration), '--agent'], capture_output=True, text=True)
        return str(packets.stdout)


class DocumentOrganizations:
    @kernel_function(description="Enriches data with organization ownership by looking up IP addresses.")
    def document_organizations(self) -> Annotated[str, "A process status message once the process is complete."]:
        organizations = subprocess.run(['python', './jaws/jaws_ipinfo.py', '--agent'], capture_output=True, text=True)
        return str(organizations.stdout)


class ComputeEmbeddings:
    @kernel_function(description="Transforms the network traffic data into embeddings for analysis.")
    def compute_embeddings(self) -> Annotated[str, "A process status message once the process is complete."]:
        embeddings = subprocess.run(['python', './jaws/jaws_compute.py', '--agent'], capture_output=True, text=True)
        return str(embeddings.stdout)
    

class AnomalyDetection:
    @kernel_function(description="Analyzes the network traffic data and embeddings and returns a list of anomalies.")
    def anomoly_detection(self) -> Annotated[str, "A string containing a list of anomalies."]:
        output = subprocess.run(['python', './jaws/jaws_finder.py', '--agent'], capture_output=True, text=True)
        return str(output.stdout)
    

class FetchData:
    @kernel_function(description="Fetches the latest (10 minutes) traffic data from the database and returns it as a string.")
    def fetch_traffic(self) -> Annotated[str, "A string containing a list of current traffic data."]:
        query = """
        MATCH (traffic:TRAFFIC)
        WHERE traffic.TIMESTAMP > datetime() - duration({minutes: 10})
        RETURN DISTINCT
            traffic.IP_ADDRESS AS ip_address,
            traffic.PORT AS port,
            traffic.ORGANIZATION AS org,
            traffic.HOSTNAME AS hostname,
            traffic.LOCATION AS location,
            traffic.TOTAL_SIZE AS total_size,
            traffic.OUTLIER AS outlier,
            traffic.TIMESTAMP AS timestamp
        ORDER BY traffic.TIMESTAMP DESC
        LIMIT 100
        """
        with driver.session(database=DATABASE) as session:
            result = session.run(query)
            data = []
            for record in result:
                data.append({
                    'ip_address': record['ip_address'],
                    'port': record['port'],
                    'org': record['org'],
                    'hostname': record['hostname'],
                    'location': record['location'],
                    'total_size': record['total_size'],
                    'outlier': record['outlier'],
                    'timestamp': record['timestamp']
                })
            return str(data)

In [3]:
# Agents
operator = ChatCompletionAgent(
    service=reasoning_service,
    name="Operator",
    description="You are the eyes of the network. You are tasked with sampling and reviewing network traffic data to identify patterns, anomalies, escalating to Lead Analyst and their team for further probing and reporting up the chain.",
    instructions=OPERATOR_PROMPT,
    plugins=[ListInterfaces(), CapturePackets(), DocumentOrganizations(), ComputeEmbeddings(), AnomalyDetection()],
    arguments=KernelArguments(settings)
)

network_analyst = ChatCompletionAgent(
    service=lang_service, 
    name="NetworkAnalyst",
    description="An expert IT Professional, Sysadmin, and Analyst. Tasked with capturing network packets and performing ETL(Extract, Transform, and Load) with the data to prepare it for analysis.",
    instructions=ANALYST_MANAGED_PROMPT,
    plugins=[ListInterfaces(), CapturePackets(), DocumentOrganizations(), ComputeEmbeddings()],
    arguments=KernelArguments(settings)
)

lead_network_analyst = ChatCompletionAgent(
    service=reasoning_service,
    name="LeadAnalyst",
    description="An expert IT Professional, Sysadmin, and Analyst. Tasked with reviewing network traffic data to identify patterns, anomalies, and make recommendations for security configurations.",
    instructions=MANAGER_PROMPT,
    plugins=[AnomalyDetection(),FetchData()],
    arguments=KernelArguments(settings)
)

handoffs = (
    OrchestrationHandoffs()
    .add_many(
        source_agent=operator.name,
        target_agents={
            lead_network_analyst.name: "Escalate to the Lead Analyst if the operator detects suspicious activity.",
        },
    )
    .add(
        source_agent=lead_network_analyst.name,
        target_agent=network_analyst.name,
        description="Requests fresh network traffic data from the Network Analyst to enrich the existing data.",
    )
    .add(
        source_agent=network_analyst.name,
        target_agent=lead_network_analyst.name,
        description="Transfer back to Lead Analyst after completing data collection tasks.",
    )
)

In [4]:
# Orchestration
members=[operator, lead_network_analyst, network_analyst]
max_rounds = 2

async def handle_intermediate_steps(message: ChatMessageContent) -> None:
    for item in message.items or []:
        if isinstance(item, FunctionCallContent):
            print(f"Function Call:> {item.name} with arguments: {item.arguments}")
        elif isinstance(item, FunctionResultContent):
            print(f"Function Result:> {item.result} for function: {item.name}")
        else:
            print(f"{message.role}: {message.content}")


async def operator_activity(input: str) -> str:
    runtime = InProcessRuntime()
    runtime.start()

    print(f"[INPUT] | {input}")

    try:
        print(f"[ORCHESTRATION] | [SOLO]\n")
        response = await operator.get_response(
            messages=input, 
            runtime=runtime
        )
        
        return str(response.content)
    
    except Exception as e:
        print(f"\n[ERROR] | {str(e)}")
        return ""
    
    finally:
        await runtime.stop_when_idle()


async def orchestration(input: str) -> str:
    group_chat_orchestration = GroupChatOrchestration(
        members=members,
        manager=RoundRobinGroupChatManager(max_rounds=max_rounds),
        agent_response_callback=handle_intermediate_steps
    )

    handoff_orchestration = HandoffOrchestration(
        members=members,
        handoffs=handoffs,
        agent_response_callback=handle_intermediate_steps
    )

    runtime = InProcessRuntime()
    runtime.start()

    print(f"[INPUT] | {input}")
    
    try:
        print(f"[ORCHESTRATION] | [GROUP CHAT] {len(members)} | [ROUND ROBIN] {max_rounds}\n")
        result = await group_chat_orchestration.invoke(
            task=input,
            runtime=runtime
        )
    
        #print(f"[ORCHESTRATION] | [HANDOFF] {len(members)}\n")
        #result = await handoff_orchestration.invoke(
        #    task=input,
        #    runtime=runtime,
        #)
        
        response = await result.get()

        if hasattr(response, 'content'):
            response_text = response.content
        elif hasattr(response, 'inner_content') and hasattr(response.inner_content, 'content'):
            response_text = response.inner_content.content
        else:
            response_text = str(response)
        
        return response_text
        
    except Exception as e:
        print(f"\n[ERROR] | {str(e)}")
        return ""
        
    finally:
        await runtime.stop_when_idle()

In [None]:
# Interface
def next_report(minutes):
    next_report_time = datetime.now() + timedelta(minutes=minutes)
    return f"⏲️ Automatic Report | {minutes} Minutes | {next_report_time.strftime('%H:%M:%S')}"


with gr.Blocks(title="Command Center | Network Traffic Monitoring") as INTERFACE:
    
    # Operator states
    operator_chat_history = gr.State(value=[{
        "role": "assistant", 
        "content": f"Tasked with periodically capturing network traffic and reporting on patterns.\n1. Capture network traffic on a specific interface.\n2. Enrich the data using OSINT.\n4. Compute embeddings from traffic data.\n5. Analyze the traffic data for anomalies and patterns.",
        "metadata": {"title": "👁️ Network Operator"}
    }])
    operator_minutes = gr.State(value=10) # 10 minutes, to be finished roughly every 12-15 minutes.
    operator_seconds = gr.State(value=operator_minutes.value * 60)
    operator_timer = gr.Timer(value=operator_seconds.value, active=True)
    
    # Analyst states
    analyst_chat_history = gr.State(value=[{
        "role": "assistant", 
        "content": f"Tasked with working as a team to capture and process network traffic into a comprehensive report. Click 'Request Report' to start the network analysis or wait for the automated report.\n1. Capture network traffic on a specific interface.\n2. Enrich the data using OSINT.\n4. Compute embeddings from traffic data.\n5. Analyze the traffic data for anomalies and patterns.",
        "metadata": {"title": "🪬 Network Traffic Analyst"}
    }])
    analyst_minutes = gr.State(value=45) # 45 minutes, to be finished roughly every 60 minutes.
    analyst_seconds = gr.State(value=analyst_minutes.value * 60)
    analyst_timer = gr.Timer(value=analyst_seconds.value, active=True)
    
    with gr.Column():
        # Operator Section
        with gr.Row():
            operator_chatbot = gr.Chatbot(
                value=operator_chat_history.value,
                type="messages",
                show_label=False,
                autoscroll=True,
                resizable=True,
                show_copy_button=True,
                height=320
            )
        with gr.Row():
            operator_timer_status = gr.Textbox(
                value=next_report(operator_minutes.value),
                show_label=False,
                container=False,
                interactive=False,
                text_align="center"
            )
    
        # Analyst Section
        with gr.Row():
            analyst_chatbot = gr.Chatbot(
                value=analyst_chat_history.value,
                type="messages",
                show_label=False,
                autoscroll=True,
                resizable=True,
                show_copy_button=True,
                height=480
            )
        with gr.Row():
            analyst_timer_status = gr.Textbox(
                value=next_report(analyst_minutes.value),
                show_label=False,
                container=False,
                interactive=False,
                text_align="center"
            )
        with gr.Row():
            analyze_button = gr.Button("🔮 Request Report", variant="huggingface")
    
    async def operator_request(history):
        response = await operator_activity("Perform your network probe and report back to the command center.")
        timestamp = datetime.now()
        formatted_response = {"role": "assistant", "content": response, "metadata": {"title": f"️🔮 Situation Report | {timestamp.strftime('%Y-%m-%d %H:%M:%S')}"}}
        chat_history = (history + [formatted_response])[-10:]
        return chat_history, chat_history
    
    async def request_report(history):
        response = await orchestration("The command center is requesting their periodic report in from your team.")
        timestamp = datetime.now()
        formatted_response = {"role": "assistant", "content": response, "metadata": {"title": f"🔮 Situation Report | {timestamp.strftime('%Y-%m-%d %H:%M:%S')}"}}
        chat_history = (history + [formatted_response])[-3:]
        return chat_history, chat_history
    
    operator_timer.tick(
        fn=operator_request,
        inputs=[operator_chat_history],
        outputs=[operator_chatbot, operator_chat_history]
    )
    
    operator_timer.tick(
        fn=next_report,
        inputs=[operator_minutes],
        outputs=[operator_timer_status]
    )
    
    analyze_button.click(
        fn=request_report,
        inputs=[analyst_chat_history],
        outputs=[analyst_chatbot, analyst_chat_history]
    )
    
    analyst_timer.tick(
        fn=request_report,
        inputs=[analyst_chat_history],
        outputs=[analyst_chatbot, analyst_chat_history]
    )
    
    analyst_timer.tick(
        fn=next_report,
        inputs=[analyst_minutes],
        outputs=[analyst_timer_status]
    )

INTERFACE.launch() #share=True