In [None]:
# Clinical Text Summarization - Complete End-to-End Pipeline
# Single Model Implementation with T5

import pandas as pd
import numpy as np
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import T5Tokenizer, T5ForConditionalGeneration, pipeline
import torch
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Handle rouge_score import gracefully
try:
    from rouge_score import rouge_scorer
    ROUGE_AVAILABLE = True
except ImportError:
    print("rouge_score not available. Installing...")
    import subprocess
    import sys
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "rouge_score"])
        from rouge_score import rouge_scorer
        ROUGE_AVAILABLE = True
        print("rouge_score installed successfully!")
    except Exception as e:
        print(f" Could not install rouge_score: {e}")
        print(" Using alternative evaluation metrics...")
        ROUGE_AVAILABLE = False

# Download required NLTK data
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords')

class ClinicalSummarizer:
    """Complete End-to-End Clinical Text Summarization Pipeline"""

    def __init__(self):
        self.model_name = "t5-small"
        self.tokenizer = None
        self.model = None
        self.summarizer = None
        self.data = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

    def load_and_explore_data(self, filepath):
        """Step 1: Load and explore the dataset"""
        print("=" * 60)
        print("STEP 1: DATA LOADING AND EXPLORATION")
        print("=" * 60)

        # Load data
        print(f"Loading data from {filepath}...")
        self.data = pd.read_csv(filepath)

        print(f"✓ Data loaded successfully!")
        print(f"✓ Dataset shape: {self.data.shape}")
        print(f"✓ Columns: {list(self.data.columns)}")

        # Display basic info
        print(f"\nDataset Info:")
        print(f"- Total records: {len(self.data)}")
        print(f"- Memory usage: {self.data.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

        # Show sample data
        print(f"\nSample Record:")
        if len(self.data) > 0:
            sample = self.data.iloc[0]
            for col in self.data.columns:
                value = str(sample[col])[:100] + "..." if len(str(sample[col])) > 100 else str(sample[col])
                print(f"- {col}: {value}")

        return self.data

    def preprocess_data(self):
        """Step 2: Clean and preprocess the data"""
        print("\n" + "=" * 60)
        print("STEP 2: DATA PREPROCESSING")
        print("=" * 60)

        print("Processing clinical text data...")

        # Select relevant columns
        required_cols = ['transcription']
        optional_cols = ['keywords', 'medical_specialty', 'sample_name']

        available_cols = [col for col in required_cols + optional_cols if col in self.data.columns]
        self.data = self.data[available_cols].copy()

        print(f"✓ Selected columns: {available_cols}")

        # Remove missing transcriptions
        initial_count = len(self.data)
        self.data = self.data.dropna(subset=['transcription'])
        print(f"✓ Removed {initial_count - len(self.data)} records with missing transcriptions")

        # Clean text
        print("✓ Cleaning text data...")
        self.data['clean_text'] = self.data['transcription'].apply(self._clean_clinical_text)

        # Calculate text statistics
        self.data['word_count'] = self.data['clean_text'].apply(lambda x: len(x.split())) 
        self.data['char_count'] = self.data['clean_text'].apply(len)

        # Filter very short texts
        min_words = 50
        initial_count = len(self.data)
        self.data = self.data[self.data['word_count'] >= min_words]
        print(f"✓ Removed {initial_count - len(self.data)} texts with < {min_words} words")

        # Create target summaries
        print("✓ Creating target summaries...")
        if 'keywords' in self.data.columns:
            self.data['target_summary'] = self.data['keywords'].apply(self._keywords_to_summary)
        else:
            self.data['target_summary'] = self.data['clean_text'].apply(self._extractive_summary)

        # Remove records with empty summaries
        initial_count = len(self.data)
        self.data = self.data[self.data['target_summary'].str.len() > 10]
        print(f"✓ Removed {initial_count - len(self.data)} records with inadequate summaries")

        print(f"\nFinal dataset statistics:")
        print(f"- Records: {len(self.data)}")
        print(f"- Avg words per text: {self.data['word_count'].mean():.1f}")
        print(f"- Avg words per summary: {self.data['target_summary'].apply(lambda x: len(x.split())).mean():.1f}")

        return self.data

    def _clean_clinical_text(self, text):
        """Clean clinical text"""
        if pd.isna(text):
            return ""

        text = str(text).lower()

        # Remove clinical section headers
        headers = ['subjective:', 'objective:', 'assessment:', 'plan:', 'medications:',
                  'allergies:', 'heent:', 'neck:', 'lungs:', 'history:', 'physical:']
        for header in headers:
            text = text.replace(header, '')

        # Clean formatting
        text = re.sub(r'\s+', ' ', text)  # Multiple spaces
        text = re.sub(r'[^\w\s\-\.]', ' ', text)  # Keep only alphanumeric, spaces, hyphens, periods
        text = re.sub(r'\b\d+\b', '', text)  # Remove standalone numbers

        return text.strip()

    def _keywords_to_summary(self, keywords):
        """Convert keywords to summary"""
        if pd.isna(keywords) or keywords == "":
            return ""

        keywords = str(keywords).lower()
        keyword_list = [k.strip() for k in re.split(r'[,;]', keywords) if k.strip()]

        if len(keyword_list) == 0:
            return ""

        # Create coherent summary
        if len(keyword_list) == 1:
            return f"Patient presents with {keyword_list[0]}."
        else:
            main_keywords = keyword_list[:4]  # Take first 4 keywords
            return f"Patient presents with {', '.join(main_keywords[:-1])} and {main_keywords[-1]}."

    def _extractive_summary(self, text):
        """Create extractive summary"""
        sentences = sent_tokenize(text)
        if len(sentences) <= 2:
            return text

        # Take first sentence and one from middle
        summary_sentences = [sentences[0]]
        if len(sentences) > 2:
            mid_idx = len(sentences) // 2
            summary_sentences.append(sentences[mid_idx])

        return ' '.join(summary_sentences)

    def visualize_data(self):
        """Step 3: Data visualization and analysis"""
        print("\n" + "=" * 60)
        print("STEP 3: DATA ANALYSIS AND VISUALIZATION")
        print("=" * 60)

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Text length distribution
        axes[0,0].hist(self.data['word_count'], bins=30, alpha=0.7, color='skyblue')
        axes[0,0].set_title('Distribution of Text Length (Words)')
        axes[0,0].set_xlabel('Word Count')
        axes[0,0].set_ylabel('Frequency')

        # Summary length distribution
        summary_lengths = self.data['target_summary'].apply(lambda x: len(x.split()))
        axes[0,1].hist(summary_lengths, bins=20, alpha=0.7, color='lightcoral')
        axes[0,1].set_title('Distribution of Summary Length (Words)')
        axes[0,1].set_xlabel('Word Count')
        axes[0,1].set_ylabel('Frequency')

        # Medical specialty distribution (if available)
        if 'medical_specialty' in self.data.columns:
            top_specialties = self.data['medical_specialty'].value_counts().head(8)
            axes[1,0].pie(top_specialties.values, labels=top_specialties.index, autopct='%1.1f%%')
            axes[1,0].set_title('Top Medical Specialties')
        else:
            axes[1,0].text(0.5, 0.5, 'Medical Specialty\nData Not Available',
                          ha='center', va='center', transform=axes[1,0].transAxes)
            axes[1,0].set_title('Medical Specialties')

        # Text vs Summary length scatter
        axes[1,1].scatter(self.data['word_count'], summary_lengths, alpha=0.6)
        axes[1,1].set_xlabel('Original Text Length (Words)')
        axes[1,1].set_ylabel('Summary Length (Words)')
        axes[1,1].set_title('Text Length vs Summary Length')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"Dataset Statistics:")
        print(f"- Text length: Mean={self.data['word_count'].mean():.1f}, Median={self.data['word_count'].median():.1f}")
        print(f"- Summary length: Mean={summary_lengths.mean():.1f}, Median={summary_lengths.median():.1f}")

    def prepare_model(self):
        """Step 4: Load and prepare the T5 model"""
        print("\n" + "=" * 60)
        print("STEP 4: MODEL PREPARATION")
        print("=" * 60)

        print(f"Loading T5 model: {self.model_name}")

        # Load tokenizer and model
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
        self.model.to(self.device)

        print(f"✓ Model loaded successfully on {self.device}")

        # Create summarization pipeline
        self.summarizer = pipeline(
            "summarization",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if self.device.type == 'cuda' else -1
        )

        print(f"✓ Summarization pipeline created")

        # Model info
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"✓ Model parameters: {total_params:,}")

    def split_data(self, test_size=0.2):
        """Step 5: Split data for training/testing"""
        print("\n" + "=" * 60)
        print("STEP 5: DATA SPLITTING")
        print("=" * 60)

        texts = self.data['clean_text'].tolist()
        summaries = self.data['target_summary'].tolist()

        X_train, X_test, y_train, y_test = train_test_split(
            texts, summaries, test_size=test_size, random_state=42
        )

        print(f"✓ Data split completed:")
        print(f"  - Training samples: {len(X_train)}")
        print(f"  - Testing samples: {len(X_test)}")
        print(f"  - Test ratio: {test_size:.1%}")

        return X_train, X_test, y_train, y_test

    def generate_summaries(self, texts, batch_size=8):
        """Step 6: Generate summaries"""
        print("\n" + "=" * 60)
        print("STEP 6: SUMMARY GENERATION")
        print("=" * 60)

        print(f"Generating summaries for {len(texts)} texts...")

        summaries = []

        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_summaries = []

            for text in batch_texts:
                try:
                    # Prepare input
                    input_text = f"summarize: {text[:512]}"  # Limit input length

                    # Generate summary
                    summary = self.summarizer(
                        input_text,
                        max_length=100,
                        min_length=20,
                        do_sample=False,
                        early_stopping=True
                    )[0]['summary_text']

                    batch_summaries.append(summary)

                except Exception as e:
                    print(f"Error generating summary: {e}")
                    batch_summaries.append("Summary generation failed.")

            summaries.extend(batch_summaries)

            if (i + batch_size) % 50 == 0 or (i + batch_size) >= len(texts):
                print(f"✓ Processed {min(i + batch_size, len(texts))}/{len(texts)} texts")

        print(f"✓ Summary generation completed!")
        return summaries

    def evaluate_model(self, predictions, references):
        """Step 7: Evaluate model performance"""
        print("\n" + "=" * 60)
        print("STEP 7: MODEL EVALUATION")
        print("=" * 60)

        if ROUGE_AVAILABLE:
            print("Calculating ROUGE scores...")

            scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

            rouge1_scores = []
            rouge2_scores = []
            rougeL_scores = []

            for pred, ref in zip(predictions, references):
                scores = scorer.score(str(ref), str(pred))
                rouge1_scores.append(scores['rouge1'].fmeasure)
                rouge2_scores.append(scores['rouge2'].fmeasure)
                rougeL_scores.append(scores['rougeL'].fmeasure)

            # Calculate averages
            avg_rouge1 = np.mean(rouge1_scores)
            avg_rouge2 = np.mean(rouge2_scores)
            avg_rougeL = np.mean(rougeL_scores)

            print(f"✓ ROUGE Evaluation Results:")
            print(f"  - ROUGE-1: {avg_rouge1:.4f}")
            print(f"  - ROUGE-2: {avg_rouge2:.4f}")
            print(f"  - ROUGE-L: {avg_rougeL:.4f}")

            # Visualize results
            plt.figure(figsize=(10, 6))

            plt.subplot(1, 2, 1)
            rouge_scores = [avg_rouge1, avg_rouge2, avg_rougeL]
            rouge_names = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
            bars = plt.bar(rouge_names, rouge_scores, color=['skyblue', 'lightcoral', 'lightgreen'])
            plt.title('Average ROUGE Scores')
            plt.ylabel('Score')
            plt.ylim(0, 1)

            # Add value labels on bars
            for bar, score in zip(bars, rouge_scores):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')

            plt.subplot(1, 2, 2)
            plt.hist(rouge1_scores, bins=20, alpha=0.7, label='ROUGE-1')
            plt.hist(rougeL_scores, bins=20, alpha=0.7, label='ROUGE-L')
            plt.xlabel('Score')
            plt.ylabel('Frequency')
            plt.title('Distribution of ROUGE Scores')
            plt.legend()

            plt.tight_layout()
            plt.show()

            return {
                'rouge1': avg_rouge1,
                'rouge2': avg_rouge2,
                'rougeL': avg_rougeL
            }

        else:
            # Alternative evaluation metrics when ROUGE is not available
            print("Using alternative evaluation metrics...")

            # BLEU-like score (simple n-gram overlap)
            bleu_scores = []
            length_ratios = []

            for pred, ref in zip(predictions, references):
                pred_words = set(str(pred).lower().split())
                ref_words = set(str(ref).lower().split())

                # Calculate word overlap (precision-like metric)
                if len(pred_words) > 0:
                    overlap = len(pred_words.intersection(ref_words)) / len(pred_words)
                else:
                    overlap = 0.0

                bleu_scores.append(overlap)

                # Length ratio
                if len(str(ref).split()) > 0:
                    length_ratio = len(str(pred).split()) / len(str(ref).split())
                else:
                    length_ratio = 1.0
                length_ratios.append(length_ratio)

            avg_bleu = np.mean(bleu_scores)
            avg_length_ratio = np.mean(length_ratios)

            print(f"✓ Alternative Evaluation Results:")
            print(f"  - Word Overlap Score: {avg_bleu:.4f}")
            print(f"  - Average Length Ratio: {avg_length_ratio:.4f}")
            print(f"  - Length Consistency: {1 - abs(1 - avg_length_ratio):.4f}")

            # Visualize results
            plt.figure(figsize=(10, 6))

            plt.subplot(1, 2, 1)
            metrics = [avg_bleu, 1 - abs(1 - avg_length_ratio)]
            metric_names = ['Word Overlap', 'Length Consistency']
            bars = plt.bar(metric_names, metrics, color=['skyblue', 'lightcoral'])
            plt.title('Alternative Evaluation Metrics')
            plt.ylabel('Score')
            plt.ylim(0, 1)

            for bar, score in zip(bars, metrics):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')

            plt.subplot(1, 2, 2)
            plt.hist(bleu_scores, bins=20, alpha=0.7, label='Word Overlap')
            plt.hist(length_ratios, bins=20, alpha=0.7, label='Length Ratio')
            plt.xlabel('Score')
            plt.ylabel('Frequency')
            plt.title('Distribution of Scores')
            plt.legend()

            plt.tight_layout()
            plt.show()

            return {
                'word_overlap': avg_bleu,
                'length_ratio': avg_length_ratio,
                'length_consistency': 1 - abs(1 - avg_length_ratio)
            }

    def show_sample_results(self, original_texts, generated_summaries, reference_summaries, n_samples=3):
        """Step 8: Display sample results"""
        print("\n" + "=" * 60)
        print("STEP 8: SAMPLE RESULTS")
        print("=" * 60)

        for i in range(min(n_samples, len(original_texts))):
            print(f"\n{'='*20} SAMPLE {i+1} {'='*20}")
            print(f"\nORIGINAL TEXT:")
            print(f"{original_texts[i][:300]}...")
            print(f"\nGENERATED SUMMARY:")
            print(f"{generated_summaries[i]}")
            print(f"\nREFERENCE SUMMARY:")
            print(f"{reference_summaries[i]}")
            print("-" * 60)

    def run_complete_pipeline(self, filepath):
        """Run the complete end-to-end pipeline"""
        print("🚀 STARTING CLINICAL TEXT SUMMARIZATION PIPELINE")
        print("=" * 80)

        try:
            # Step 1: Load and explore data
            self.load_and_explore_data(filepath)

            # Step 2: Preprocess data
            self.preprocess_data()

            # Step 3: Visualize data
            self.visualize_data()

            # Step 4: Prepare model
            self.prepare_model()

            # Step 5: Split data
            X_train, X_test, y_train, y_test = self.split_data()

            # Step 6: Generate summaries (using subset for demo)
            test_subset = min(50, len(X_test))  # Limit for demo
            test_texts = X_test[:test_subset]
            test_references = y_test[:test_subset]

            generated_summaries = self.generate_summaries(test_texts)

            # Step 7: Evaluate model
            evaluation_results = self.evaluate_model(generated_summaries, test_references)

            # Step 8: Show sample results
            self.show_sample_results(test_texts, generated_summaries, test_references)

            print("\n" + "=" * 80)
            print("🎉 PIPELINE COMPLETED SUCCESSFULLY!")
            print("=" * 80)
            print(f"✓ Processed {len(self.data)} clinical records")
            print(f"✓ Generated {len(generated_summaries)} summaries")
            print(f"✓ Average ROUGE-1 Score: {evaluation_results['rouge1']:.4f}")

            return {
                'model': self,
                'summaries': generated_summaries,
                'evaluation': evaluation_results,
                'test_data': (test_texts, test_references)
            }

        except Exception as e:
            print(f"❌ Pipeline failed with error: {str(e)}")
            raise e

    def summarize_new_text(self, clinical_text):
        """Summarize new clinical text"""
        if self.summarizer is None:
            raise ValueError("Model not loaded. Please run the pipeline first.")

        # Clean the text
        cleaned_text = self._clean_clinical_text(clinical_text)

        # Generate summary
        input_text = f"summarize: {cleaned_text[:512]}"
        summary = self.summarizer(
            input_text,
            max_length=100,
            min_length=20,
            do_sample=False
        )[0]['summary_text']

        return summary

