## Setup

In [5]:
import os
from openai import AzureOpenAI
import argparse
import sys
import json
from azure.search.documents import SearchClient
from azure.core.credentials import AzureKeyCredential
import csv
from azure.storage.filedatalake import (DataLakeServiceClient, FileSystemClient, DataLakeDirectoryClient)
from azure.identity import (DefaultAzureCredential, ClientSecretCredential)
import pandas as pd
from io import (BytesIO, StringIO)
import ipaddress



#Load environment variables
from dotenv import load_dotenv
load_dotenv()

True

In [146]:
#run az login before to be able to use this credential
credential = DefaultAzureCredential()

In [2]:
#instatiate Azure OpenAI client
client = AzureOpenAI(
  azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), 
  api_key=os.getenv("AZURE_OPENAI_API_KEY"),  
  api_version=os.getenv("AZURE_OPENAI_API_VERSION")
)

model = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME")


In [3]:
# Initialize the Azure Cognitive Search client
search_client = SearchClient(
    endpoint=os.getenv("AZURE_AI_SEARCH_ENDPOINT"),
    index_name=os.getenv("AZURE_AI_SEARCH_INDEX"),
    credential=AzureKeyCredential(os.getenv("AZURE_AI_SEARCH_API_KEY"))
)

## Function Definitions

In [None]:
#TO DO: Find a way to work with IP Ranges
def create_ip_range_filter(ip_range):
    """
    Creates a prefix search query string for finding IP addresses within a given IP range.

    Args:
        ip_range (str): The IP range in CIDR notation (e.g., "20.0.0.0/24").

    Returns:
        str: A prefix search query string.
    """
    try:
        # Parse the IP range
        network = ipaddress.ip_network(ip_range, strict=False)
        prefix = str(network.network_address).rsplit('.', 1)[0]

        # Create the prefix search query string
        query_string = f"destination_ip:{prefix}.*"
        return query_string
    except ValueError as e:
        return f"Invalid IP range: {e}"

# Example usage
ip_range_filter = create_ip_range_filter("20.0.0.0/24")
print(ip_range_filter)


destination_ip:20.0.0.*


In [6]:
"""
Queries the Azure Cognitive Search index for risky IP ranges.

Args:
    ip_range (str): The IP range or address to search for in the index.

Returns:
    str: A JSON-formatted string containing the risk details if found, 
         or a message indicating no risky IP range was found.
"""
def query_risky_ip_ranges(ip_range):
    results = search_client.search(
        search_text=ip_range,  # Perform a keyword search using the provided IP range
        query_type="full",  # Use simple query type for keyword search
        search_fields=["ip_ranges"],  # Field to search for the IP range
        select=["risk_id", "description", "cvss_score", "ports", "protocols", "ip_ranges", "last_updated"],  # Select relevant fields
        top=1  # Limit to top 1 results
    )

    top_result = next(iter(results), None)
    if top_result:
        return json.dumps({
            "risk_id": top_result["risk_id"],
            "description": top_result["description"],
            "cvss_score": top_result["cvss_score"],
            "ports": top_result["ports"],
            "protocols": top_result["protocols"],
            "ip_ranges": top_result["ip_ranges"],
            "last_updated": top_result["last_updated"]
        }, indent=2)

    return "No risky IP range found for this input."

In [32]:
test_ipranges_risky = query_risky_ip_ranges("193.163.125.241")
print(test_ipranges_risky)

{
  "risk_id": "R401",
  "description": "Reported abusive IP",
  "cvss_score": "10.0",
  "ports": null,
  "protocols": null,
  "ip_ranges": "193.163.125.241",
  "last_updated": "2025-03-03T00:00:00Z"
}


In [33]:
test_ipranges_unrisky = query_risky_ip_ranges("193.163.125.245")
print(test_ipranges_unrisky)

No risky IP range found for this input.


