In [1]:
import os
import re
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Tuple
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

class CyberCrimeDataset(Dataset):
    """Dataset class for cyber crime text classification"""
    def __init__(self, texts: List[str], main_categories: List[str], 
                 categories: List[str], sub_categories: List[str],
                 sub_category_names: List[str]):
        self.texts = texts
        self.main_categories = main_categories
        self.categories = categories
        self.sub_categories = sub_categories
        self.sub_category_names = sub_category_names
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return {
            'text': self.texts[idx],
            'category_names': self.main_categories[idx],
            'retagged_category': self.categories[idx],
            'retagged_sub_category': self.sub_categories[idx],
            'sub_category_names': self.sub_category_names[idx]
        }

def custom_collate(batch):
    """Custom collate function to handle batch processing"""
    return {
        'text': [item['text'] for item in batch],
        'category_names': [item['category_names'] for item in batch],
        'retagged_category': [item['retagged_category'] for item in batch],
        'retagged_sub_category': [item['retagged_sub_category'] for item in batch],
        'sub_category_names': [item['sub_category_names'] for item in batch]
    }

def evaluate_predictions(y_true: List[str], y_pred: List[str], level_name: str):
    """
    Evaluate predictions for a given hierarchy level
    
    Args:
        y_true: List of true labels
        y_pred: List of predicted labels
        level_name: Name of the hierarchy level being evaluated
    """
    print(f"\n=== {level_name} Metrics ===")
    
    # Filter out None values and mismatched entries
    valid_indices = [i for i in range(len(y_true)) 
                     if y_true[i] is not None 
                     and y_pred[i] is not None]
    
    # Create filtered lists
    filtered_true = [y_true[i] for i in valid_indices]
    filtered_pred = [y_pred[i] for i in valid_indices]
    
    # Validate filtered lists
    if not filtered_true or not filtered_pred:
        print(f"No valid predictions for {level_name}")
        return 0
    
    # Calculate accuracy
    accuracy = accuracy_score(filtered_true, filtered_pred)
    print(f"Accuracy: {accuracy:.4f}")
    
    # Generate and print classification report
    print("\nClassification Report:")
    print(classification_report(filtered_true, filtered_pred))
    
    # Generate confusion matrix
    cm = confusion_matrix(filtered_true, filtered_pred)
    print("\nConfusion Matrix:")
    print(cm)
    
    return accuracy

def process_batch(batch: Dict, encoder, models: Dict, selectors: Dict, 
                 label_encoders: Dict, category_to_sub_category: Dict,
                 master_mapper: Dict) -> Dict[str, List[str]]:
    """
    Process a batch of texts through the model hierarchy
    
    Args:
        batch: Dictionary containing batch data
        encoder: Sentence transformer encoder
        models: Dictionary of trained models
        selectors: Dictionary of feature selectors
        label_encoders: Dictionary of label encoders
        category_to_sub_category: Mapping of categories to subcategories
        master_mapper: Master mapping dictionary
    
    Returns:
        Dictionary containing predictions for all hierarchy levels
    """
    texts = [str(text).lower() for text in batch['text']]
    
    batch_results = {
        'pred_category_names': [],
        'pred_retagged_category': [],
        'pred_retagged_sub_category': [],
        'pred_sub_category_names': []
    }
    
    try:
        # Encode all texts in batch
        text_embeddings = encoder.encode(texts, show_progress_bar=False)
        text_embeddings = text_embeddings.reshape(len(texts), -1)
        
        # Predict main categories
        main_features = selectors['category_names'].transform(text_embeddings)
        main_cat_pred = models['category_names'].predict(main_features)
        main_categories = label_encoders['category_names'].inverse_transform(main_cat_pred)
        
        # Process each text in batch
        for idx, category_names in enumerate(main_categories):
            batch_results['pred_category_names'].append(category_names)
            
            try:
                # Predict category
                category_model_key = f'category_{category_names.replace(" ", "_").replace("/", "_").replace("&", "and")}'
                single_embedding = text_embeddings[idx:idx+1]
                
                category_features = selectors[category_model_key].transform(single_embedding)
                cat_pred = models[category_model_key].predict(category_features)
                category = label_encoders[category_model_key].inverse_transform(cat_pred)[0]
                
                # Predict subcategory
                if category in category_to_sub_category and len(category_to_sub_category[category]) > 1:
                    sub_category_names_model_key = f'sub_category_names_{category.replace(" ", "_").replace("/", "_").replace("&", "and")}'
                    sub_features = selectors[sub_category_names_model_key].transform(single_embedding)
                    mapped_sub_cat_pred = models[sub_category_names_model_key].predict(sub_features)
                    sub_category_names = label_encoders[sub_category_names_model_key].inverse_transform(mapped_sub_cat_pred)[0]
                    sub_category = find_immediate_key(master_mapper, sub_category_names)
                else:
                    sub_category_names = category_to_sub_category[category][0]
                    sub_category = find_immediate_key(master_mapper, sub_category_names)
                
            except KeyError as e:
                print(f"Warning: Model not found for prediction chain: {e}")
                category = "unknown"
                sub_category_names = "unknown"
                sub_category = "unknown"
            
            batch_results['pred_retagged_category'].append(category)
            batch_results['pred_retagged_sub_category'].append(sub_category)
            batch_results['pred_sub_category_names'].append(sub_category_names)
    
    except Exception as e:
        print(f"Error in processing batch: {str(e)}")
        # Fill with unknowns for this batch
        batch_size = len(texts)
        for key in batch_results:
            batch_results[key].extend(['unknown'] * batch_size)
    
    return batch_results

