In [1]:
import json

# Load json dataset from json
with open('edge-iiotset-ddos-train.json', 'r') as f:
    ddos_json_train = json.load(f)

with open('edge-iiotset-normal-train.json', 'r') as f:
    normal_json_train = json.load(f)

with open('edge-iiotset-ddos-test.json', 'r') as f:
    ddos_json_test = json.load(f)

with open('edge-iiotset-normal-test.json', 'r') as f:
    normal_json_test = json.load(f)

In [10]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
import tiktoken
import os
import dotenv

# OpenAI models
# 1. gpt-4o | gpt-4o-2024-08-06 
# 2. gpt-4o-mini | gpt-4o-mini-2024-07-18
# 3. gpt-3.5-turbo-0125

dotenv.load_dotenv()
API_KEY = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(openai_api_key=API_KEY, model="gpt-3.5-turbo-0125", temperature=0.0)
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

embeddings = OllamaEmbeddings(model="all-minilm")
vector_store = Chroma(
    collection_name="edge-iiotset",
    embedding_function=embeddings, 
    persist_directory="./chroma_langchain_db")
retriever = vector_store.as_retriever(
    search_type="mmr", 
    search_kwargs={"k": 5, "fetch_k": 5})

def format_docs(docs):
    return "\n\n".join(doc.page_content + "-->" + ("ATTACK" if doc.metadata["label"] == "ddos" else "BENIGN") for doc in docs)

def predict(x):
    benign_samples = ""
    attack_samples = ""
    for i in range(10):
        benign_samples += str(normal_json_train[i].values()) + "-->BENIGN\n"
        attack_samples += str(ddos_json_train[i].values()) + "-->ATTACK\n"
    system_prompt = (
        "You are intelligent network log analyzer."
        "You will be given a network log to predict ATTACK or BENIGN."
        "Use the example network logs given to predict the label."
        "Output the label ATTACK or BENIGN, nothing else."
        "\n\n"
        "Fields:" + str(normal_json_train[0].keys()) + "\n"
        "```" + benign_samples + attack_samples + "```"
    )
    messages = []
    messages.append(("system", system_prompt))
    messages.append(("user", "{input}"))
    # num_tokens = len(encoding.encode(str(messages)))
    # print("Num tokens:", num_tokens)
    prompt = ChatPromptTemplate.from_messages(messages)
    chain = (
        prompt 
        | llm 
        | StrOutputParser()
    )
    return chain.invoke({"input": x})

In [11]:
from sklearn.metrics import classification_report
from tqdm import tqdm

sample_size = 10
y_pred = []
y_true = []
for i in tqdm(range(sample_size), ncols=100, desc="Predicting attack entries..."):
    y = predict(str(ddos_json_train[i].values()))
    if y == "ATTACK":
        y_pred.append(1)
    else:
        y_pred.append(0)
    y_true.append(1)

print(classification_report(y_true, y_pred))

Predicting attack entries...: 100%|█████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]

              precision    recall  f1-score   support

           1       1.00      1.00      1.00        10

    accuracy                           1.00        10
   macro avg       1.00      1.00      1.00        10
weighted avg       1.00      1.00      1.00        10






In [12]:
sample_size = 10
y_pred = []
y_true = []
for i in tqdm(range(sample_size), ncols=100, desc="Predicting benign entries..."):
    y = predict(str(normal_json_train[i].values()))
    if y == "BENIGN":
        y_pred.append(1)
    else:
        y_pred.append(0)
    y_true.append(1)

print(classification_report(y_true, y_pred))

Predicting benign entries...: 100%|█████████████████████████████████| 10/10 [00:08<00:00,  1.19it/s]

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       1.00      0.90      0.95        10

    accuracy                           0.90        10
   macro avg       0.50      0.45      0.47        10
weighted avg       1.00      0.90      0.95        10




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
