In [1]:
import torch
import pickle
import numpy as np
from transformers import BertTokenizer, BertModel
import torch.nn as nn

# ==================== ATTENTION LAYER (same as training) ====================
class AttentionLayer(nn.Module):
    def __init__(self, hidden_size=768):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.Tanh(),
            nn.Linear(256, 1)
        )

    def forward(self, bert_output):
        attention_scores = self.attention(bert_output)
        attention_weights = torch.softmax(attention_scores, dim=1)
        attended_output = torch.sum(attention_weights * bert_output, dim=1)
        return attended_output, attention_weights


# ==================== LOAD MODELS ====================
def load_models():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"💻 Using device: {device}")

    # Load BERT
    tokenizer = BertTokenizer.from_pretrained('tokenizer')
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    bert_model.load_state_dict(torch.load('bert_model.pth', map_location=device))
    bert_model.to(device).eval()

    # Load attention layer
    attention_layer = AttentionLayer()
    attention_layer.load_state_dict(torch.load('attention_layer.pth', map_location=device))
    attention_layer.to(device).eval()

    # Load XGBoost and label encoder
    with open('xgb_model.pkl', 'rb') as f:
        xgb_model = pickle.load(f)
    with open('label_encoder.pkl', 'rb') as f:
        label_encoder = pickle.load(f)

    print("✅ All models loaded successfully!")
    return bert_model, attention_layer, tokenizer, xgb_model, label_encoder, device


# ==================== FEATURE EXTRACTION ====================
def extract_features(texts, bert_model, attention_layer, tokenizer, device, max_len=128):
    bert_model.eval()
    attention_layer.eval()
    features = []
    with torch.no_grad():
        for text in texts:
            encoding = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=max_len,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            seq_output = outputs.last_hidden_state
            attended_output, _ = attention_layer(seq_output)
            features.append(attended_output.cpu().numpy().flatten())
    return np.array(features)


# ==================== PREDICT FUNCTION ====================
def predict(text, bert_model, attention_layer, tokenizer, xgb_model, label_encoder, device):
    feats = extract_features([text], bert_model, attention_layer, tokenizer, device)
    probs = xgb_model.predict_proba(feats)[0]
    pred_idx = np.argmax(probs)
    pred_label = label_encoder.inverse_transform([pred_idx])[0]
    confidence = probs[pred_idx]

    # top 3
    top3_idx = np.argsort(probs)[-3:][::-1]
    top3_labels = label_encoder.inverse_transform(top3_idx)
    top3_probs = probs[top3_idx]

    print(f"\n💬 Complaint: {text}")
    print(f"✅ Predicted Domain: {pred_label}")
    print(f"📊 Confidence: {confidence:.2%}")
    print("📈 Top 3 Predictions:")
    for i, (lbl, pr) in enumerate(zip(top3_labels, top3_probs), 1):
        print(f"   {i}. {lbl}: {pr:.2%}")


# ==================== MAIN ====================
if __name__ == "__main__":
    bert_model, attention_layer, tokenizer, xgb_model, label_encoder, device = load_models()

    # Try sample complaints
    samples = [
        "My credit card payment was declined at the store",
        "Unable to access my online banking account",
        "Loan approval is taking too long",
        "ATM not working properly"
    ]

    for s in samples:
        predict(s, bert_model, attention_layer, tokenizer, xgb_model, label_encoder, device)

    # Try your own
    print("\n🎤 Enter your own complaint:")
    user_text = input("Complaint: ").strip()
    if user_text:
        predict(user_text, bert_model, attention_layer, tokenizer, xgb_model, label_encoder, device)


💻 Using device: cpu
✅ All models loaded successfully!

💬 Complaint: My credit card payment was declined at the store
✅ Predicted Domain: credit card
📊 Confidence: 22.97%
📈 Top 3 Predictions:
   1. credit card: 22.97%
   2. debit card: 21.18%
   3. transaction failure: 11.72%

💬 Complaint: Unable to access my online banking account
✅ Predicted Domain: credit card
📊 Confidence: 32.44%
📈 Top 3 Predictions:
   1. credit card: 32.44%
   2. financial policies: 8.74%
   3. debit card: 7.68%

💬 Complaint: Loan approval is taking too long
✅ Predicted Domain: financial policies
📊 Confidence: 76.99%
📈 Top 3 Predictions:
   1. financial policies: 76.99%
   2. loan: 9.23%
   3. credit card: 8.44%

💬 Complaint: ATM not working properly
✅ Predicted Domain: atm
📊 Confidence: 57.11%
📈 Top 3 Predictions:
   1. atm: 57.11%
   2. debit card: 17.37%
   3. credit card: 8.09%

🎤 Enter your own complaint:


Complaint:  yours service is surreal



💬 Complaint: yours service is surreal
✅ Predicted Domain: customer service
📊 Confidence: 38.71%
📈 Top 3 Predictions:
   1. customer service: 38.71%
   2. credit card: 11.47%
   3. transaction failure: 11.15%