def run_inference_pipeline(test_df: pd.DataFrame, 
                         encoder, models: Dict, 
                         selectors: Dict, 
                         label_encoders: Dict,
                         category_to_sub_category: Dict,
                         master_mapper: Dict,
                         batch_size: int = 64) -> pd.DataFrame:
    """
    Run the complete inference pipeline
    
    Args:
        test_df: DataFrame containing test data
        encoder: Sentence transformer encoder
        models: Dictionary of trained models
        selectors: Dictionary of feature selectors
        label_encoders: Dictionary of label encoders
        category_to_sub_category: Mapping of categories to subcategories
        master_mapper: Master mapping dictionary
        batch_size: Batch size for processing
        
    Returns:
        DataFrame containing all predictions and metrics
    """
    # Create dataset and dataloader
    dataset = CyberCrimeDataset(
        texts=test_df['content_processed'].apply(lambda x: str(x).lower()).tolist(),
        main_categories=test_df.get('category_names', ['unknown'] * len(test_df)),
        categories=test_df['retagged_category'].tolist(),
        sub_categories=test_df['retagged_sub_category'].tolist(),
        sub_category_names=test_df["sub_category_names"].tolist(),
    )

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=0
    )

    # Initialize results dictionary
    results = {
        'true_category_names': [],
        'pred_category_names': [],
        'true_retagged_category': [],
        'pred_retagged_category': [],
        'true_retagged_sub_category': [],
        'pred_retagged_sub_category': [],
        'true_sub_category_names': [],
        'pred_sub_category_names': []
    }

    # Process batches
    print("\nProcessing batches...")
    total_batches = len(dataloader)

    with tqdm(total=total_batches, desc="Processing") as pbar:
        for batch_idx, batch in enumerate(dataloader):
            # Get predictions for batch
            batch_predictions = process_batch(
                batch, encoder, models, selectors, 
                label_encoders, category_to_sub_category, 
                master_mapper
            )
            
            # Store results
            if 'category_names' in batch:
                results['true_category_names'].extend(batch['category_names'])
            results['pred_category_names'].extend(batch_predictions['pred_category_names'])
            results['true_retagged_category'].extend(batch['retagged_category'])
            results['pred_retagged_category'].extend(batch_predictions['pred_retagged_category'])
            results['true_sub_category_names'].extend(batch['sub_category_names'])
            results['pred_sub_category_names'].extend(batch_predictions['pred_sub_category_names'])
            results['true_retagged_sub_category'].extend(batch['retagged_sub_category'])
            results['pred_retagged_sub_category'].extend(batch_predictions['pred_retagged_sub_category'])
            
            # Update progress and show intermediate metrics
            pbar.update(1)
            if (batch_idx + 1) % max(1, total_batches // 10) == 0:
                show_intermediate_metrics(results, batch_idx, total_batches, pbar)

    # Convert results to DataFrame and calculate final metrics
    results_df = pd.DataFrame(results)
    results_df.to_csv('prediction_results.csv', index=False)
    
    print("\nCalculating final metrics...")
    calculate_final_metrics(results_df)
    
    return results_df

def show_intermediate_metrics(results: Dict, batch_idx: int, 
                            total_batches: int, pbar: tqdm):
    """Show intermediate metrics during batch processing"""
    pbar.write(f"\nBatch {batch_idx + 1}/{total_batches}")
    
    for level in ['retagged_category', 'category_names', 'sub_category_names', 'retagged_sub_category']:
        true_key = f'true_{level}'
        pred_key = f'pred_{level}'
        
        # Ensure both keys exist and have data
        if true_key in results and pred_key in results:
            # Truncate both lists to the shorter length to ensure matching
            true_values = results[true_key][:len(results[pred_key])]
            pred_values = results[pred_key]
            
            # Filter out None values
            valid_indices = [i for i in range(len(true_values)) 
                             if true_values[i] is not None 
                             and pred_values[i] is not None]
            
            if valid_indices:
                filtered_true = [true_values[i] for i in valid_indices]
                filtered_pred = [pred_values[i] for i in valid_indices]
                
                if filtered_true and filtered_pred:
                    current_accuracy = accuracy_score(filtered_true, filtered_pred)
                    pbar.write(f"Current {level} Accuracy: {current_accuracy:.4f}")

def clean_json_mapping(json_mapping):
    """
    Cleans a JSON mapping by replacing special characters and spaces with * or _,
    and removing consecutive special characters.
    
    Args:
        json_mapping (dict): Input JSON mapping to clean
        
    Returns:
        dict: Cleaned JSON mapping
    """
    def clean_string(s):
        if not isinstance(s, str):
            return s
        
        # Replace spaces with underscore
        s = s.replace(' ', '_')
        
        # Replace special characters with asterisk
        s = re.sub(r'[^a-zA-Z0-9_.]', '*', s)
        s = s.replace(".","")
        
        # Remove consecutive special characters
        s = re.sub(r'[*_]+', lambda m: '_' if '_' in m.group() else '_', s)
        
        return s
    
    def process_value(value):
        if isinstance(value, dict):
            return {clean_string(k): process_value(v) for k, v in value.items()}
        elif isinstance(value, list):
            return [clean_string(item) for item in value]
        else:
            return clean_string(value)
    
    return process_value(json_mapping)

def calculate_final_metrics(results_df: pd.DataFrame):
    """Calculate and display final metrics for all hierarchy levels"""
    if 'true_category_names' in results_df.columns and results_df['true_category_names'].iloc[0] != 'unknown':
        evaluate_predictions(
            results_df['true_category_names'],
            results_df['pred_category_names'],
            'Main Category'
        )

    evaluate_predictions(
        results_df['true_retagged_category'],
        results_df['pred_retagged_category'],
        'retagged_category'
    )

    evaluate_predictions(
        results_df['true_sub_category_names'],
        results_df['pred_sub_category_names'],
        'Mapped Sub-Category'
    )

    evaluate_predictions(
        results_df['true_retagged_sub_category'],
        results_df['pred_retagged_sub_category'],
        'Sub-Category'
    )

    # Calculate overall accuracy across all levels
    overall_accuracy = (
        (results_df['true_category_names'] == results_df['pred_category_names']) &
        (results_df['true_retagged_category'] == results_df['pred_retagged_category']) &
        (results_df['true_sub_category_names'] == results_df['pred_sub_category_names'])
    ).mean()

    print("\n=== Overall Results ===")
    print(f"Complete Hierarchy Accuracy: {overall_accuracy:.4f}")
    
# Mappings
category_names_to_category = clean_json_mapping({
            "women/child related crime": [
                "child pornography cpchild sexual abuse material csam",
                "crime against women & children",
                "online cyber trafficking",
                "rapegang rape rgrsexually abusive content",
                "sexually explicit act",
                "sexually obscene material"
            ],
            "financial fraud crimes": [
                "cryptocurrency crime",
                "online financial fraud",
                "online gambling  betting"
            ],
            "other cyber crime": [
                "any other cyber crime",
                "cyber attack/ dependent crimes",
                "cyber terrorism",
                "hacking  damage to computercomputer system etc",
                "online and social media related crime",
                "report unlawful content"
            ]
        })

category_to_sub_category = clean_json_mapping({
            "any other cyber crime": [
                "other",
                "supply chain attacks"
            ],
            "child pornography cpchild sexual abuse material csam": [
                "child pornography cpchild sexual abuse material csam"
            ],
            "crime against women & children": [
                "sexual harassment",
                "computer generated csam/csem"
            ],
            "cryptocurrency crime": [
                "cryptocurrency fraud"
            ],
            "cyber attack/ dependent crimes": [
                "sql injection",
                "ransomware attack",
                "malware attack",
                "malicious code attacks (specifically mentioning virus, worm, trojan, bots, spyware, cryptominers)",
                "data breach/theft",
                "data leaks",
                "hacking/defacement",
                "zero-day exploits",
                "malicious mobile app attacks",
                "denial of service (dos)/distributed denial of service (ddos) attacks",
                "tampering with computer source documents"
            ],
            "cyber terrorism": [
                "cyber terrorism",
                "cyber espionage"
            ],
            "hacking  damage to computercomputer system etc": [
                "email hacking",
                "unauthorised accessdata breach",
                "compromise of critical systems/information",
                "targeted scanning/probing of critical networks/systems",
                "attacks on servers (database mail dns) and network devices (routers)",
                "attacks on critical infrastructure, scada, operational technology systems, and wireless networks",
                "attacks or suspicious activities affecting cloud computing systems servers software and applications",
                "attacks or malicious suspicious activities affecting systems related to big data blockchain virtual assets and robotics",
                "attacks on internet of things (iot) devices and associated systems, networks, and servers",
                "attacks on systems related to artificial intelligence (ai) and machine learning (ml)",
                "damage to computer computer systems etc",
                "web application vulnerabilities",
            ],
            "online cyber trafficking": [
                "online trafficking"
            ],
            "online financial fraud": [
                "upi related frauds",
                "aadhar enabled payment system (aeps) fraud",
                "business email compromiseemail takeover",
                "debitcredit card fraudsim swap fraud",
                "ewallet related fraud",
                "fraud callvishing",
                "internet banking related fraud",
                "attacks or incidents affecting digital payment systems"
            ],
            "online gambling  betting": [
                "online gambling  betting"
            ],
            "online and social media related crime": [
                "intimidating email",
                "provocative speech for unlawful acts",
                "email phishing",
                "online job fraud",
                "profile hacking identity theft",
                "identity theft, spoofing, and phishing attacks",
                "unauthorized social media access",
                "cheating by impersonation",
                "fake mobile apps",
                "online matrimonial fraud",
                "cyber bullying  stalking  sexting",
                "fakeimpersonating profile"
            ],
            "rapegang rape rgrsexually abusive content": [
                "rapegang rape rgrsexually abusive content"
            ],
            "report unlawful content": [
                "against interest of sovereignty or integrity of india",
                "disinformation or misinformation campaigns"
            ],
            "sexually explicit act": [
                "sexually explicit act"
            ],
            "sexually obscene material": [
                "sale publishing and transmitting obscene material/sexually explicit material"
            ]
        })

def preprocess_text(text):
    text = str(text).lower()
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    return text

def load_models(models_path='models/') -> Tuple[object, Dict, Dict, Dict]:
    """
    Load all saved models, encoders, and selectors from the specified path
    
    Args:
        models_path (str): Path to directory containing saved models
        
    Returns:
        Tuple containing:
        - sentence encoder
        - dictionary of trained models
        - dictionary of label encoders
        - dictionary of feature selectors
    """
    print("Loading models...")
    models = {}
    label_encoders = {}
    selectors = {}
    
    # Load the same sentence transformer used in training
    encoder = SentenceTransformer('paraphrase-MiniLM-L3-v2')
    
    # Load main category model components
    try:
        models['category_names'] = joblib.load(os.path.join(models_path, 'category_names_model.joblib'))
        label_encoders['category_names'] = joblib.load(os.path.join(models_path, 'category_names_encoder.joblib'))
        selectors['category_names'] = joblib.load(os.path.join(models_path, 'category_names_selector.joblib'))
    except Exception as e:
        raise RuntimeError(f"Failed to load main category model components: {str(e)}")
    
    # Load all category and subcategory models
    model_files = list(Path(models_path).glob('*_model.joblib'))
    for file in tqdm(model_files, desc="Loading models"):
        if file.name != 'category_names_model.joblib':
            key = file.name.replace('_model.joblib', '')
            try:
                models[key] = joblib.load(file)
                label_encoders[key] = joblib.load(str(file).replace('_model.joblib', '_encoder.joblib'))
                selectors[key] = joblib.load(str(file).replace('_model.joblib', '_selector.joblib'))
            except Exception as e:
                print(f"Warning: Failed to load model components for {key}: {str(e)}")
    
    return encoder, models, label_encoders, selectors

def predict_single(text: str, encoder, models: Dict, selectors: Dict, 
                  label_encoders: Dict, category_to_sub_category: Dict,
                  master_mapper: Dict) -> Dict[str, str]:
    """
    Process a single text through the hierarchical model chain
    
    Args:
        text (str): Input text to classify
        encoder: Sentence transformer encoder
        models (Dict): Dictionary of trained models
        selectors (Dict): Dictionary of feature selectors
        label_encoders (Dict): Dictionary of label encoders
        category_to_sub_category (Dict): Mapping from categories to subcategories
        master_mapper (Dict): Master mapping of categories
        
    Returns:
        Dict containing predictions for category_names, category, sub_category, 
        and sub_category_names
    """
    try:
        # Preprocess and encode text
        processed_text = preprocess_text(text)
        text_embedding = encoder.encode([processed_text], show_progress_bar=False)
        text_embedding = text_embedding.reshape(1, -1)
        
        # Predict main category
        main_features = selectors['category_names'].transform(text_embedding)
        main_cat_pred = models['category_names'].predict(main_features)
        category_names = label_encoders['category_names'].inverse_transform(main_cat_pred)[0]
        
        # Predict category based on main category
        category_model_key = f'category_{category_names.replace(" ", "_").replace("/", "_").replace("&", "and")}'
        try:
            category_features = selectors[category_model_key].transform(text_embedding)
            cat_pred = models[category_model_key].predict(category_features)
            category = label_encoders[category_model_key].inverse_transform(cat_pred)[0]
        except KeyError:
            print(f"Warning: No category model found for {category_names}")
            return {
                'pred_category_names': category_names,
                'pred_retagged_category': 'unknown',
                'pred_retagged_sub_category': 'unknown',
                'pred_sub_category_names': 'unknown'
            }
        
        # Predict subcategory if multiple options exist
        if category in category_to_sub_category and len(category_to_sub_category[category]) > 1:
            sub_category_names_model_key = f'sub_category_names_{category.replace(" ", "_").replace("/", "_").replace("&", "and")}'
            try:
                sub_features = selectors[sub_category_names_model_key].transform(text_embedding)
                mapped_sub_cat_pred = models[sub_category_names_model_key].predict(sub_features)
                sub_category_names = label_encoders[sub_category_names_model_key].inverse_transform(mapped_sub_cat_pred)[0]
                sub_category = find_immediate_key(master_mapper, sub_category_names)
            except KeyError:
                print(f"Warning: No subcategory model found for {category}")
                sub_category_names = category_to_sub_category[category][0]
                sub_category = find_immediate_key(master_mapper, sub_category_names)
        else:
            sub_category_names = category_to_sub_category[category][0]
            sub_category = find_immediate_key(master_mapper, sub_category_names)
        
        return {
            'pred_category_names': category_names,
            'pred_retagged_category': category,
            'pred_retagged_sub_category': sub_category,
            'pred_sub_category_names': sub_category_names
        }
        
    except Exception as e:
        print(f"Error in prediction: {str(e)}")
        return {
            'pred_category_names': 'unknown',
            'pred_retagged_category': 'unknown',
            'pred_retagged_sub_category': 'unknown',
            'pred_sub_category_names': 'unknown'
        }

def find_immediate_key(dictionary, search_value):
    """
    Find the immediate key for a given value in a nested dictionary.
    
    Args:
    dictionary (dict): The nested dictionary to search
    search_value (str): The value to find
    
    Returns:
    str or None: The immediate key if found, None otherwise
    """
    for outer_key, inner_dict in dictionary.items():
        for inner_key, values in inner_dict.items():
            if search_value in values:
                return inner_key
    return None

def save_detailed_results(results_df: pd.DataFrame, test_df: pd.DataFrame):
    """Save detailed analysis of the results"""
    # Combine original text with predictions
    detailed_results = pd.concat([
        test_df['content_processed'],
        results_df
    ], axis=1)
    
    # Add correctness columns
    detailed_results['category_names_correct'] = (
        detailed_results['true_category_names'] == 
        detailed_results['pred_category_names']
    )
    detailed_results['category_correct'] = (
        detailed_results['true_retagged_category'] == 
        detailed_results['pred_retagged_category']
    )
    detailed_results['sub_category_correct'] = (
        detailed_results['true_retagged_sub_category'] == 
        detailed_results['pred_retagged_sub_category']
    )
    
    # Save to CSV
    detailed_results.to_csv('detailed_prediction_results.csv', index=False)
    
    # Save error analysis
    error_cases = detailed_results[
        ~(detailed_results['category_names_correct'] & 
          detailed_results['category_correct'] & 
          detailed_results['sub_category_correct'])
    ]
    error_cases.to_csv('prediction_errors.csv', index=False)

def analyze_examples(results_df: pd.DataFrame, test_df: pd.DataFrame, n_examples: int = 5):
    """Analyze specific examples from the results"""
    print("\n=== Example Predictions ===")
    
    # Sample some random examples
    indices = np.random.choice(len(results_df), min(n_examples, len(results_df)), replace=False)
    
    for idx in indices:
        print("\nText:")
        print(test_df['content_processed'].iloc[idx][:200] + "...")  # Show first 200 chars
        
        print("\nPredictions:")
        print(f"Main Category: {results_df['pred_category_names'].iloc[idx]} "
              f"(True: {results_df['true_category_names'].iloc[idx]})")
        print(f"Category: {results_df['pred_retagged_category'].iloc[idx]} "
              f"(True: {results_df['true_retagged_category'].iloc[idx]})")
        print(f"Sub-Category: {results_df['pred_retagged_sub_category'].iloc[idx]} "
              f"(True: {results_df['true_retagged_sub_category'].iloc[idx]})")
        print("-" * 80)

In [2]:
master_mapper = {
    "any other cyber crime": {
        "other": [
            "other",
            "supply chain attacks"
        ]
    },
    "child pornography cpchild sexual abuse material csam": {
        "child pornography cpchild sexual abuse material csam": [
            "child pornography cpchild sexual abuse material csam"
        ]
    },
    "crime against women & children": {
        "sexual harassment": [
            "sexual harassment"
        ],
        "computer generated csam/csem": [
            "computer generated csam/csem"
        ]
    },
    "cryptocurrency crime": {
        "cryptocurrency fraud": [
            "cryptocurrency fraud"
        ]
    },
    "cyber attack/ dependent crimes": {
        "sql injection": [
            "sql injection"
        ],
        "ransomware attack": [
            "ransomware attack"
        ],
        "malware attack": [
            "malware attack",
            "malicious code attacks (specifically mentioning virus, worm, trojan, bots, spyware, cryptominers)"
        ],
        "data breach/theft": [
            "data breach/theft",
            "data leaks"
        ],
        "hacking/defacement": [
            "hacking/defacement",
            "zero-day exploits",
            "malicious mobile app attacks"
        ],
        "denial of service (dos)/distributed denial of service (ddos) attacks": [
            "denial of service (dos)/distributed denial of service (ddos) attacks"
        ],
        "tampering with computer source documents": [
            "tampering with computer source documents"
        ]
    },
    "cyber terrorism": {
        "cyber terrorism": [
            "cyber terrorism",
            "cyber espionage"
        ]
    },
    "hacking  damage to computercomputer system etc": {
        "email hacking": [
            "email hacking"
        ],
        "unauthorised accessdata breach": [
            "unauthorised accessdata breach",
            "compromise of critical systems/information",
            "targeted scanning/probing of critical networks/systems",
            "attacks on servers (database mail dns) and network devices (routers)",
            "attacks on critical infrastructure, scada, operational technology systems, and wireless networks",
            "attacks or suspicious activities affecting cloud computing systems servers software and applications",
            "attacks or malicious suspicious activities affecting systems related to big data blockchain virtual assets and robotics",
            "attacks on internet of things (iot) devices and associated systems, networks, and servers",
            "attacks on systems related to artificial intelligence (ai) and machine learning (ml)"
        ],
        "damage to computer computer systems etc": [
            "damage to computer computer systems etc"
        ],
        "website defacementhacking": [
            "web application vulnerabilities",
        ]
    },
    "online cyber trafficking": {
        "online trafficking": [
            "online trafficking"
        ]
    },
    "online financial fraud": {
        "upi related frauds": [
            "upi related frauds",
            "aadhar enabled payment system (aeps) fraud"
        ],
        "business email compromiseemail takeover": [
            "business email compromiseemail takeover"
        ],
        "debitcredit card fraudsim swap fraud": [
            "debitcredit card fraudsim swap fraud"
        ],
        "ewallet related fraud": [
            "ewallet related fraud"
        ],
        "fraud callvishing": [
            "fraud callvishing"
        ],
        "internet banking related fraud": [
            "internet banking related fraud",
            "attacks or incidents affecting digital payment systems"
        ]
    },
    "online gambling  betting": {
        "online gambling  betting": [
            "online gambling  betting"
        ]
    },
    "online and social media related crime": {
        "intimidating email": [
            "intimidating email"
        ],
        "provocative speech for unlawful acts": [
            "provocative speech for unlawful acts"
        ],
        "email phishing": [
            "email phishing"
        ],
        "online job fraud": [
            "online job fraud"
        ],
        "profile hacking identity theft": [
            "profile hacking identity theft",
            "identity theft, spoofing, and phishing attacks",
            "unauthorized social media access"
        ],
        "cheating by impersonation": [
            "cheating by impersonation",
            "fake mobile apps"
        ],
        "online matrimonial fraud": [
            "online matrimonial fraud"
        ],
        "cyber bullying  stalking  sexting": [
            "cyber bullying  stalking  sexting"
        ],
        "fakeimpersonating profile": [
            "fakeimpersonating profile"
        ]
    },
    "rapegang rape rgrsexually abusive content": {
        "rapegang rape rgrsexually abusive content": [
            "rapegang rape rgrsexually abusive content"
        ]
    },
    "report unlawful content": {
        "against interest of sovereignty or integrity of india": [
            "against interest of sovereignty or integrity of india",
            "disinformation or misinformation campaigns"
        ]
    },
    "sexually explicit act": {
        "sexually explicit act": [
            "sexually explicit act"
        ]
    },
    "sexually obscene material": {
        "sexually obscene material": [
            "sale publishing and transmitting obscene material/sexually explicit material",
            "sexually obscene material"
        ]
    }
}

# Example usage with your mapping
master_mapper = clean_json_mapping(master_mapper)

In [3]:
# Load test data
print("Loading test data...")
test_df = pd.read_csv('final_test_dataset.csv')

test_df

Loading test data...


Unnamed: 0.1,Unnamed: 0,category,sub_category,sub_category_names,category_names,retagged_sub_category,retagged_category,content_processed
0,0,hacking damage to computercomputer system etc,damage to computer computer systems etc,other,other_cyber_crime,other,any_other_cyber_crime,please read above attached complaint overview ...
1,1,online and social media related crime,cheating by impersonation,other,other_cyber_crime,other,any_other_cyber_crime,so agriculture qr neft dr narender sco satyawan
2,2,cyber attack/ dependent crimes,tampering with computer source documents,Email_Phishing,other_cyber_crime,email_phishing,any_other_cyber_crime,this all happened a few days after i accidenta...
3,3,cyber attack/ dependent crimes,denial of service (dos)/distributed denial of ...,Email_Phishing,other_cyber_crime,email_phishing,any_other_cyber_crime,this all happened a few days after i accidenta...
4,4,cyber attack/ dependent crimes,tampering with computer source documents,Email_Phishing,other_cyber_crime,email_phishing,any_other_cyber_crime,this all happened a few days after i accidenta...
...,...,...,...,...,...,...,...,...
19164,18637,online and social media related crime,online job fraud,online_job_fraud,other_cyber_crime,online_job_fraud,online_and_social_media_related_crime,i recieved call on th jan at pm regarding job ...
19165,18638,online and social media related crime,cheating by impersonation,online_gambling_betting,financial_fraud_crimes,online_gambling_betting,online_gambling_betting,firstly she say you will get profit by investi...
19166,18639,online and social media related crime,cheating by impersonation,fraud_callvishing,financial_fraud_crimes,fraud_callvishing,online_financial_fraud,on date i saw a advertisement of car altoon fa...
19167,18640,online and social media related crime,cheating by impersonation,cheating_by_impersonation,other_cyber_crime,cheating_by_impersonation,online_and_social_media_related_crime,i surfed facebook and my eye caught one facebo...


In [4]:
# Clean up the text data
test_df['content_processed'] = test_df['content_processed'].fillna('')
test_df['content_processed'] = test_df['content_processed'].astype(str)

In [5]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 19169 entries, 0 to 19168
Data columns (total 8 columns):
 #   Column                 Non-Null Count  Dtype 
---  ------                 --------------  ----- 
 0   Unnamed: 0             19169 non-null  int64 
 1   category               19169 non-null  object
 2   sub_category           19168 non-null  object
 3   sub_category_names     19169 non-null  object
 4   category_names         19169 non-null  object
 5   retagged_sub_category  19169 non-null  object
 6   retagged_category      19169 non-null  object
 7   content_processed      19169 non-null  object
dtypes: int64(1), object(7)
memory usage: 1.2+ MB


In [None]:
# Load models and vectorizer
encoder, models, label_encoders, selectors = load_models(models_path='models/')

# Run full inference pipeline
results_df = run_inference_pipeline(
    test_df=test_df,
    encoder=encoder,
    models=models,
    selectors=selectors,
    label_encoders=label_encoders,
    category_to_sub_category=category_to_sub_category,
    master_mapper=master_mapper,
    batch_size=256
)

# Save detailed results
save_detailed_results(results_df, test_df)

Loading models...


Loading models: 100%|██████████| 13/13 [00:01<00:00, 12.30it/s]



Processing batches...


Processing:   9%|▉         | 7/75 [01:00<09:46,  8.62s/it]


Batch 7/75
Current retagged_category Accuracy: 0.7031
Current category_names Accuracy: 0.9503
Current sub_category_names Accuracy: 0.5022
Current retagged_sub_category Accuracy: 0.5117


Processing:  19%|█▊        | 14/75 [02:00<08:40,  8.52s/it]


Batch 14/75
Current retagged_category Accuracy: 0.8401
Current category_names Accuracy: 0.9637
Current sub_category_names Accuracy: 0.5614
Current retagged_sub_category Accuracy: 0.5664


Processing:  28%|██▊       | 21/75 [03:03<08:06,  9.00s/it]


Batch 21/75
Current retagged_category Accuracy: 0.8668
Current category_names Accuracy: 0.9520
Current sub_category_names Accuracy: 0.5800
Current retagged_sub_category Accuracy: 0.5863


Processing:  37%|███▋      | 28/75 [04:01<06:28,  8.26s/it]


Batch 28/75
Current retagged_category Accuracy: 0.8991
Current category_names Accuracy: 0.9630
Current sub_category_names Accuracy: 0.6667
Current retagged_sub_category Accuracy: 0.6715


Processing:  47%|████▋     | 35/75 [05:05<06:04,  9.11s/it]


Batch 35/75
Current retagged_category Accuracy: 0.9146
Current category_names Accuracy: 0.9658
Current sub_category_names Accuracy: 0.6913
Current retagged_sub_category Accuracy: 0.6952


Processing:  56%|█████▌    | 42/75 [06:06<04:51,  8.84s/it]


Batch 42/75
Current retagged_category Accuracy: 0.9131
Current category_names Accuracy: 0.9558
Current sub_category_names Accuracy: 0.6785
Current retagged_sub_category Accuracy: 0.6819


Processing:  65%|██████▌   | 49/75 [07:11<03:43,  8.59s/it]


Batch 49/75
Current retagged_category Accuracy: 0.9209
Current category_names Accuracy: 0.9575
Current sub_category_names Accuracy: 0.6775
Current retagged_sub_category Accuracy: 0.6804


Processing:  75%|███████▍  | 56/75 [08:15<02:54,  9.16s/it]


Batch 56/75
Current retagged_category Accuracy: 0.9238
Current category_names Accuracy: 0.9558
Current sub_category_names Accuracy: 0.6793
Current retagged_sub_category Accuracy: 0.6819


Processing:  84%|████████▍ | 63/75 [09:42<01:52,  9.34s/it]


Batch 63/75
Current retagged_category Accuracy: 0.8953
Current category_names Accuracy: 0.9272
Current sub_category_names Accuracy: 0.6479
Current retagged_sub_category Accuracy: 0.6514


Processing:  93%|█████████▎| 70/75 [10:34<00:37,  7.59s/it]


Batch 70/75
Current retagged_category Accuracy: 0.8732
Current category_names Accuracy: 0.9065
Current sub_category_names Accuracy: 0.6247
Current retagged_sub_category Accuracy: 0.6289


Processing: 100%|██████████| 75/75 [11:11<00:00,  8.96s/it]



Calculating final metrics...

=== Main Category Metrics ===
Accuracy: 0.8955

Classification Report:
                           precision    recall  f1-score   support

   financial_fraud_crimes       0.92      0.95      0.94     14122
        other_cyber_crime       0.81      0.78      0.79      4845
women_child_related_crime       0.94      0.16      0.28       202

                 accuracy                           0.90     19169
                macro avg       0.89      0.63      0.67     19169
             weighted avg       0.89      0.90      0.89     19169


Confusion Matrix:
[[13372   750     0]
 [ 1083  3760     2]
 [    9   160    33]]

=== retagged_category Metrics ===
Accuracy: 0.8619

Classification Report:


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                      precision    recall  f1-score   support

                               any_other_cyber_crime       0.86      0.01      0.02       519
child_pornography_cpchild_sexual_abuse_material_csam       0.65      0.34      0.45        32
                        crime_against_women_children       0.00      0.00      0.00         2
                                cryptocurrency_crime       1.00      0.17      0.29        18
                       cyber_attack_dependent_crimes       0.00      0.00      0.00        51
                                     cyber_terrorism       1.00      0.33      0.50         3
       hacking_damage_to_computercomputer_system_etc       0.13      0.21      0.16       124
               online_and_social_media_related_crime       0.75      0.76      0.75      4147
                            online_cyber_trafficking       1.00      0.10      0.18        10
                              online_financial_fraud       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                                                                                         precision    recall  f1-score   support

                                                                                              Cheating_by_Impersonation       0.00      0.00      0.00         5
                                                                                                          Email_Hacking       0.00      0.00      0.00        74
                                                                                                         Email_Phishing       0.00      0.00      0.00       218
                                                                               aadhar_enabled_payment_system_aeps_fraud       0.01      0.25      0.03         8
                                                                  against_interest_of_sovereignty_or_integrity_of_india       0.00      0.00      0.00         0
                          attacks

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                                  precision    recall  f1-score   support

           against_interest_of_sovereignty_or_integrity_of_india       0.00      0.00      0.00         1
                         business_email_compromiseemail_takeover       1.00      0.04      0.07        77
                                       cheating_by_impersonation       0.32      0.12      0.17       867
            child_pornography_cpchild_sexual_abuse_material_csam       0.65      0.34      0.45        32
                                            cryptocurrency_fraud       1.00      0.17      0.29        18
                                  cyber_blackmailing_threatening       0.00      0.00      0.00         2
                                 cyber_bullying_stalking_sexting       0.59      0.91      0.72      2136
                                                 cyber_terrorism       1.00      0.33      0.50         3
                         damage_to_computer_c

In [7]:
# Optional: Analyze specific examples
analyze_examples(results_df, test_df)


=== Example Predictions ===

Text:
i received a spam video call from unknown mo no as usual on picking up the vc front camera captured my face and shehe started taking screenshot with her vulgar video and my face from from front camera...

Predictions:
Main Category: other_cyber_crime (True: other_cyber_crime)
Category: online_and_social_media_related_crime (True: online_and_social_media_related_crime)
Sub-Category: cyber_bullying_stalking_sexting (True: cyber_bullying_stalking_sexting)
--------------------------------------------------------------------------------

Text:
u brutus youtube channel...

Predictions:
Main Category: other_cyber_crime (True: other_cyber_crime)
Category: online_and_social_media_related_crime (True: any_other_cyber_crime)
Sub-Category: cyber_bullying_stalking_sexting (True: other)
--------------------------------------------------------------------------------

Text:
today i have unknowingly touched a link carrying the notification for kyc updation and it as