In [16]:
# Setup
import asyncio
from typing import Annotated, AsyncGenerator
from semantic_kernel import Kernel
from semantic_kernel.agents import ChatCompletionAgent, ChatHistoryAgentThread, GroupChatOrchestration, RoundRobinGroupChatManager
from semantic_kernel.agents.runtime import InProcessRuntime
from semantic_kernel.connectors.ai.open_ai import OpenAIChatPromptExecutionSettings, OpenAIChatCompletion
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents import FunctionCallContent, FunctionResultContent
from semantic_kernel.functions import kernel_function, KernelArguments

import subprocess
import pandas as pd
import gradio as gr
from jaws.jaws_config import *
from jaws.jaws_utils import dbms_connection

# Database not passed, uses the database set in jaws_config.py
driver = dbms_connection(DATABASE)
kernel = Kernel()
settings = OpenAIChatPromptExecutionSettings()
reasoning_service = OpenAIChatCompletion(ai_model_id=OPENAI_REASONING_MODEL, api_key=OPENAI_API_KEY)
kernel.add_service(reasoning_service)
lang_service = OpenAIChatCompletion(ai_model_id=OPENAI_MODEL, api_key=OPENAI_API_KEY)
kernel.add_service(lang_service)

In [17]:
# Semantic Kernel Helper Functions
async def handle_intermediate_steps(message: ChatMessageContent) -> None:
    for item in message.items or []:
        if isinstance(item, FunctionCallContent):
            print(f"Function Call:> {item.name} with arguments: {item.arguments}")
        elif isinstance(item, FunctionResultContent):
            print(f"Function Result:> {item.result} for function: {item.name}")
        else:
            print(f"{message.role}: {message.content}")

In [18]:
# Tools
class ListInterfaces:
    @kernel_function(description="List available network interfaces. You will never want to select interfaces such as; 'lo', 'docker0', 'wlo1', etc.")
    def list_interfaces(self) -> Annotated[str, "A list of available network interfaces."]:
        interfaces = subprocess.run(['python', './jaws/jaws_capture.py', '--list', '--agent'], capture_output=True, text=True)
        return str(interfaces.stdout)


class CapturePackets:
    @kernel_function(description="Captures packets into the database. Choose a duration depending on the amount of data you want to capture. Recommended not to exceed 60 seconds.")
    def capture_packets(self, interface: str, duration: int) -> Annotated[str, "A process status message once the process is complete."]:
        if duration > 60:
            duration = 60
        packets = subprocess.run(['python', './jaws/jaws_capture.py', '--interface', interface, '--duration', str(duration), '--agent'], capture_output=True, text=True)
        return str(packets.stdout)


class DocumentOrganizations:
    @kernel_function(description="Enriches data with organizations by looking up IP addresses.")
    def document_organizations(self) -> Annotated[str, "A process status message once the process is complete."]:
        organizations = subprocess.run(['python', './jaws/jaws_ipinfo.py', '--agent'], capture_output=True, text=True)
        return str(organizations.stdout)


class ComputeEmbeddings:
    @kernel_function(description="Transforms the network traffic data into embeddings for analysis.")
    def compute_embeddings(self) -> Annotated[str, "A process status message once the process is complete."]:
        embeddings = subprocess.run(['python', './jaws/jaws_compute.py', '--agent'], capture_output=True, text=True)
        return str(embeddings.stdout)
    

class AnomalyDetection:
    @kernel_function(description="Transforms the network traffic data and embeddings into a list of anomalies.")
    def anomoly_detection(self) -> Annotated[str, "A string containing a list of anomalies."]:
        output = subprocess.run(['python', './jaws/jaws_finder.py', '--agent'], capture_output=True, text=True)
        return str(output.stdout)
    

class FetchTraffic:
    @kernel_function(description="Step 2: Fetches the latest traffic data from the database and returns it as a string.")
    def fetch_traffic(self) -> Annotated[str, "A string containing a list of current traffic data."]:
        query = """
        MATCH (traffic:TRAFFIC)
        WHERE traffic.TIMESTAMP > datetime() - duration({minutes: 10})
        RETURN DISTINCT
            traffic.IP_ADDRESS AS ip_address,
            traffic.PORT AS port,
            traffic.ORGANIZATION AS org,
            traffic.HOSTNAME AS hostname,
            traffic.LOCATION AS location,
            traffic.TOTAL_SIZE AS total_size,
            traffic.OUTLIER AS outlier,
            traffic.TIMESTAMP AS timestamp
        ORDER BY traffic.TIMESTAMP DESC
        LIMIT 100
        """
        with driver.session(database=DATABASE) as session:
            result = session.run(query)
            data = []
            for record in result:
                data.append({
                    'ip_address': record['ip_address'],
                    'port': record['port'],
                    'org': record['org'],
                    'hostname': record['hostname'],
                    'location': record['location'],
                    'total_size': record['total_size'],
                    'outlier': record['outlier'],
                    'timestamp': record['timestamp']
                })
            return str(data)

