<a href="https://colab.research.google.com/github/kushalshah0/Detecting-AI-Generated-Phishing-Emails-Using-BERT/blob/main/ai_generated_phishing_email_detection_FastAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Mount Google Drive
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
#@title Install Dependencies
!pip install fastapi uvicorn pyngrok transformers torch tensorflow pickle-mixin

Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Collecting pickle-mixin
  Downloading pickle-mixin-1.0.2.tar.gz (5.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading pyngrok-7.5.0-py3-none-any.whl (24 kB)
Building wheels for collected packages: pickle-mixin
  Building wheel for pickle-mixin (setup.py) ... [?25l[?25hdone
  Created wheel for pickle-mixin: filename=pickle_mixin-1.0.2-py3-none-any.whl size=5988 sha256=105804ba2e3184bb8ed1158dda336614b7614f958ee902805ff1fa164eb607a4
  Stored in directory: /root/.cache/pip/wheels/69/e2/5c/da8f96a08c63469bc8b10e206cd4c78e8886d8acb8699f84c2
Successfully built pickle-mixin
Installing collected packages: pickle-mixin, pyngrok
Successfully installed pickle-mixin-1.0.2 pyngrok-7.5.0


In [3]:
#@title Verify model paths
import os

SAMPLE_NAME = "sample1"

BASE_PATH = f"/content/drive/MyDrive/Detect_AI_Phishing_Project/{SAMPLE_NAME}"

paths = {
    "LSTM": f"{BASE_PATH}/lstm_model.pt",
    "GRU": f"{BASE_PATH}/gru_model.pt",
    "BERT": f"{BASE_PATH}/bert/final_model",
    "Tokenizer": f"{BASE_PATH}/rnn_tokenizer.pkl"
}

for k, v in paths.items():
    print(k, "✅" if os.path.exists(v) else "❌", v)


LSTM ✅ /content/drive/MyDrive/Detect_AI_Phishing_Project/sample1/lstm_model.pt
GRU ✅ /content/drive/MyDrive/Detect_AI_Phishing_Project/sample1/gru_model.pt
BERT ✅ /content/drive/MyDrive/Detect_AI_Phishing_Project/sample1/bert/final_model
Tokenizer ✅ /content/drive/MyDrive/Detect_AI_Phishing_Project/sample1/rnn_tokenizer.pkl


In [4]:
#@title Create FastAPI project structure
import os

API_DIR = "/content/api"
os.makedirs(API_DIR, exist_ok=True)

files = ["main.py", "models.py", "schemas.py"]
for f in files:
    with open(os.path.join(API_DIR, f), "w") as fp:
        fp.write("")

print("FastAPI files created:", files)

FastAPI files created: ['main.py', 'models.py', 'schemas.py']


In [5]:
#@title Write schemas.py
%%writefile /content/api/schemas.py
from pydantic import BaseModel

class EmailRequest(BaseModel):
    text: str
    model: str  # bert | lstm | gru

class PredictionResponse(BaseModel):
    model: str
    prediction: str
    confidence: float

Overwriting /content/api/schemas.py


In [6]:
#@title Write models.py
%%writefile /content/api/models.py
import torch
import torch.nn as nn
import pickle
import numpy as np
from transformers import BertTokenizer, BertForSequenceClassification

SAMPLE_NAME = "sample1"
BASE_PATH = f"/content/drive/MyDrive/Detect_AI_Phishing_Project/{SAMPLE_NAME}"
DEVICE = torch.device("cpu")

# Define LSTM Model Architecture (assuming a basic setup)
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        _, (hidden, _) = self.lstm(embedded)
        return self.fc(hidden.squeeze(0))

# Define GRU Model Architecture (assuming a basic setup)
class GRUModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        _, hidden = self.gru(embedded)
        return self.fc(hidden.squeeze(0))

# LOAD TOKENIZER
with open(f"{BASE_PATH}/rnn_tokenizer.pkl", "rb") as f:
    rnn_tokenizer = pickle.load(f)

MAX_LEN = 200
# Corrected model parameters based on checkpoint (from error message)
MODEL_VOCAB_SIZE = 20000    # As indicated by embedding.weight shape in error
EMBEDDING_DIM = 128         # As indicated by embedding.weight shape in error
HIDDEN_DIM = 128            # As indicated by lstm.weight_ih_l0 shape in error
OUTPUT_DIM = 1              # Binary classification

# LOAD LSTM
lstm_model = LSTMModel(MODEL_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
lstm_model.load_state_dict(torch.load(f"{BASE_PATH}/lstm_model.pt", map_location=DEVICE))
lstm_model.eval()

# LOAD GRU
gru_model = GRUModel(MODEL_VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
gru_model.load_state_dict(torch.load(f"{BASE_PATH}/gru_model.pt", map_location=DEVICE))
gru_model.eval()

# LOAD BERT
bert_tokenizer = BertTokenizer.from_pretrained(
    f"{BASE_PATH}/bert/final_model"
)
bert_model = BertForSequenceClassification.from_pretrained(
    f"{BASE_PATH}/bert/final_model"
).to(DEVICE)
bert_model.eval()

# HELPERS
def preprocess_rnn(text):
    seq = rnn_tokenizer.texts_to_sequences([text])
    # Map token IDs >= MODEL_VOCAB_SIZE to 0 (assuming 0 is OOV/padding)
    processed_seq = [[token_id if token_id < MODEL_VOCAB_SIZE else 0 for token_id in s] for s in seq]

    padded = np.zeros((1, MAX_LEN))
    # Ensure sequence is not longer than MAX_LEN
    padded[0, :min(MAX_LEN, len(processed_seq[0]))] = processed_seq[0][:MAX_LEN]
    return torch.tensor(padded, dtype=torch.long)

def predict_rnn(model, text):
    with torch.no_grad():
        x = preprocess_rnn(text)
        output = model(x)
        prob = torch.sigmoid(output).item()
        label = "Phishing" if prob >= 0.5 else "Legitimate"
        confidence = prob if prob >= 0.5 else 1 - prob
        return label, confidence

def predict_bert(text):
    inputs = bert_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(DEVICE)
    with torch.no_grad():
        outputs = bert_model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)
        conf, pred = torch.max(probs, dim=1)
        label = "Phishing" if pred.item() == 1 else "Legitimate"
        return label, conf.item()

Overwriting /content/api/models.py


In [7]:
#@title Write main.py
%%writefile /content/api/main.py
from fastapi import FastAPI, HTTPException
from schemas import EmailRequest, PredictionResponse
from models import predict_rnn, predict_bert, lstm_model, gru_model

app = FastAPI(
    title="AI-Generated Phishing Detection API",
    version="1.0"
)

@app.post("/predict", response_model=PredictionResponse)
def predict(request: EmailRequest):
    text = request.text
    model_name = request.model.lower()

    if model_name == "bert":
        label, confidence = predict_bert(text)

    elif model_name == "lstm":
        label, confidence = predict_rnn(lstm_model, text)

    elif model_name == "gru":
        label, confidence = predict_rnn(gru_model, text)

    else:
        raise HTTPException(
            status_code=400,
            detail="Invalid model. Choose from: bert, lstm, gru"
        )

    return PredictionResponse(
        model=model_name,
        prediction=label,
        confidence=round(confidence, 4)
    )

Overwriting /content/api/main.py


In [13]:
#@title Run FastAPI
import subprocess
import time

# Kill any processes running on port 8000 or any uvicorn process
!pkill -f uvicorn || true
!fuser -k 8000/tcp || true

%cd /content/api

uvicorn_process = subprocess.Popen(
    ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE
)

print("FastAPI server starting...")
time.sleep(5)
print("FastAPI server should be running.")

^C
/content/api
FastAPI server starting...
FastAPI server should be running.


In [None]:
#@title Display FastAPI Server Logs
!cat nohup.out

In [16]:
#@title Test API Locally (with model selector)
import requests

API_URL = "http://127.0.0.1:8000/predict"

email_text = "Dear Customer,   Your bank account has been temporarily suspended due to suspicious activity.  Please click the link below to verify your account immediately:  http://verify-bank-login.com" #@param {type:"string"}
model_selector = "bert" #@param ["bert", "lstm", "gru"]

payload = {
    "text": email_text,
    "model": model_selector
}

response = requests.post(API_URL, json=payload)

print("Status Code:", response.status_code)
print("Response:", response.json())

Status Code: 200
Response: {'model': 'bert', 'prediction': 'Phishing', 'confidence': 0.9971}


In [None]:
#@title Configure ngrok and Expose FastAPI
from pyngrok import ngrok

ngrok.set_auth_token("ngrok_auth_token")
print("ngrok authtoken set.")

ngrok.kill()

public_url_tunnel = ngrok.connect(8000)
print("Public API URL:", public_url_tunnel.public_url)
print("Swagger Docs:", public_url_tunnel.public_url + "/docs")

ngrok authtoken set.
Public API URL: https://albert-unsheeting-lacteally.ngrok-free.dev
Swagger Docs: https://albert-unsheeting-lacteally.ngrok-free.dev/docs


In [None]:
#@title Test API via ngrok URL
import requests

# Ensure public_url_tunnel is available from the previous cell execution
if 'public_url_tunnel' not in locals():
    print("Error: ngrok tunnel not established. Please run the 'Configure ngrok and Expose FastAPI' cell first.")
else:
    NGROK_API_URL = public_url_tunnel.public_url + "/predict"

    # Use the same input parameters as the local test
    email_text = "Dear Customer,   Your bank account has been temporarily suspended due to suspicious activity.  Please click the link below to verify your account immediately:  http://verify-bank-login.com" #@param {type:"string"}
    model_selector = "bert" #@param ["bert", "lstm", "gru"]

    payload = {
        "text": email_text,
        "model": model_selector
    }

    try:
        response = requests.post(NGROK_API_URL, json=payload)
        print("Status Code:", response.status_code)
        print("Response:", response.json())
    except requests.exceptions.RequestException as e:
        print(f"Error connecting to ngrok URL: {e}")

Status Code: 200
Response: {'model': 'bert', 'prediction': 'Phishing', 'confidence': 0.9971}
