In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix
from autogluon.text import TextPredictor
import warnings
warnings.filterwarnings('ignore')

# Set up logging and visualization
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

class NewsClassificationProject:
    """Complete news article classification pipeline"""
    
    def __init__(self, data_path: str, project_name: str = "news_classifier"):
        self.data_path = data_path
        self.project_name = project_name
        self.predictor = None
        self.categories = None
        self.results = {}
        
    def load_and_explore_data(self):
        """Load data and perform exploratory analysis"""
        print("Loading and exploring dataset...")
        
        # Load the data
        self.data = pd.read_csv(self.data_path)
        print(f"Dataset shape: {self.data.shape}")
        print(f"Columns: {self.data.columns.tolist()}")
        
        # Check for required columns
        required_cols = ['title', 'content', 'category']
        missing_cols = [col for col in required_cols if col not in self.data.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")
        
        # Basic statistics
        print(f"\nDataset Statistics:")
        print(f"Total articles: {len(self.data):,}")
        print(f"Unique categories: {self.data['category'].nunique()}")
        print(f"Missing values: {self.data.isnull().sum().sum()}")
        
        # Category distribution
        self.categories = self.data['category'].value_counts()
        print(f"\nCategory Distribution:")
        print(self.categories)
        
        # Text length analysis
        self.data['title_length'] = self.data['title'].str.len()
        self.data['content_length'] = self.data['content'].str.len()
        self.data['total_length'] = self.data['title_length'] + self.data['content_length']
        
        print(f"\nText Length Statistics:")
        print(f"Average title length: {self.data['title_length'].mean():.1f} characters")
        print(f"Average content length: {self.data['content_length'].mean():.1f} characters")
        print(f"Max total length: {self.data['total_length'].max():,} characters")
        
        return self.data
    
    def visualize_data_distribution(self):
        """Create visualizations of data distribution"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Category distribution
        axes[0,0].pie(self.categories.values, labels=self.categories.index, autopct='%1.1f%%')
        axes[0,0].set_title('Category Distribution')
        
        # Text length distribution
        axes[0,1].hist(self.data['total_length'], bins=50, alpha=0.7)
        axes[0,1].set_title('Text Length Distribution')
        axes[0,1].set_xlabel('Total Characters')
        axes[0,1].set_ylabel('Frequency')
        
        # Text length by category
        for i, category in enumerate(self.categories.head(5).index):
            category_data = self.data[self.data['category'] == category]['total_length']
            axes[1,0].hist(category_data, bins=30, alpha=0.6, label=category)
        axes[1,0].set_title('Text Length by Category (Top 5)')
        axes[1,0].set_xlabel('Total Characters')
        axes[1,0].legend()
        
        # Category count bar plot
        axes[1,1].bar(range(len(self.categories)), self.categories.values)
        axes[1,1].set_title('Articles per Category')
        axes[1,1].set_xlabel('Category')
        axes[1,1].set_ylabel('Count')
        axes[1,1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig(f'{self.project_name}_data_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def prepare_training_data(self, min_samples_per_class=100, test_size=0.2):
        """Prepare balanced training and test sets"""
        print(f"\nPreparing training data...")
        
        # Filter categories with sufficient samples
        valid_categories = self.categories[self.categories >= min_samples_per_class].index
        filtered_data = self.data[self.data['category'].isin(valid_categories)].copy()
        
        print(f"Filtered to {len(valid_categories)} categories with >= {min_samples_per_class} samples")
        print(f"Remaining data: {len(filtered_data):,} articles")
        
        # Combine title and content
        filtered_data['text'] = filtered_data['title'] + ' ' + filtered_data['content']
        
        # Handle class imbalance by capping maximum samples per class
        max_samples_per_class = min_samples_per_class * 5  # Allow up to 5x the minimum
        balanced_data = []
        
        for category in valid_categories:
            category_data = filtered_data[filtered_data['category'] == category]
            if len(category_data) > max_samples_per_class:
                category_data = category_data.sample(n=max_samples_per_class, random_state=42)
            balanced_data.append(category_data)
        
        self.processed_data = pd.concat(balanced_data, ignore_index=True)
        print(f"Balanced dataset: {len(self.processed_data):,} articles")
        
        # Split data
        self.train_data, self.test_data = train_test_split(
            self.processed_data[['text', 'category']],
            test_size=test_size,
            stratify=self.processed_data['category'],
            random_state=42
        )
        
        print(f"Training set: {len(self.train_data):,} articles")
        print(f"Test set: {len(self.test_data):,} articles")
        
        return self.train_data, self.test_data
    
    def train_models(self, time_limit=7200, presets='best_quality'):
        """Train the news classification model"""
        print(f"\nTraining models with {presets} preset...")
        print(f"Time limit: {time_limit/3600:.1f} hours")
        
        # Custom hyperparameters optimized for news text
        hyperparameters = {
            'MultimodalTextModel': {
                'optimization.learning_rate': [1e-5, 2e-5, 3e-5],
                'optimization.max_epochs': [3, 5, 8],
                'optimization.per_device_train_batch_size': [16, 32],
                'model.hf_text.checkpoint_name': [
                    'distilbert-base-uncased',
                    'roberta-base',
                    'microsoft/DialoGPT-medium'
                ],
                'model.hf_text.dropout_prob': [0.1, 0.2],
                'optimization.gradient_clip_val': [1.0, 2.0]
            },
            'XGBModel': {
                'n_estimators': [100, 200, 400],
                'max_depth': [4, 6, 8],
                'learning_rate': [0.05, 0.1, 0.2]
            }
        }
        
        # Initialize predictor
        self.predictor = TextPredictor(
            label='category',
            path=f'./{self.project_name}_model',
            eval_metric='f1_macro',  # Good for multi-class with potential imbalance
            verbosity=2
        )
        
        # Train models
        self.predictor.fit(
            self.train_data,
            time_limit=time_limit,
            presets=presets,
            hyperparameters=hyperparameters,
            num_cpus=8,
            num_gpus=1
        )
        
        # Get training results
        self.leaderboard = self.predictor.leaderboard(silent=False)
        print(f"\nTraining completed!")
        print(f"Best model: {self.leaderboard.index[0]}")
        print(f"Best validation score: {self.leaderboard.iloc[0]['score_val']:.4f}")
        
        return self.predictor
    
    def evaluate_model(self, save_results=True):
        """Comprehensive model evaluation"""
        print(f"\nEvaluating model performance...")
        
        # Basic performance metrics
        test_performance = self.predictor.evaluate(self.test_data)
        predictions = self.predictor.predict(self.test_data)
        probabilities = self.predictor.predict_proba(self.test_data)
        
        print(f"Test Performance:")
        for metric, value in test_performance.items():
            print(f"  {metric}: {value:.4f}")
        
        # Detailed classification report
        categories = sorted(self.test_data['category'].unique())
        class_report = classification_report(
            self.test_data['category'],
            predictions,
            target_names=categories,
            output_dict=True
        )
        
        # Print per-class metrics
        print(f"\nPer-Class Performance:")
        print(f"{'Category':<20} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<10}")
        print("-" * 70)
        
        for category in categories:
            if category in class_report:
                metrics = class_report[category]
                print(f"{category:<20} {metrics['precision']:<10.3f} "
                      f"{metrics['recall']:<10.3f} {metrics['f1-score']:<10.3f} "
                      f"{int(metrics['support']):<10}")
        
        # Confusion matrix
        cm = confusion_matrix(self.test_data['category'], predictions, labels=categories)
        
        # Visualization
        self.plot_evaluation_results(cm, categories, class_report)
        
        # Store results
        self.results = {
            'test_performance': test_performance,
            'classification_report': class_report,
            'confusion_matrix': cm,
            'categories': categories,
            'leaderboard': self.leaderboard
        }
        
        if save_results:
            self.save_evaluation_results()
        
        return self.results
    
    def plot_evaluation_results(self, cm, categories, class_report):
        """Create comprehensive evaluation visualizations"""
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # Confusion Matrix
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        sns.heatmap(cm_normalized, annot=True, fmt='.2f', 
                   xticklabels=categories, yticklabels=categories,
                   cmap='Blues', ax=axes[0,0])
        axes[0,0].set_title('Normalized Confusion Matrix')
        axes[0,0].set_xlabel('Predicted')
        axes[0,0].set_ylabel('Actual')
        
        # Per-class F1 scores
        f1_scores = [class_report[cat]['f1-score'] for cat in categories if cat in class_report]
        axes[0,1].bar(range(len(categories)), f1_scores)
        axes[0,1].set_title('F1-Score by Category')
        axes[0,1].set_xlabel('Category')
        axes[0,1].set_ylabel('F1-Score')
        axes[0,1].set_xticks(range(len(categories)))
        axes[0,1].set_xticklabels(categories, rotation=45)
        
        # Model performance comparison
        model_scores = self.leaderboard['score_val'].head(10)
        axes[1,0].barh(range(len(model_scores)), model_scores.values)
        axes[1,0].set_title('Model Performance Comparison')
        axes[1,0].set_xlabel('Validation Score')
        axes[1,0].set_yticks(range(len(model_scores)))
        axes[1,0].set_yticklabels(model_scores.index, fontsize=8)
        
        # Training time vs performance
        if 'fit_time' in self.leaderboard.columns:
            axes[1,1].scatter(self.leaderboard['fit_time'], self.leaderboard['score_val'])
            axes[1,1].set_title('Training Time vs Performance')
            axes[1,1].set_xlabel('Training Time (seconds)')
            axes[1,1].set_ylabel('Validation Score')
        
        plt.tight_layout()
        plt.savefig(f'{self.project_name}_evaluation.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def test_on_sample_articles(self):
        """Test the model on sample news articles"""
        sample_articles = [
            "The Federal Reserve announced a 0.25% interest rate hike following concerns about inflation rates reaching multi-decade highs.",
            "Local basketball team advances to championship finals after defeating rivals 98-87 in overtime thriller at packed arena.",
            "Breakthrough artificial intelligence research promises to revolutionize medical diagnosis with 95% accuracy in detecting early-stage cancer.",
            "Climate activists demand immediate action as global temperatures reach record highs for third consecutive year.",
            "New smartphone features include advanced camera technology and longer battery life, launching next month.",
            "President signs landmark infrastructure bill allocating $1.2 trillion for roads, bridges, and broadband expansion."
        ]
        
        print(f"\nTesting on sample articles:")
        print("=" * 80)
        
        predictions = self.predictor.predict(sample_articles)
        probabilities = self.predictor.predict_proba(sample_articles)
        
        for i, article in enumerate(sample_articles):
            print(f"\nArticle {i+1}: {article[:80]}...")
            print(f"Predicted category: {predictions[i]}")
            
            # Show top 3 predictions with probabilities
            top_preds = probabilities.iloc[i].sort_values(ascending=False).head(3)
            print("Top 3 predictions:")
            for j, (category, prob) in enumerate(top_preds.items()):
                print(f"  {j+1}. {category}: {prob:.3f}")
            print("-" * 60)
    
    def save_evaluation_results(self):
        """Save evaluation results to files"""
        results_dir = Path(f'{self.project_name}_results')
        results_dir.mkdir(exist_ok=True)
        
        # Save detailed results
        pd.DataFrame(self.results['classification_report']).T.to_csv(
            results_dir / 'classification_report.csv'
        )
        
        self.leaderboard.to_csv(results_dir / 'model_leaderboard.csv')
        
        # Save confusion matrix
        cm_df = pd.DataFrame(
            self.results['confusion_matrix'],
            index=self.results['categories'],
            columns=self.results['categories']
        )
        cm_df.to_csv(results_dir / 'confusion_matrix.csv')
        
        print(f"Results saved to {results_dir}/")
    
    def run_complete_pipeline(self, data_path=None, time_limit=7200):
        """Run the complete news classification pipeline"""
        if data_path:
            self.data_path = data_path
            
        print("Starting News Classification Project")
        print("=" * 50)
        
        # Step 1: Load and explore data
        self.load_and_explore_data()
        self.visualize_data_distribution()
        
        # Step 2: Prepare training data
        self.prepare_training_data()
        
        # Step 3: Train models
        self.train_models(time_limit=time_limit)
        
        # Step 4: Evaluate performance
        self.evaluate_model()
        
        # Step 5: Test on samples
        self.test_on_sample_articles()
        
        print(f"\nProject completed successfully!")
        print(f"Model saved to: {self.project_name}_model/")
        print(f"Results saved to: {self.project_name}_results/")
        
        return self.predictor, self.results

# Example usage
if __name__ == "__main__":
    # Initialize the project
    project = NewsClassificationProject(
        data_path='data/news_dataset.csv',  # Replace with your dataset path
        project_name='news_classifier_v1'
    )
    
    # Run the complete pipeline
    predictor, results = project.run_complete_pipeline(time_limit=3600)  # 1 hour
    
    # Additional analysis or deployment steps can be added here
    print(f"\nFinal Results Summary:")
    print(f"Best Model: {results['leaderboard'].index[0]}")
    print(f"Test Accuracy: {results['test_performance'].get('accuracy', 0):.4f}")
    print(f"Test F1-Macro: {results['test_performance'].get('f1_macro', 0):.4f}")