In [None]:
# notebooks/01_eda.ipynb
# ==============================================================================
# Intelligent Document Classification System
# Exploratory Data Analysis Notebook
# ==============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# ==============================================================================
# 1. Data Loading
# ==============================================================================

print("üìä Loading datasets...")

# Load data from multiple sources
train_df = pd.read_csv('../data/raw/train.csv')
val_df = pd.read_csv('../data/raw/val.csv')
test_df = pd.read_csv('../data/raw/test.csv')

print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Test samples: {len(test_df):,}")

# Combine for overall analysis
all_data = pd.concat([train_df, val_df, test_df], ignore_index=True)

# ==============================================================================
# 2. Basic Data Exploration
# ==============================================================================

print("\nüîç Dataset Overview:")
print("=" * 50)
print(f"Total documents: {len(all_data):,}")
print(f"Number of features: {all_data.shape[1]}")
print(f"Columns: {list(all_data.columns)}")

# Display sample data
print("\nüìÑ Sample documents:")
display(all_data.head())

# Data types
print("\nüìà Data Types:")
print(all_data.dtypes)

# Missing values
print("\n‚ùå Missing Values Analysis:")
missing_data = pd.DataFrame({
    'Column': all_data.columns,
    'Missing_Values': all_data.isnull().sum(),
    'Missing_Percentage': (all_data.isnull().sum() / len(all_data)) * 100
})
print(missing_data[missing_data['Missing_Values'] > 0])

# ==============================================================================
# 3. Target Variable Analysis
# ==============================================================================

print("\nüéØ Target Variable Analysis:")
print("=" * 50)

# Check if we have a target column
target_col = 'category'  # Adjust based on your data
if target_col in all_data.columns:
    # Class distribution
    class_dist = all_data[target_col].value_counts()
    
    print(f"Number of unique classes: {len(class_dist)}")
    print(f"Class distribution:\n{class_dist}")
    
    # Visualize class distribution
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=('Class Distribution', 'Class Proportions'),
        specs=[[{'type': 'bar'}, {'type': 'pie'}]]
    )
    
    fig.add_trace(
        go.Bar(x=class_dist.index, y=class_dist.values,
               marker_color=px.colors.qualitative.Set3,
               text=class_dist.values,
               textposition='auto'),
        row=1, col=1
    )
    
    fig.add_trace(
        go.Pie(labels=class_dist.index, values=class_dist.values,
               hole=0.3,
               marker_colors=px.colors.qualitative.Set3),
        row=1, col=2
    )
    
    fig.update_layout(height=500, showlegend=True,
                      title_text="Target Class Distribution")
    fig.show()
    
    # Check for class imbalance
    imbalance_ratio = class_dist.max() / class_dist.min()
    print(f"\n‚ö†Ô∏è Class imbalance ratio: {imbalance_ratio:.2f}")
    if imbalance_ratio > 10:
        print("Warning: Severe class imbalance detected!")

# ==============================================================================
# 4. Text Analysis
# ==============================================================================

print("\nüìù Text Analysis:")
print("=" * 50)

# Select text column
text_col = 'text'  # Adjust based on your data
if text_col in all_data.columns:
    # Calculate text statistics
    all_data['text_length'] = all_data[text_col].astype(str).apply(len)
    all_data['word_count'] = all_data[text_col].astype(str).apply(lambda x: len(str(x).split()))
    all_data['sentence_count'] = all_data[text_col].astype(str).apply(lambda x: len(str(x).split('.')))
    all_data['avg_word_length'] = all_data[text_col].astype(str).apply(
        lambda x: np.mean([len(word) for word in str(x).split()]) if len(str(x).split()) > 0 else 0
    )
    
    # Summary statistics
    text_stats = all_data[['text_length', 'word_count', 'sentence_count', 'avg_word_length']].describe()
    print("Text Statistics:")
    print(text_stats)
    
    # Visualize text length distributions
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Text Length Distribution', 'Word Count Distribution',
                       'Sentence Count Distribution', 'Average Word Length'),
        specs=[[{'type': 'histogram'}, {'type': 'histogram'}],
               [{'type': 'histogram'}, {'type': 'histogram'}]]
    )
    
    metrics = ['text_length', 'word_count', 'sentence_count', 'avg_word_length']
    titles = ['Characters', 'Words', 'Sentences', 'Avg Word Length']
    
    for i, (metric, title) in enumerate(zip(metrics, titles)):
        row = i // 2 + 1
        col = i % 2 + 1
        
        fig.add_trace(
            go.Histogram(x=all_data[metric],
                        nbinsx=50,
                        marker_color=px.colors.sequential.Blues[i*2],
                        name=title),
            row=row, col=col
        )
        
        # Add mean line
        mean_val = all_data[metric].mean()
        fig.add_vline(x=mean_val, line_dash="dash", line_color="red",
                     annotation_text=f"Mean: {mean_val:.1f}",
                     row=row, col=col)
    
    fig.update_layout(height=600, showlegend=False,
                      title_text="Text Feature Distributions")
    fig.show()
    
    # Text length by category
    if target_col in all_data.columns:
        fig = px.box(all_data, x=target_col, y='word_count',
                    color=target_col,
                    title="Word Count Distribution by Category",
                    labels={'word_count': 'Number of Words', target_col: 'Category'})
        fig.update_layout(height=500)
        fig.show()