In [19]:
# Prompts
ANALYST_PROMPT = """You are an expert IT Professional, Sysadmin, and Analyst. Your task is to capture network packets and perform ETL(Extract, Transform, and Load) on the network data to prepare it for analysis. Once the network traffic data is prepared, you task is to analyze it for anomalies and patterns. You have access to several tools to accomplish this, but the process is faily linear, and looks something like this:

1. Use tool: ListInterfaces() to list and select an interface. You will never want to select interfaces such as; 'lo', 'docker0', 'wlo1', etc.
2. Use tool: CapturePackets() to capture network traffic. This is a critical step, as packet data is the foundation of the analysis.
3. Use tool: DocumentOrganizations() to document organizations using captured ip address data. This is an important step, as it enriches the packet data with organization ownership information.
4. Use tool: ComputeEmbeddings() to compute embeddings from traffic data. This is an important step, as it transforms the data into traffic embeddings.
5. Use tool: AnomalyDetection() to analyze the traffic data for anomalies and patterns.
6. Use tool: FetchTraffic() to fetch the final enriched and transformed network traffic data from the database to be used for analysis and reporting.
7. Return a report of your findings using the following format:

Executive Summary:
A concise summary of the traffic analysis, including a description of the cluster plot.

Traffic Patterns: 
Identify and describe the regular traffic patterns. Highlight any anomalies or unusual patterns.

Recommendations:
1. Recommendations: List detailed recommendations for enhancing security based on the traffic patterns identified.
2. Rationale: Provide a rationale for each recommendation, explaining how it addresses specific issues identified in the traffic analysis.
"""

ANALYST_MANAGED_PROMPT = """You are an expert IT Professional, Sysadmin, and Analyst. Your task is to capture network packets and perform ETL, Extract, Transform, and Load using the network data to prepare it for downstream analysis. You have access to several tools, but the process is faily linear. and looks something like this:

1. Use tool: ListInterfaces() to list and select an interface. You will never want to select interfaces such as; 'lo', 'docker0', 'wlo1', etc.
2. Use tool: CapturePackets() to capture network traffic. This is a critical step, as packet data is the foundation of the analysis.
3. Use tool: DocumentOrganizations() to document organizations using captured ip address data. This is an important step, as it enriches the packet data with organization ownership information.
4. Use tool: ComputeEmbeddings() to compute embeddings from traffic data. This is an important step, as it transforms the data into traffic embeddings.
"""

MANAGER_PROMPT = """You are an expert IT Professional, Sysadmin, and Analyst. Your task is to review data from network traffic to identify patterns and make recommendations for security configurations. 

You can use the FetchTraffic() tool to check if there is any data available. If data exists, you can use the anomoly_detection tool to detect anomalies.

If there is no data, or an empty DataFrame is returned, you should leverage the network_analyst agent you manage to capture and process network traffic. It is very expensive
to collect and store network traffic data, so do not recommend that the network_analyst agent collect more than 60 seconds of data.

Since data is being collected over short periods of time. You should always consider collecting fresh data before peforming your analysis. It is recommended that you consider running 
fetch_data to see what data is available, but not not limit yourself to these outputs as they may be outdated, and consider requesting fresh data from the network_analyst agent.

When you have access to fresh data, return a brief report in the following format:

Executive Summary:
A concise summary of the traffic analysis, including a description of the cluster plot.

Traffic Patterns: 
Identify and describe the regular traffic patterns. Highlight any anomalies or unusual patterns.

Recommendations:
1. Recommendations: List detailed recommendations for enhancing security based on the traffic patterns identified.
2. Rationale: Provide a rationale for each recommendation, explaining how it addresses specific issues identified in the traffic analysis.
"""