In [None]:
#TO DO: case-sensitive for protocol!
#TO DO: comma separation in ports
#TO DO: * for all ports
"""
Queries the Azure Cognitive Search index for risky ports.

Args:
    port (str): The port number to search for in the index.
    protocol (str): The protocol (e.g., "tcp", "udp") to search for in the index.

Returns:
    str: A JSON-formatted string containing the risk details if found,
         or a message indicating no risky port was found.
"""
def query_risky_ports(port, protocol):
    results = search_client.search(
        search_text="*",  # Perform a keyword search using the provided port
        filter=f"ports eq '{port}' and protocols eq '{protocol}'",  # Filter by the specified port and protocol 
        query_type="full",  # Use simple query type for keyword search
        select=["risk_id", "description", "cvss_score", "ports", "protocols", "ip_ranges", "last_updated"],  # Select relevant fields
        top=1  # Limit to top 1 results
    )

    top_result = next(iter(results), None)
    if top_result:
        return json.dumps({
            "risk_id": top_result["risk_id"],
            "description": top_result["description"],
            "cvss_score": top_result["cvss_score"],
            "ports": top_result["ports"],
            "protocols": top_result["protocols"],
            "ip_ranges": top_result["ip_ranges"],
            "last_updated": top_result["last_updated"]
        }, indent=2)

    return "No risky port found for this input."

In [34]:
test_ports_risky = query_risky_ports("20", "tcp")
print(test_ports_risky)


{
  "risk_id": "R002",
  "description": "FTP Data- Unencrypted file transfer",
  "cvss_score": "5.0",
  "ports": "20",
  "protocols": "tcp",
  "ip_ranges": "0.0.0.0/0",
  "last_updated": "2025-03-03T00:00:00Z"
}


In [35]:
test_ports_unrisky = query_risky_ports("145", "tcp")
print(test_ports_unrisky)

No risky port found for this input.


In [13]:
"""
Queries the Azure Cognitive Search index for public risk information associated with a firewall rule.

Args:
    source_ip (str): The source IP address or range to search for in the index.
    destination_ip (str): The destination IP address or range to search for in the index.
    port (str): The port number to search for in the index.
    protocol (str): The protocol (e.g., "tcp", "udp") to search for in the index.

Returns:
    str: A JSON-formatted string containing the risk details for the source IP, destination IP, and port,
         or a message indicating no risk was found for each component.
"""
def query_firewall_rule_public_risk(source_ip, destination_ip, port, protocol):
    source_ip_risk = query_risky_ip_ranges(source_ip)
    destination_ip_risk = query_risky_ip_ranges(destination_ip)
    port_risk = query_risky_ports(port, protocol)
    return json.dumps(
        {
        "source_ip": source_ip_risk,
        "destination_ip": destination_ip_risk,
        "port": port_risk
    }, indent=2)

In [36]:
test_firewall_rule_public_risk = query_firewall_rule_public_risk("193.163.125.241","193.163.125.245", "20", "tcp")
print(test_firewall_rule_public_risk)

{
  "source_ip": "{\n  \"risk_id\": \"R401\",\n  \"description\": \"Reported abusive IP\",\n  \"cvss_score\": \"10.0\",\n  \"ports\": null,\n  \"protocols\": null,\n  \"ip_ranges\": \"193.163.125.241\",\n  \"last_updated\": \"2025-03-03T00:00:00Z\"\n}",
  "destination_ip": "No risky IP range found for this input.",
  "port": "{\n  \"risk_id\": \"R002\",\n  \"description\": \"FTP Data- Unencrypted file transfer\",\n  \"cvss_score\": \"5.0\",\n  \"ports\": \"20\",\n  \"protocols\": \"tcp\",\n  \"ip_ranges\": \"0.0.0.0/0\",\n  \"last_updated\": \"2025-03-03T00:00:00Z\"\n}"
}


## Define the tools that are  for the function calling