# ==============================================================================
# 5. N-gram Analysis
# ==============================================================================

print("\nüî§ N-gram Analysis:")
print("=" * 50)

from collections import Counter
from sklearn.feature_extraction.text import CountVectorizer
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords', quiet=True)

if text_col in all_data.columns:
    # Get stopwords
    stop_words = set(stopwords.words('english'))
    
    # Function to get top n-grams
    def get_top_ngrams(corpus, ngram_range=(1,1), n=10):
        vec = CountVectorizer(ngram_range=ngram_range, 
                            stop_words='english',
                            max_features=10000).fit(corpus)
        bag_of_words = vec.transform(corpus)
        sum_words = bag_of_words.sum(axis=0) 
        words_freq = [(word, sum_words[0, idx]) 
                     for word, idx in vec.vocabulary_.items()]
        words_freq = sorted(words_freq, key=lambda x: x[1], reverse=True)
        return words_freq[:n]
    
    # Analyze for different n-grams
    ngram_ranges = [(1,1), (2,2), (3,3)]
    ngram_names = ['Unigrams', 'Bigrams', 'Trigrams']
    
    for (ngram_min, ngram_max), name in zip(ngram_ranges, ngram_names):
        top_ngrams = get_top_ngrams(all_data[text_col].astype(str), 
                                   ngram_range=(ngram_min, ngram_max), 
                                   n=15)
        
        print(f"\nTop 15 {name}:")
        for word, freq in top_ngrams:
            print(f"  {word}: {freq}")
        
        # Visualize
        words, frequencies = zip(*top_ngrams)
        
        fig = px.bar(x=list(words)[::-1], y=list(frequencies)[::-1],
                    orientation='h',
                    title=f"Top 15 {name}",
                    labels={'x': 'Frequency', 'y': f'{name}'},
                    color=list(frequencies)[::-1],
                    color_continuous_scale='Viridis')
        fig.update_layout(height=400)
        fig.show()

# ==============================================================================
# 6. Category-Wise Analysis
# ==============================================================================

print("\nüìä Category-Wise Analysis:")
print("=" * 50)

if target_col in all_data.columns and text_col in all_data.columns:
    # Get top words per category
    unique_categories = all_data[target_col].unique()
    
    for category in unique_categories[:5]:  # Limit to first 5 categories for brevity
        category_docs = all_data[all_data[target_col] == category][text_col]
        
        if len(category_docs) > 0:
            top_words = get_top_ngrams(category_docs.astype(str), ngram_range=(1,1), n=10)
            
            print(f"\nTop words for category '{category}':")
            words, freqs = zip(*top_words)
            for word, freq in zip(words, freqs):
                print(f"  {word}: {freq}")
            
            # Word cloud for each category
            try:
                from wordcloud import WordCloud
                
                text = ' '.join(category_docs.astype(str))
                wordcloud = WordCloud(width=800, height=400,
                                    background_color='white',
                                    max_words=50).generate(text)
                
                plt.figure(figsize=(10, 5))
                plt.imshow(wordcloud, interpolation='bilinear')
                plt.axis('off')
                plt.title(f"Word Cloud for '{category}'")
                plt.show()
            except:
                print("WordCloud not available. Install with: pip install wordcloud")

# ==============================================================================
# 7. Data Quality Analysis
# ==============================================================================

print("\nüîç Data Quality Analysis:")
print("=" * 50)

# Check for duplicates
duplicates = all_data.duplicated(subset=[text_col] if text_col in all_data.columns else None).sum()
print(f"Duplicate documents: {duplicates} ({duplicates/len(all_data)*100:.2f}%)")

# Check for empty documents
empty_docs = all_data[text_col].apply(lambda x: len(str(x).strip()) == 0).sum()
print(f"Empty documents: {empty_docs} ({empty_docs/len(all_data)*100:.2f}%)")