# Main execution function
def main():
    """Main function to run the complete pipeline"""

    # Initialize the summarizer
    summarizer = ClinicalSummarizer()

    # File path for your data
    data_file = "mtsamples.csv"  # Change this to your file path

    try:
        # Run the complete pipeline
        results = summarizer.run_complete_pipeline(data_file)

        # Example of using the trained model on new text
        print("\n" + "=" * 60)
        print("TESTING WITH NEW CLINICAL TEXT")
        print("=" * 60)

        new_text = """
        SUBJECTIVE: This 45-year-old male presents with chest pain that started
        2 hours ago. The pain is described as crushing and radiates to the left arm.
        Patient has history of hypertension and diabetes.
        OBJECTIVE: Vital signs stable, BP 140/90, HR 88. EKG shows ST elevation
        in leads II, III, aVF.
        ASSESSMENT: Acute myocardial infarction.
        PLAN: Immediate cardiac catheterization and PCI.
        """

        new_summary = summarizer.summarize_new_text(new_text)
        print(f"New Clinical Text Summary: {new_summary}")

        return results

    except FileNotFoundError:
        print(f"❌ Error: File '{data_file}' not found!")
        print("\n📋 To use this pipeline:")
        print("1. Place your CSV file in the same directory")
        print("2. Ensure it has a 'transcription' column with clinical text")
        print("3. Optional: 'keywords' column for reference summaries")
        print("4. Update the 'data_file' variable with your filename")
        print("5. Run the script again")

    except Exception as e:
        print(f"❌ Unexpected error: {str(e)}")
        raise e

if __name__ == "__main__":
    # Install required packages if needed
    print("Clinical Text Summarization Pipeline")
    print("Required packages: pandas, numpy, nltk, matplotlib, seaborn, transformers, torch, rouge-score")
    print("-" * 80)

    # Run the main pipeline
    results = main()