In [21]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "query_risky_ip_ranges",
            "description": "Retrieve risky IP ranges from Azure AI Search index",
            "parameters": {
                "type": "object",
                "properties": {
                    "ip_range": {
                        "type": "string",
                        "description": "The ip-range from the firewall rule request to search for in the risky IP range index",
                    },
                },
                "required": ["ip_range"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "query_risky_ports",
            "description": "Retrieve risky ports from Azure AI Search index",
            "parameters": {
                "type": "object",
                "properties": {
                    "port": {
                        "type": "string",
                        "description": "The port number from the firewall rule request to search for in the risky ports index",
                    },
                    "protocol": {
                        "type": "string",
                        "description": "The protocol from the firewall rule request to search for in the risky ports index",
                    },
                },
                "required": ["port"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "query_firewall_rule_public_risk",
            "description": "Takes source IP, destination IP and port as input and returns whether there is any public risk information associated with the firewall rule",
            "parameters": {
                "type": "object",
                "properties": {
                    "source_ip": {
                        "type": "string",
                        "description": "The source ip-range from the firewall rule request to search for in the risky IP range index",
                    },
                    "destination_ip": {
                        "type": "string",
                        "description": "The destination ip-range from the firewall rule request to search for in the risky IP range index",
                    },
                    "port": {
                        "type": "string",
                        "description": "The port number from the firewall rule request to search for in the risky ports index",
                    },
                    "protocol": {
                        "type": "string",
                        "description": "The protocol from the firewall rule request to search for in the risky ports index",
                    },
                },
                "required": ["source_ip","destination_ip","port", "protocol"],
            },
        },
    },
]

available_functions = {
    "query_risky_ip_ranges": query_risky_ip_ranges,
    "query_risky_ports": query_risky_ports,
    "query_firewall_rule_public_risk": query_firewall_rule_public_risk,
}

## Conversation Handling

In [22]:
def run_conversation(messages, tools, available_functions, debug=False):
    def log(*args):
        if debug:
            print(*args)

    # Step 1: send the conversation and available functions to GPT
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        tools=tools,
        tool_choice="auto",
        temperature=0.2,
    )
    response_message = response.choices[0].message

    # Step 2: check if the model wants to call a function
    if response_message.tool_calls:
        tool_call = response_message.tool_calls[0]
        function_name = tool_call.function.name
        log("Recommended Function Call:", tool_call)

        # Step 3: validate and execute the function
        if function_name not in available_functions:
            return f"Function '{function_name}' not found in available_functions."

        function_to_call = available_functions[function_name]

        try:
            function_args = json.loads(tool_call.function.arguments)
        except json.JSONDecodeError as e:
            return f"Failed to parse function arguments: {e}"

        try:
            function_response = function_to_call(**function_args)
        except Exception as e:
            return f"Error while calling '{function_name}': {e}"

        log("Function Output:", function_response)

        # Step 4: add tool call + function response to messages
        messages.append(
            {
                "role": response_message.role,
                "function_call": {
                    "name": response_message.tool_calls[0].function.name,
                    "arguments": response_message.tool_calls[0].function.arguments,
                },
                "content": None,
            }
        )

        # adding function response to messages
        messages.append(
            {
                "role": "function",
                "name": function_name,
                "content": function_response,
            }
        )  # extend conversation with function response

        log("Messages passed in second request:")
        if debug:
            for msg in messages:
                print(msg)

        second_response = client.chat.completions.create(
            messages=messages,
            model=model,
        )  # get a new response from GPT where it can see the function response

        return second_response
    else:
        return response

## System Message

In [None]:
system_message = """
You are a security risk assessment assistant that helps users evaluate whether an Azure Firewall rule is risky. You interact in natural language and use available tools when needed.

You have access to an Azure AI Search index containing public risk data about IP ranges and ports. Use the appropriate function depending on what information is provided:
- `query_risky_ip_ranges(ip_range)`: use this when only the IP address or IP range is provided.
- `query_risky_ports(port, protocol)`: use this when only the port and protocol are provided.
- `query_public_risk_data(source_ip, destination_ip, port, protocol)`: use this when all sourceIp, destination IP, port and protocol are provided.

Your goals:
- Evaluate firewall rules using available data.
- Use CVSS score and descriptions to assess risk when available.
- If no data is found or input is incomplete, ask follow-up questions or make an informed estimate.
- Assign a confidence score (0–100%) based on how complete and reliable the information is.

Risk classification based on CVSS:
- `High Risk`: score ≥ 7.0
- `Medium Risk`: 4.0 ≤ score < 7.0
- `Low Risk`: score < 4.0

Return your output in the form of a markdown table with the following columns:

| Source IP | Destination IP | Port | Protocol | Risk Level | CVSS Score | Description | Recommendation | Confidence Score |
|-----------|----------------|------|----------|------------|-------------|-------------|----------------|------------------|

- Populate each column with your findings.
- Include all user-provided input fields (IP, port, protocol) in the table.
- Risk Level should be based on CVSS or logical inference.
- Recommendation should be one of: "Block this rule", "Monitor this rule", or "Allow this rule".
- If CVSS score is unavailable, use "N/A".

Only output the final markdown table. Do not include explanations, summaries, or any additional text.
"""