# Check for very short documents
short_docs = all_data['word_count' if 'word_count' in all_data.columns else text_col].apply(
    lambda x: len(str(x).split()) < 10
).sum()
print(f"Very short documents (<10 words): {short_docs} ({short_docs/len(all_data)*100:.2f}%)")

# ==============================================================================
# 8. Correlation Analysis
# ==============================================================================

print("\nüìà Correlation Analysis:")
print("=" * 50)

# Calculate correlations between text features
if all('word_count' in all_data.columns,
       'text_length' in all_data.columns,
       'sentence_count' in all_data.columns):
    
    numeric_features = ['text_length', 'word_count', 'sentence_count', 'avg_word_length']
    correlation_matrix = all_data[numeric_features].corr()
    
    fig = px.imshow(correlation_matrix,
                   text_auto=True,
                   aspect="auto",
                   color_continuous_scale='RdBu',
                   title="Correlation Matrix of Text Features")
    fig.update_layout(height=500)
    fig.show()
    
    print("\nFeature Correlations:")
    print(correlation_matrix)

# ==============================================================================
# 9. Temporal Analysis (if date column exists)
# ==============================================================================

print("\nüìÖ Temporal Analysis:")
print("=" * 50)

# Check for date columns
date_columns = [col for col in all_data.columns if 'date' in col.lower() or 'time' in col.lower()]
if date_columns:
    date_col = date_columns[0]
    all_data[date_col] = pd.to_datetime(all_data[date_col], errors='coerce')
    
    # Documents over time
    docs_over_time = all_data[date_col].dt.to_period('M').value_counts().sort_index()
    
    fig = px.line(x=docs_over_time.index.astype(str), 
                 y=docs_over_time.values,
                 title="Documents Over Time",
                 labels={'x': 'Month', 'y': 'Number of Documents'})
    fig.update_layout(height=400)
    fig.show()
    
    # Category distribution over time
    if target_col in all_data.columns:
        monthly_data = all_data.groupby([all_data[date_col].dt.to_period('M'), target_col]).size().unstack(fill_value=0)
        
        fig = px.line(monthly_data,
                     title="Category Distribution Over Time",
                     labels={'value': 'Number of Documents', 'variable': 'Category'})
        fig.update_layout(height=500)
        fig.show()

# ==============================================================================
# 10. Summary and Recommendations
# ==============================================================================

print("\nüìã EDA Summary and Recommendations:")
print("=" * 50)

summary = {
    "Total Documents": len(all_data),
    "Number of Classes": len(all_data[target_col].unique()) if target_col in all_data.columns else 0,
    "Class Imbalance Ratio": f"{imbalance_ratio:.2f}" if 'imbalance_ratio' in locals() else "N/A",
    "Average Word Count": f"{all_data['word_count'].mean():.1f}" if 'word_count' in all_data.columns else "N/A",
    "Missing Values": f"{missing_data['Missing_Values'].sum()} total",
    "Duplicate Documents": duplicates,
    "Empty Documents": empty_docs
}

for key, value in summary.items():
    print(f"‚úì {key}: {value}")

print("\nüéØ Key Insights & Recommendations:")
print("-" * 40)

insights = []

# Class imbalance check
if 'imbalance_ratio' in locals() and imbalance_ratio > 3:
    insights.append("üî∏ Class imbalance detected - consider using class weights or oversampling/undersampling")
else:
    insights.append("‚úÖ Class distribution is relatively balanced")

# Text length insights
if 'word_count' in all_data.columns:
    avg_words = all_data['word_count'].mean()
    if avg_words < 50:
        insights.append("üî∏ Documents are very short - consider using character-level models or n-grams")
    elif avg_words > 1000:
        insights.append("üî∏ Documents are very long - consider truncation or hierarchical models")
    else:
        insights.append("‚úÖ Document length is suitable for most models")

# Data quality
if duplicates > 0:
    insights.append(f"üî∏ {duplicates} duplicates found - consider removing them")
if empty_docs > 0:
    insights.append(f"üî∏ {empty_docs} empty documents found - consider removing them")

# Missing values
if missing_data['Missing_Values'].sum() > 0:
    insights.append("üî∏ Missing values found - need to handle them in preprocessing")

for i, insight in enumerate(insights, 1):
    print(f"{i}. {insight}")

print("\nüìä Next Steps:")
print("1. Address data quality issues (duplicates, empty docs)")
print("2. Handle class imbalance if significant")
print("3. Preprocess text data (cleaning, normalization)")
print("4. Split data appropriately for training")
print("5. Extract features based on text characteristics")