In [20]:
# Agents
network_analyst = ChatCompletionAgent(
    service=lang_service, 
    name="NetworkAnalyst",
    description="An expert IT Professional, Sysadmin, and Analyst. Tasked with capturing network packets and perform ETL(Extract, Transform, and Load) on the network data to prepare it for analysis.", # Once the network traffic data is prepared, tasked with analyzing the data for anomalies and patterns.
    instructions=ANALYST_MANAGED_PROMPT,
    plugins=[ListInterfaces(), CapturePackets(), DocumentOrganizations(), ComputeEmbeddings()],
    arguments=KernelArguments(settings)
)

# For managed orchestration.
lead_network_analyst = ChatCompletionAgent(
    service=reasoning_service,
    name="LeadAnalyst",
    description="An expert IT Professional, Sysadmin, and Analyst. Tasked with reviewing network traffic data to identify patterns, anomalies, and make recommendations for security configurations.",
    instructions=MANAGER_PROMPT,
    plugins=[FetchTraffic(), AnomalyDetection()],
    arguments=KernelArguments(settings)
)

In [21]:
# Group Chat Orchestration
members=[network_analyst, lead_network_analyst]
max_rounds = 2


def response_callback(message: ChatMessageContent) -> None:
    print(f"**{message.name}**\n{message.content}")


async def group_chat(input: str) -> str:
    group_chat_orchestration = GroupChatOrchestration(
        members=members,
        manager=RoundRobinGroupChatManager(max_rounds=max_rounds),
        agent_response_callback=response_callback,
    )

    runtime = InProcessRuntime()
    runtime.start()

    print(f"[INPUT] | {input}")
    
    try:
        print(f"[ORCHESTRATION] | [GROUP CHAT] {len(members)} | [ROUND ROBIN] {max_rounds}")
        orchestration_result = await group_chat_orchestration.invoke(
            task=input,
            runtime=runtime
        )
        
        response = await orchestration_result.get()
        
        if hasattr(response, 'content'):
            response_text = response.content
        elif hasattr(response, 'inner_content') and hasattr(response.inner_content, 'content'):
            response_text = response.inner_content.content
        else:
            response_text = str(response)
        
        print(f"[RESPONSE] | {response_text}")
        return response_text
        
    except Exception as e:
        print(f"\n[ERROR] | {str(e)}")
        return ""
        
    finally:
        await runtime.stop_when_idle()


async def request_report():
    response = await group_chat("The command center is reporting suspicious activity at your endpoint, please probe the network traffic and report back ASAP.")
    formatted_response = {"role": "assistant", "content": response, "metadata": {"title": "🔮 Situation Report"}}
    return [formatted_response]

#response = await request_report()
#print(response)

In [None]:
# Interface
with gr.Blocks(title="Network Traffic Analysis") as INTERFACE:
    with gr.Column():
        with gr.Row():
            CHATBOT = gr.Chatbot(
                value=[{
                    "role": "assistant", 
                    "content": f"[ORCHESTRATION] | [GROUP CHAT] {len(members)} | [ROUND ROBIN] {max_rounds}\nClick 'Request Report' to start the network analysis.\nThe process consists of the following activities:\n1. List and select an interface.\n2. Capture network traffic on the selected interface.\n3. Document organizations using captured ip address data.\n4. Compute embeddings from traffic data.\n5. Analyze the traffic data for anomalies and patterns.\n6. Return a situation report.",
                    "metadata": {"title": "🪬 Network Traffic Analysis"}
                }],
                type="messages",
                show_label=False,
                autoscroll=True,
                resizable=True,
                show_copy_button=True,
                height=320
            )
        with gr.Row():
            ANALYZE_BUTTON = gr.Button(
                "🔮 Request Report"
            )

    ANALYZE_BUTTON.click(
        fn=request_report,
        outputs=[CHATBOT]
    )
    
INTERFACE.launch(share=True)

In [15]:
# Hand Off Orchestration | NOT IN USE
async def hand_off() -> None:
    thread: ChatHistoryAgentThread = None
    print("[READY] | [ExpertAdvisor]")
    while True:
        user_input = input("[INPUT] ")

        if user_input.lower().strip() == "that is all":
            print("\n\n[EXITING]")
            return False

        print(f"[ORCHESTRATION] | [HAND OFF] [THREAD]")
        async for response in lead_network_analyst.invoke(
            messages=user_input,
            thread=thread,
            on_intermediate_message=handle_intermediate_steps,
        ):
            print(f"# {response.role}: {response}")
            thread = response.thread