## Test with a few chat inputs

In [None]:
messages = [
    {"role": "system", "content": system_message},
    {
        "role": "user",
        "content": "Help me find out whether this firewall rule is risky or not: port is 20, protocol is tcp",
    },
]

#testing for ip-range: "content": "Help me find out whether this firewall rule is risky or not: 193.163.125.240"
#testing for port and ip-range:"content": "Help me find out whether this firewall rule is risky or not: ip-range is 193.163.125.240, port is 20"
#testing for all together: "content": "Help me find out whether this firewall rule is risky or not: source-ip is 193.163.125.240, destination-ip is 193.163.125.245, and port is 20"
#testing for all together: "content": "Help me find out whether this firewall rule is risky or not: source-ip is 193.163.125.240, destination-ip is 193.163.125.245, port is 20, protocol is tcp"

result = run_conversation(messages, tools, available_functions, debug=True)

print("Final response:")
if isinstance(result, str):
    print(result)  # It's an error message or string content
else:
    print(result.choices[0].message.content)  # It's a ChatCompletionMessage object
#print(result.choices[0].message.content)

Recommended Function Call: ChatCompletionMessageToolCall(id='call_mq9oXn9lgKX8IzQVQpDHZy7C', function=Function(arguments='{"port":"20","protocol":"tcp"}', name='query_risky_ports'), type='function')
Function Output: {
  "risk_id": "R002",
  "description": "FTP Data- Unencrypted file transfer",
  "cvss_score": "5.0",
  "ports": "20",
  "protocols": "tcp",
  "ip_ranges": "0.0.0.0/0",
  "last_updated": "2025-03-03T00:00:00Z"
}
Messages passed in second request:
{'role': 'system', 'content': '\nYou are a security risk assessment assistant that helps users evaluate whether an Azure Firewall rule is risky. You interact in natural language and use available tools when needed.\n\nYou have access to an Azure Cognitive Search index containing public risk data about IP ranges and ports. Use the appropriate function depending on what information is provided:\n- `query_risky_ip_ranges(ip_range)`: use this when only the IP address or IP range is provided.\n- `query_risky_ports(port, protocol)`: use 

## Review Firerules from CSV file

In [25]:
# Define the connection string and blob details
account_name = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") 
account_key = os.getenv("AZURE_STORAGE_ACCOUNT_KEY") 
account_url = f"https://{account_name}.dfs.core.windows.net"
file_system_name = "test"  # Replace with your container name
file_path = "testdata_existing_rules_artem.csv"  # Replace with the path to your CSV file in the container

In [26]:
# Authentication towards Storage account via Storage account key ###TODO: Change to identity based authentication at some point
service_client = DataLakeServiceClient(account_url=account_url, credential=account_key)
# List file systems (containers)
file_systems = service_client.list_file_systems()
for fs in file_systems:
    print("File system:", fs.name)

File system: risks
File system: test


In [None]:
# Get the file client
file_client = service_client.get_file_system_client(file_system=file_system_name).get_file_client(file_path)

# Download the file
download = file_client.download_file()
downloaded_bytes = download.readall()

# Check the data
df = pd.read_csv(BytesIO(downloaded_bytes))
print(df.head())

          Name     RuleType SourceAddresses Protocols DestinationAddresses  \
0   Allow_HTTP  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
1  Allow_HTTPS  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
2    Allow_RDP  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
3   Allow_ICMP  NetworkRule     10.0.0.0/24      ICMP          20.0.0.0/24   

   SourceIpGroups  DestinationIpGroups DestinationPorts  DestinationFqdns  
0             NaN                  NaN          80,8080               NaN  
1             NaN                  NaN              443               NaN  
2             NaN                  NaN             3389               NaN  
3             NaN                  NaN                *               NaN  


### Read CSV data

In [28]:
# Function to read the CSV from ADLS Gen2 and return as pandas dataframe
def read_csv_from_datalake(service_client, file_system_name, file_path):
    """
    Reads a CSV file from ADLS Gen2 using the given DataLakeServiceClient and returns a cleaned pandas DataFrame.
    """
    try:
        # Create the file client
        file_client = service_client.get_file_system_client(file_system_name).get_file_client(file_path)

        # Download and decode
        download = file_client.download_file()
        downloaded_bytes = download.readall()
        text_data = downloaded_bytes.decode("utf-8")

        # Split into lines
        lines = text_data.splitlines()
        raw_header = lines[0]
        column_names = [col.strip().replace('"', '') for col in raw_header.split(',')]
        data_rows = "\n".join(lines[1:])

        # Read data without header and assign cleaned column names
        df = pd.read_csv(StringIO(data_rows), sep=",", quotechar='"', header=None)
        df.columns = pd.Index(column_names).astype(str).str.strip().str.replace('"', '').str.replace("'", '').str.lower()

        return df

    except Exception as e:
        print(f"[ERROR] Failed to parse CSV from ADLS: {e}")
        return pd.DataFrame()


In [29]:
test_data = read_csv_from_datalake(service_client, file_system_name, "testdata_existing_rules_artem.csv")
print("Detected columns:", test_data.columns.tolist())
print("First few rows:")
print(test_data.head())


Detected columns: ['name', 'ruletype', 'sourceaddresses', 'protocols', 'destinationaddresses', 'sourceipgroups', 'destinationipgroups', 'destinationports', 'destinationfqdns']
First few rows:
          name     ruletype sourceaddresses protocols destinationaddresses  \
0   Allow_HTTP  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
1  Allow_HTTPS  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
2    Allow_RDP  NetworkRule     10.0.0.0/24       TCP          20.0.0.0/24   
3   Allow_ICMP  NetworkRule     10.0.0.0/24      ICMP          20.0.0.0/24   

   sourceipgroups  destinationipgroups destinationports  destinationfqdns  
0             NaN                  NaN          80,8080               NaN  
1             NaN                  NaN              443               NaN  
2             NaN                  NaN             3389               NaN  
3             NaN                  NaN                *               NaN  


### Construct user message

In [30]:
def create_message(system_message, source_ip, destination_ip, port, protocol):
    """
    Constructs a message payload for an LLM to evaluate firewall rule risk.
    """
    return [
        {"role": "system", "content": system_message.strip()},
        {
            "role": "user",
            "content": (
                f"Determine the risk of this firewall rule:\n"
                f"- Source IP: {source_ip}\n"
                f"- Destination IP: {destination_ip}\n"
                f"- Port: {port}\n"
                f"- Protocol: {protocol}"
            ),
        },
    ]

### Go through CSV row-by-row

In [31]:
for index, row in test_data.iterrows():
    source_ip = row.get("sourceaddresses", "unknown")
    destination_ip = row.get("destinationaddresses", "unknown")
    port = row.get("destinationports", "unknown")
    protocol = row.get("protocols", "unknown")

    message = create_message(system_message, source_ip, destination_ip, port, protocol)

    print(message)


    result = run_conversation(message, tools, available_functions, debug=True)

    print("Final response:")
    if isinstance(result, str):
        print(result)  # It's an error message or string content
    else:
        print(f"assessing the risk of {source_ip}, {destination_ip}, {port}, {protocol}")
        print(result.choices[0].message.content)  # It's a ChatCompletionMessage object
    #print(result.choices[0].message.content)

    

[{'role': 'system', 'content': 'You are a security risk assessment assistant that helps users evaluate whether an Azure Firewall rule is risky. You interact in natural language and use available tools when needed.\n\nYou have access to an Azure Cognitive Search index containing public risk data about IP ranges and ports. Use the appropriate function depending on what information is provided:\n- `query_risky_ip_ranges(ip_range)`: use this when only the IP address or IP range is provided.\n- `query_risky_ports(port, protocol)`: use this when only the port and protocol are provided.\n- `query_public_risk_data(source_ip, destination_ip, port, protocol)`: use this when all sourceIp, destination IP, port and protocol are provided.\n\nYour goals:\n- Evaluate firewall rules using available data.\n- Use CVSS score and descriptions to assess risk when available.\n- If no data is found or input is incomplete, ask follow-up questions or make an informed estimate.\n- Assign a confidence score (0–10