In [4]:
import pandas as pd
import google.generativeai as genai
import os
import chromadb
import uuid
from collections import Counter 
import random
from dotenv import load_dotenv

In [None]:
# >>>>READ<<<<: 
# ADD GOOGLE API KEY in ./.env file
# GOOGLE_API_KEY=...
load_dotenv()

In [8]:
# init
train_data = pd.read_csv("./train.csv")
test_data = pd.read_csv("./test.csv")
# prepare google genai
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])

In [None]:
# prepare dataset for embeddings
# new doc
def mod_func(row):
    row['document'] = f"""The following complaint document written by a user is categorized as '{row.category}' and sub-categorized as '{row.sub_category}':
{row.crimeaditionalinfo}
    """
    return row

prepared_train_data =  train_data.apply(mod_func, axis=1)
prepared_train_data['combined_length'] = prepared_train_data['category'].fillna('').str.len() + \
                        prepared_train_data['sub_category'].fillna('').str.len() + \
                        prepared_train_data['crimeaditionalinfo'].fillna('').str.len()

prepared_train_data

In [None]:
# sanity check that if any row exceeds embeddings context length
embeddings_context_length = 2048
if len(prepared_train_data[prepared_train_data.combined_length > embeddings_context_length]) > 1:
    prepared_train_data[prepared_train_data.combined_length > embeddings_context_length]
else:
    print("Success! All rows are within the context length.")

In [None]:
# prepare vectordb

# start chromdb server
# chroma run --host localhost --port 8000 --path ./my_chroma_data

embeddings_collection = None
def collection_exists(collection_name):
    existing_collections = client.list_collections()
    return any(collection.name == collection_name for collection in existing_collections)

client = chromadb.HttpClient()
collection_name = "cyber-ai-embeddings"
if not collection_exists(collection_name):
    embeddings_collection = client.create_collection(collection_name)
else:
    embeddings_collection = client.get_collection(collection_name)
print("Collection created...")

In [15]:
%%script true
# >>>>PLEASE READ!!!!!<<<<
# THIS CELL EXECUTION SHOULD BE A ONE TIME PROCESS
# REMOVE THE ABOVE "%%SCRIPT TRUE" TO RUN THIS CELL
# THIS CELL PREPARES AND POPULATES ALL THE VECTOR EMBEDDINGS OF THE TRAIN DATA
# 
step = 200
max_length = len(prepared_train_data)
for i in range(0, max_length, step):
    start_idx = i
    end_idx = i + step
    documents = []
    metadatas = []
    document_ids = []
    for index, row in prepared_train_data[start_idx:end_idx].iterrows():
        documents.append(row.document)
        metadatas.append({"category": row.category, "scategory": row.sub_category})
        document_ids.append(str(uuid.uuid4()))

    result = genai.embed_content(model="models/text-embedding-004", content=documents)

    print("inserting results in vectordb")
    embeddings = result["embedding"]
    data_to_insert = embeddings_collection.add(
        documents=documents,
        metadatas=metadatas,
        ids=document_ids,
        embeddings=embeddings,
    )
    print(f"data inserted from index {start_idx} to {end_idx}")
print("Vector embeddings prepared...")

In [16]:
def predict_categories(document,sample_size = 1):
    predicted_category = ""
    predicted_scategory = ""
    result = genai.embed_content(model="models/text-embedding-004", content=[document])
    test_embeddings = result['embedding'][0]
    # query db
    query_result = embeddings_collection.query(query_embeddings=test_embeddings,n_results=sample_size)
    # print results
    
    sampled_categories = []
    sampled_scategories = []
    for i in range(sample_size):
        if 'category' in query_result['metadatas'][0][0]:
            sampled_categories.append(query_result['metadatas'][0][0]['category'])
        else:
            sampled_categories.append(None)
        if 'scategory' in query_result['metadatas'][0][0]:
            sampled_scategories.append(query_result['metadatas'][0][0]['scategory'])
        else:
            sampled_scategories.append(None)
            
    
    # get the most occurred
    predicted_category = Counter(sampled_categories).most_common(1)[0][0]
    predicted_scategory = Counter(sampled_scategories).most_common(1)[0][0]
    return (predicted_category,predicted_scategory,sampled_categories,sampled_scategories)

In [39]:
# TEST SINGLE
test_idx = int(len(test_data)*random.random())
row = test_data.iloc[test_idx]
print(f"Expected Category: {row.category}")
print(f"Expected Sub-category: {row.sub_category}")
predict_category,predict_scategory,sampled_categories,sampled_scategories  =predict_categories(f"""The following complaint document written by a user is categorized as '{row.category}' and sub-categorized as '{row.sub_category}':
{row.crimeaditionalinfo}""",10)

print(f"Predicted Category: {predict_category}")
print(f"Predicted Sub-category: {predict_scategory}")

Expected Category: Online Financial Fraud
Expected Sub-category: Fraud CallVishing
Predicted Category: Online Financial Fraud
Predicted Sub-category: Fraud CallVishing


In [None]:
# TEST ALL
sample_size = 50
test_test_data = test_data.sample(n=sample_size).reset_index(drop=True)
failure_indexes = []
failure_indexes_category = []
failure_indexes_scategory = []
failure_predicted_categories = []
for index, row in test_test_data.iterrows():
    print(f"{index+1}/{sample_size}",end='\r')
    document = f"""The following complaint document written by a user is categorized as '{row.category}' and sub-categorized as '{row.sub_category}':
{row.crimeaditionalinfo}"""
    predict_category,predict_scategory,sampled_categories,sampled_scategories = predict_categories(document,10)
    
    if row.category == predict_category:
        if row.sub_category == predict_scategory or pd.isna(row.sub_category) or row.sub_category == None:
            
            continue
        else:
            print("Failed Sub-category---------")
            print(f"Index: {index}")
            print(f"Expected: {row.sub_category}")
            print(f"Predicted: {predict_scategory}")
            print("----------------------------")
            failure_indexes_scategory.append(index)
    else:
        print("Failed Category-------------")
        print(f"Index: {index}")
        print(f"Expected: {row.category}")
        print(f"Predicted: {predict_category}")
        print("----------------------------")
        failure_indexes_category.append(index)
    
    failure_predicted_categories.append([predict_category,predict_scategory])
    failure_indexes.append(index)
        
print(f"Success Rate Overall: {100 - len(failure_indexes)/sample_size*100}%") 
print(f"Success Rate (Category only): {100 - len(failure_indexes_category)/sample_size*100}%") 
print(f"Success Rate (Sub-category only): {100 - len(failure_indexes_scategory)/sample_size*100}%") 