In [None]:
import asyncio
from typing import Annotated
from semantic_kernel import Kernel
from semantic_kernel.agents import ChatCompletionAgent, ChatHistoryAgentThread
from semantic_kernel.connectors.ai.open_ai import OpenAIChatPromptExecutionSettings, OpenAIChatCompletion
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents import FunctionCallContent, FunctionResultContent
from semantic_kernel.functions import kernel_function, KernelArguments

import subprocess
import pandas as pd
from jaws.jaws_config import *
from jaws.jaws_utils import dbms_connection

# Database not passed, uses the database set in jaws_config.py
driver = dbms_connection(DATABASE)
kernel = Kernel()

In [2]:
async def handle_intermediate_steps(message: ChatMessageContent) -> None:
    for item in message.items or []:
        if isinstance(item, FunctionCallContent):
            print(f"Function Call:> {item.name} with arguments: {item.arguments}")
        elif isinstance(item, FunctionResultContent):
            print(f"Function Result:> {item.result} for function: {item.name}")
        else:
            print(f"{message.role}: {message.content}")

In [3]:
class GatherData:
    @kernel_function(description="Step 1: List available network interfaces. You will never want to select interfaces such as; 'lo', 'docker0', 'wlo1', etc.")
    def list_interfaces(self) -> Annotated[str, "The list of available network interfaces."]:
        interfaces = subprocess.run(['python', './jaws/jaws_capture.py', '--list', '--agent'], capture_output=True, text=True)
        return interfaces.stdout

    @kernel_function(description="Step 2: Captures packets into the database. Choose a duration depending on the amount of data you want to capture. Recommended not to exceed 60 seconds.")
    def capture_packets(self,interface: str, duration: int) -> Annotated[str, "A process status message once the process is complete."]:
        if duration > 60:
            duration = 60
        packets = subprocess.run(['python', './jaws/jaws_capture.py', '--interface', interface, '--duration', str(duration), '--agent'], capture_output=True, text=True)
        return packets.stdout

    @kernel_function(description="Step 3:Documents organizations by sending IP addresses to ipinfo.io.")
    def document_organizations(self) -> Annotated[str, "A process status message once the process is complete."]:
        organizations = subprocess.run(['python', './jaws/jaws_ipinfo.py', '--agent'], capture_output=True, text=True)
        return organizations.stdout

    @kernel_function(description="Step 4: Processes network traffic data into embeddings for reporting and analysis.")
    def compute_embeddings(self) -> Annotated[str, "A process status message once the process is complete."]:
        embeddings = subprocess.run(['python', './jaws/jaws_compute.py', '--agent'], capture_output=True, text=True)
        return embeddings.stdout
    

class PerformAnalysis:
    @kernel_function(description="Step 1: Processes the network traffic data and embeddings into a set of plots for detecting anomalies.")
    def anomoly_detection(self) -> Annotated[str, "A process status message once the process is complete."]:
        output = subprocess.run(['python', './jaws/jaws_finder.py', '--agent'], capture_output=True, text=True)
        return output.stdout
    
    @kernel_function(description="Step 2: Fetches the latest data from the database and returns it as a string.")
    def fetch_data(self) -> Annotated[str, "A string containing the latest data from the database."]:
        query = """
        MATCH (traffic:TRAFFIC)
        WHERE traffic.TIMESTAMP > datetime() - duration({minutes: 10})
        RETURN DISTINCT
            traffic.IP_ADDRESS AS ip_address,
            traffic.PORT AS port,
            traffic.ORGANIZATION AS org,
            traffic.HOSTNAME AS hostname,
            traffic.LOCATION AS location,
            traffic.TOTAL_SIZE AS total_size,
            traffic.OUTLIER AS outlier,
            traffic.TIMESTAMP AS timestamp
        ORDER BY traffic.TIMESTAMP DESC
        LIMIT 100
        """
        with driver.session(database=DATABASE) as session:
            result = session.run(query)
            data = []
            for record in result:
                data.append({
                    'ip_address': record['ip_address'],
                    'port': record['port'],
                    'org': record['org'],
                    'hostname': record['hostname'],
                    'location': record['location'],
                    'total_size': record['total_size'],
                    'outlier': record['outlier'],
                    'timestamp': record['timestamp']
                })
            return str(data)

In [4]:
settings = OpenAIChatPromptExecutionSettings()
service = OpenAIChatCompletion(ai_model_id=OPENAI_MODEL, api_key=OPENAI_API_KEY)
kernel.add_service(service)

network_analyst = ChatCompletionAgent(
    service=service, 
    name="NetworkAnalyst", 
    instructions=ANALYST_SYSTEM_PROMPT,
    plugins=[GatherData()],
    arguments=KernelArguments(settings)
)

expert_advisor = ChatCompletionAgent(
    service=service,
    name="ExpertAdvisor",
    instructions=ADVISOR_PROMPT,
    plugins=[network_analyst, PerformAnalysis()],
    arguments=KernelArguments(settings)
)

In [None]:
async def main() -> None:
    thread: ChatHistoryAgentThread = None
    print("[GREETINGS] How can I help you today? I am a network analyst and advisor, I am capable of gathering data, analyzing it, and providing a report.")
    while True:
        user_input = input("[INPUT] ")

        if user_input.lower().strip() == "exit":
            print("\n\n[EXITING]")
            return False

        async for response in expert_advisor.invoke(
            messages=user_input,
            thread=thread,
            on_intermediate_message=handle_intermediate_steps,
        ):
            print(f"# {response.role}: {response}")
            thread = response.thread

await main()