# NN Multiclass

In [2]:
import pandas as pd
from imblearn.over_sampling import SMOTE
from sklearn.utils import resample
from rich import print
from tabulate import tabulate

In [3]:
def balance_dataset(df, max_ratio=1.5, min_ratio=0.67):
    """
    Balance dataset using a more conservative approach with SMOTE
    
    Parameters:
    - df: DataFrame with 'category' column
    - max_ratio: Maximum ratio between largest and smallest class (default 1.5)
    - min_ratio: Minimum ratio between smallest and reference class (default 0.67)
    
    Returns:
    - Balanced DataFrame
    """
    
    # Get category counts
    category_counts = df['category'].value_counts()
    
    # Calculate target counts
    median_count = category_counts.median()
    min_samples = int(median_count * min_ratio)  # Lower bound
    max_samples = int(median_count * max_ratio)  # Upper bound
    
    # Initialize final dataframes list
    final_dfs = []
    
    # Handle undersampling for large classes
    for category in category_counts[category_counts > max_samples].index:
        category_df = df[df['category'] == category]
        downsampled = resample(category_df,
                             n_samples=max_samples,
                             random_state=42)
        final_dfs.append(downsampled)
    
    # Handle oversampling for small classes
    small_categories = category_counts[category_counts < min_samples].index
    if len(small_categories) > 0:
        # Prepare data for SMOTE
        small_df = df[df['category'].isin(small_categories)]
        cat_cols = ['category', 'attack']
        cat_data = small_df[cat_cols].copy()
        
        # Apply SMOTE
        smote = SMOTE(sampling_strategy={
                cat: min(category_counts[cat] * 2, min_samples) for cat in small_categories
            }, random_state=42)
        
        X_resampled, y_resampled = smote.fit_resample(
            small_df.drop(cat_cols, axis=1), 
            small_df['category']
        )
        
        # Reconstruct DataFrame
        augmented_df = pd.DataFrame(X_resampled, columns=df.drop(cat_cols, axis=1).columns)
        augmented_df['category'] = y_resampled
        augmented_df['attack'] = cat_data['attack'].iloc[0]  # Simplified attack labeling
        final_dfs.append(augmented_df)
    
    # Keep medium-sized classes as is
    medium_mask = (category_counts >= min_samples) & (category_counts <= max_samples)
    for category in category_counts[medium_mask].index:
        final_dfs.append(df[df['category'] == category])
    
    # Combine all data
    final_df = pd.concat(final_dfs, ignore_index=True)
    # Get final counts and prepare comparison
    final_counts = final_df['category'].value_counts()
    
    # Create comparison table
    comparison_data = []
    for category in sorted(category_counts.index):
        comparison_data.append([
            category,
            category_counts[category],
            final_counts.get(category, 0)
        ])
    
    # Print comparison table
    print(tabulate(
        comparison_data,
        headers=['Category', 'Original', 'After Balance'],
        tablefmt='psql'
    ))
    
    return final_df

# Usage:
df = pd.read_csv('dataset/train_labeled.csv')
balanced_df = balance_dataset(df, max_ratio=1.1, min_ratio=0.5)
balanced_df.to_csv('dataset/train_smote.csv', index=False)