In [6]:
import pandas as pd
import numpy as np
import pickle
import json
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer
from sklearn.metrics import f1_score, precision_score, recall_score
from tabulate import tabulate
from tqdm import tqdm

# Load trained preprocessing models
with open("scaler.pkl", "rb") as f:
    scaler = pickle.load(f)
with open("pca.pkl", "rb") as f:
    pca = pickle.load(f)

# Load frequency maps
with open("frequency_maps.pkl", "rb") as f:
    freq_maps = pickle.load(f)

predictor = Predictor(endpoint_name="kmeans", serializer=CSVSerializer())

# Load dataset
dataset = pd.read_csv("./fraudValidation.csv")

def encode_categorical_features(dataframe):
    """
    Apply frequency encoding using pre-saved frequency maps.
    Fill missing (unseen) values with 0.
    """
    cat_cols = dataframe.select_dtypes(include=["object"]).columns
    freq_frames = {}

    for col in cat_cols:
        if col in freq_maps:
            mapped_col = dataframe[col].map(freq_maps[col])
            freq_frames[col + "_freq"] = mapped_col.fillna(0)
        else:
            print(f"Warning: Column '{col}' not found in frequency map. Filling with 0.")
            freq_frames[col + "_freq"] = pd.Series(0, index=dataframe.index)

    if freq_frames:
        freq_df = pd.DataFrame(freq_frames, index=dataframe.index)
        dataframe = pd.concat([dataframe, freq_df], axis=1)

    dataframe.drop(columns=cat_cols, inplace=True)
    return dataframe

def preprocess(df):
    keep = ['trans_date_trans_time', 'cc_num', 'merchant', 'category', 'amt', 'zip', 'trans_num', 'is_fraud']
    df = df[keep].copy()
    df_encoded = encode_categorical_features(df.copy())
    X = df_encoded.drop(columns=['is_fraud'])
    y = df_encoded['is_fraud'].values

    print("✅ Columns used for scaling and PCA:")
    print(X.columns.tolist())

    X_scaled = scaler.transform(X)
    X_pca = pca.transform(X_scaled)
    return X_pca, y

# Inference
threshold = 3.0
X_pca, y_true = preprocess(dataset)

correct = 0
results = []

# Run inference with progress bar
for i, (row, actual) in enumerate(tqdm(zip(X_pca, y_true), total=len(y_true), desc="Inferencing..."), start=1):
    payload = ",".join(map(str, row)) + "\n"
    resp = json.loads(predictor.predict(payload))['predictions'][0]
    distance = resp.get('distance_to_cluster')
    predicted = 1 if distance > threshold else 0
    match = "✅" if predicted == actual else "❌"
    if predicted == actual:
        correct += 1
    results.append([i, actual, f"{distance:.4f}", predicted, match])

print(" Inference completed on all rows.\n")

# Output results
headers = ["Sample", "Actual", "Distance", "Predicted", "Match"]
print(tabulate(results, headers=headers))

# Accuracy
accuracy = correct / len(dataset)
print(f"\n Accuracy: {accuracy:.2%} ({correct}/{len(dataset)})")


✅ Columns used for scaling and PCA:
['cc_num', 'amt', 'zip', 'trans_date_trans_time_freq', 'merchant_freq', 'category_freq', 'trans_num_freq']


Inferencing...: 100%|██████████| 24/24 [00:00<00:00, 81.31it/s]


✅ Inference completed on all rows.

  Sample    Actual    Distance    Predicted  Match
--------  --------  ----------  -----------  -------
       1         1      3.1375            1  ✅
       2         1      3.2945            1  ✅
       3         0      2.9709            0  ✅
       4         0      2.8925            0  ✅
       5         0      3.8269            1  ❌
       6         0      0.7626            0  ✅
       7         1      3.6948            1  ✅
       8         0      2.3455            0  ✅
       9         1      3.354             1  ✅
      10         0      0.4069            0  ✅
      11         1      3.2889            1  ✅
      12         0      2.2282            0  ✅
      13         0      3.7642            1  ❌
      14         1      4.236             1  ✅
      15         1      3.609             1  ✅
      16         0      2.6543            0  ✅
      17         1      3.7471            1  ✅
      18         1      4.3039            1  ✅
      19      