In [1]:
import os
import sys
import re
# Get the absolute path to the parent directory (assumes this file is in 'condensation')
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)
# autoreload modules
%load_ext autoreload
%autoreload 2
from typing import List, Optional, Dict, Any
import datetime
from dataclasses import dataclass
import pandas as pd
from chatbot_api.providers.openai import OpenAIProvider
from chatbot_api.base import Role, Message
import json

In [2]:
class Argument:
    def __init__(self, main_point: str, subpoints: List[str], score_distribution: List[int], source_indices: List[int]):
        self.main_point = main_point                 # The main argument which is a broad, high-level point
        self.subpoints = subpoints                   # Supporting subpoints
        self.score_distribution = score_distribution # Likert score distribution
        self.source_indices = source_indices         # Indices of the source comments that support the argument (at least in some way)

@dataclass
class BatchMetadata:
    """Metadata for each batch processing iteration"""
    batch_number: int
    timestamp: datetime
    processed_indices: List[int]
    processed_comments: List[str]
    arguments_after_batch: List[Argument]
    batch_likert_scores: Optional[List[int]] = None
    metrics: Optional[Dict[str, Any]] = None  # For storing additional metrics like processing time, token usage etc.


In [3]:
class ArgumentProcessor:
    def __init__(self, llm_provider: OpenAIProvider, batch_size: int):
        self.llm_provider = llm_provider
        self.batch_size = batch_size
        self.current_arguments: List[Argument] = []
        self.processing_history: List[BatchMetadata] = []
        
        self.INITIAL_BATCH_TEMPLATE = """
            ### Instructions:
            1. Review the following numbered comments on the topic: "{topic}"
            2. Identify and categorize arguments into main topics.
                - Main arguments are broad, high-level points.
                - Subpoints provide supporting details or reasoning for the main argument.
            3. For each argument and subpoint, include the indices of the source comments in square brackets.
            4. Provide the output in Finnish.

            ### Output Format: 
            <ARGUMENTS>
            MAIN: [Main Argument] [source_indices: 1,2,3]
            SUB: [Supporting subpoint] [source_indices: 1,2]
            SUB: [Another supporting subpoint] [source_indices: 3]
            </ARGUMENTS>

            ### Comments to analyze:
            {comments_text}
            """
            
        self.SYNTHESIS_TEMPLATE = """
            ### Instructions:
            1. Review the existing arguments and new comments on the topic: "{topic}"
            2. Perform a comprehensive synthesis:
                a. First, consider if the new comments suggest any major themes or perspectives that aren't captured in the existing arguments
                b. Evaluate if existing main arguments should be:
                - Combined if they represent closely related ideas
                - Split if they contain distinct themes that deserve separate focus
                - Reworded to better capture the full scope of ideas, including new comments
                - Removed if they're no longer representative of the broader discussion
                c. Only after restructuring the main arguments, organize supporting points under them
            3. For each argument and subpoint, include ALL relevant source indices (both from existing and new comments)
            4. Provide the output in Finnish.

            ### Important:
            - Don't feel constrained by the existing argument structure
            - New comments might reveal better ways to organize and present the overall discussion
            - Main arguments should reflect the most important themes across ALL comments

            ### Existing Arguments:
            {existing_arguments}

            ### New Comments to Analyze:
            {new_comments}

            ### Output Format:
            <ARGUMENTS>
            MAIN: [Main Argument] [source_indices: 1,2,3]
            SUB: [Supporting subpoint] [source_indices: 1,2]
            SUB: [Another supporting subpoint] [source_indices: 3]
            </ARGUMENTS>
            """

    async def process_all_comments(self, 
                                comments: List[str], 
                                comment_indices: List[int], 
                                topic: str, 
                                likert_answers: Optional[List[int]] = None) -> List[Argument]:
        """Process all comments in batches, maintaining argument synthesis across batches."""
        self.processing_history = []  # Reset history for new processing run
        
        # Process comments in batches of batch_size
        for batch_num, i in enumerate(range(0, len(comments), self.batch_size)):
            batch_comments = comments[i:i + self.batch_size]
            batch_indices = comment_indices[i:i + self.batch_size]
            batch_likert = likert_answers[i:i + self.batch_size] if likert_answers else None
            
            start_time = datetime.datetime.now()
            
            if not self.current_arguments:
                # First batch: Use initial template
                new_arguments = await self._process_initial_batch(
                    batch_comments, batch_indices, topic, batch_likert
                )
            else:
                # Subsequent batches: Synthesize with existing arguments
                new_arguments = await self._synthesize_batch(
                    batch_comments, batch_indices, topic, batch_likert
                )
            
            # Update current arguments
            self.current_arguments = new_arguments
            
            # Record batch metadata
            processing_time = (datetime.datetime.now() - start_time).total_seconds()
            
            batch_metadata = BatchMetadata(
                batch_number=batch_num,
                timestamp=datetime.datetime.now(),
                processed_indices=batch_indices,
                processed_comments=batch_comments,
                arguments_after_batch=new_arguments.copy(),
                batch_likert_scores=batch_likert,
                metrics={
                    'processing_time_seconds': processing_time,
                    'num_comments': len(batch_comments),
                    'num_arguments': len(new_arguments),
                    'total_subpoints': sum(len(arg.subpoints) for arg in new_arguments),
                }
            )
            
            self.processing_history.append(batch_metadata)
            
        return self.current_arguments

    def get_argument_evolution(self, argument_index: int) -> List[str]:
        """
        Track how a specific argument evolved across batches.
        Returns a list of the argument's main point at each iteration.
        """
        evolution = []
        for batch in self.processing_history:
            if argument_index < len(batch.arguments_after_batch):
                evolution.append(batch.arguments_after_batch[argument_index].main_point)
        return evolution

    def get_batch_statistics(self) -> Dict[str, List[Any]]:
        """
        Get statistics about each batch processing iteration.
        """
        stats = {
            'batch_numbers': [],
            'timestamps': [],
            'num_comments': [],
            'num_arguments': [],
            'processing_times': [],
            'total_subpoints': []
        }
        
        for batch in self.processing_history:
            stats['batch_numbers'].append(batch.batch_number)
            stats['timestamps'].append(batch.timestamp)
            stats['num_comments'].append(batch.metrics['num_comments'])
            stats['num_arguments'].append(batch.metrics['num_arguments'])
            stats['processing_times'].append(batch.metrics['processing_time_seconds'])
            stats['total_subpoints'].append(batch.metrics['total_subpoints'])
            
        return stats

    def export_processing_history(self, filepath: str):
        """
        Export the processing history to a JSON file.
        """
        history_data = []
        for batch in self.processing_history:
            # Convert arguments to serializable format
            arguments_data = []
            for arg in batch.arguments_after_batch:
                arguments_data.append({
                    'main_point': arg.main_point,
                    'subpoints': arg.subpoints,
                    'source_indices': arg.source_indices,
                    'score_distribution': arg.score_distribution
                })
            
            batch_data = {
                'batch_number': batch.batch_number,
                'timestamp': batch.timestamp.isoformat(),
                'processed_indices': batch.processed_indices,
                'processed_comments': batch.processed_comments,
                'arguments': arguments_data,
                'metrics': batch.metrics
            }
            history_data.append(batch_data)
            
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(history_data, f, ensure_ascii=False, indent=2)

    async def _process_initial_batch(self, 
                                   comments: List[str], 
                                   indices: List[int], 
                                   topic: str,
                                   likert_answers: List[int] = None) -> List[Argument]:
        """Process the first batch of comments."""
        numbered_comments = [f"[{idx}] {comment}" for idx, comment in zip(indices, comments)]
        comments_text = "\n".join(numbered_comments)
        
        prompt = self.INITIAL_BATCH_TEMPLATE.format(
            topic=topic,
            comments_text=comments_text
        )
        
        response = await self.llm_provider.generate([Message(Role.USER, prompt)])
        return await self._parse_arguments(response.content, likert_answers)

    async def _synthesize_batch(self,
                              new_comments: List[str],
                              new_indices: List[int],
                              topic: str,
                              likert_answers: Optional[List[int]] = None) -> List[Argument]:
        """Synthesize new comments with existing arguments."""
        # Format existing arguments for the prompt
        existing_args_text = self._format_arguments_for_synthesis(self.current_arguments)
        
        # Format new comments
        numbered_comments = [f"[{idx}] {comment}" for idx, comment in zip(new_indices, new_comments)]
        new_comments_text = "\n".join(numbered_comments)
        
        prompt = self.SYNTHESIS_TEMPLATE.format(
            topic=topic,
            existing_arguments=existing_args_text,
            new_comments=new_comments_text
        )
        
        response = await self.llm_provider.generate([Message(Role.USER, prompt)])
        return await self._parse_arguments(response.content, likert_answers or [])

    def _format_arguments_for_synthesis(self, arguments: List[Argument]) -> str:
        """Format existing arguments for the synthesis prompt."""
        formatted = []
        for arg in arguments:
            formatted.append(f"MAIN: {arg.main_point} [source_indices: {','.join(map(str, arg.source_indices))}]")
            for sub in arg.subpoints:
                formatted.append(f"SUB: {sub} [source_indices: {','.join(map(str, arg.source_indices))}]")
        return "\n".join(formatted)

    async def _parse_arguments(self, response: str, likert_answers: List[int]) -> List[Argument]:
        """Parse LLM response into Argument objects with score distributions."""
        arguments = []
        current_main = None
        current_subpoints = []
        current_indices = set()

        pattern = r'<ARGUMENTS>(.*?)</ARGUMENTS>'
        match = re.search(pattern, response, re.DOTALL)

        # return early if the correct regex pattern is not found in the answer
        if not match:
            return arguments

        lines = match.group(1).strip().split('\n')
        
        for line in lines:
            line = line.strip()
            if not line:
                continue

            indices_match = re.search(r'\[source_indices: ([\d,\s]+)\]', line)
            indices = [int(idx.strip()) for idx in indices_match.group(1).split(',')] if indices_match else []
            
            content = re.sub(r'\[source_indices: [\d,\s]+\]', '', line).strip()

            if line.startswith('MAIN:'):
                if current_main is not None:
                    score_distribution = self._calculate_score_distribution(
                        self._get_relevant_scores(list(current_indices), likert_answers)
                    )
                    arguments.append(Argument(
                        main_point=current_main,
                        subpoints=current_subpoints,
                        score_distribution=score_distribution,
                        source_indices=list(current_indices)
                    ))

                current_main = content.replace('MAIN:', '').strip()
                current_subpoints = []
                current_indices = set(indices)
            
            elif line.startswith('SUB:'):
                subpoint = content.replace('SUB:', '').strip()
                current_subpoints.append(subpoint)
                current_indices.update(indices)

        # Add the last argument
        if current_main is not None:
            score_distribution = self._calculate_score_distribution(
                self._get_relevant_scores(list(current_indices), likert_answers)
            )
            arguments.append(Argument(
                main_point=current_main,
                subpoints=current_subpoints,
                score_distribution=score_distribution,
                source_indices=list(current_indices)
            ))

        return arguments

    def _get_relevant_scores(self, indices: List[int], all_likert_scores: List[int]) -> List[int]:
        """Get Likert scores for given indices."""
        return [all_likert_scores[idx] for idx in indices if idx < len(all_likert_scores)]

    def _calculate_score_distribution(self, scores: List[int]) -> List[int]:
        """Calculate distribution of Likert scores."""
        if not scores:
            return [0] * 5

        distribution = [0] * 5
        for score in scores:
            if 1 <= score <= 5:
                distribution[int(score-1)] += 1

        return distribution
    
    async def format_arguments(self, arguments: List[Argument]) -> str:
        formatted_output = ["-" * 50]
        for i, arg in enumerate(arguments, 1):
            formatted_output.append(f"Argument {i}: {arg.main_point}\n")
            
            if arg.subpoints:
                formatted_output.append("Supporting points:")
                for j, subpoint in enumerate(arg.subpoints, 1):
                    formatted_output.append(f"  {j}. {subpoint}")
            
            if arg.score_distribution:
                formatted_output.append("\nLikert Score Distribution (calculated from indices given by the LLM):")
                for score_nominality, prevalence in enumerate(arg.score_distribution, 1):
                    formatted_output.append(f"  Score {score_nominality}: {prevalence:} answers")

            formatted_output.append(f"\nSource indices: {arg.source_indices}\n")
            
            formatted_output.append("-" * 50)
        
        return "\n".join(formatted_output)

In [4]:
async def main():
    # Params
    n_comments = 200  # Number of comments to process overall
    batch_size = 20  # Number of comments per batch
    topic = "Kun kunnan menoja ja tuloja tasapainotetaan, se on tehtävä mieluummin menoja karsimalla kuin veroja kiristämällä."

    # Config
    api_key = os.getenv("OPENAI_API_KEY")
    model = "gpt-4o-2024-11-20"
    openai_provider = OpenAIProvider(api_key, model)
    processor = ArgumentProcessor(openai_provider, batch_size)

    # Setup paths
    data_source_path = os.path.join(parent_dir, 'data', 'sources', 'kuntavaalit2021.csv')
    output_base_path = os.path.join(parent_dir, 'condensation', 'results', 'synthesis')
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Create output directories if they don't exist
    os.makedirs(output_base_path, exist_ok=True)
    
    # Define output paths
    results_path = os.path.join(output_base_path, f'arguments_{timestamp}.txt')
    history_path = os.path.join(output_base_path, f'processing_history_{timestamp}.json')
    analysis_path = os.path.join(output_base_path, f'analysis_{timestamp}.txt')

    # Read and prepare data
    df = pd.read_csv(data_source_path)

    # Choose a subset of comments to process
    question_index = 10
    explanation_column_name = f'q{question_index}.explanation_fi'
    likert_column_name = f'q{question_index}.answer'

    # Get comments and their original indices
    comment_mask = df[explanation_column_name].notna()  # Only process comments with content
    comment_indices = df[comment_mask][explanation_column_name].index[:n_comments].tolist()
    comments = df.loc[comment_indices, explanation_column_name].tolist()
    likert_answers = df.loc[comment_indices, likert_column_name].tolist()

    print(f"Starting processing of {len(comments)} comments in batches of {batch_size}")
    
    # Process all comments in batches
    start_time = datetime.datetime.now()
    final_arguments = await processor.process_all_comments(
        comments=comments,
        comment_indices=comment_indices,
        topic=topic,
        likert_answers=likert_answers
    )
    processor.current_arguments = final_arguments
    total_processing_time = (datetime.datetime.now()- start_time).total_seconds()

    # Export processing history
    processor.export_processing_history(history_path)
    
    # Get batch statistics
    stats = processor.get_batch_statistics()
    
    # Format and save main results
    formatted_args = await processor.format_arguments(final_arguments)
    with open(results_path, 'w', encoding='utf-8') as f:
        f.write(f"Processing Summary:\n")
        f.write(f"- Total comments processed: {len(comments)}\n")
        f.write(f"- Number of batches: {len(stats['batch_numbers'])}\n")
        f.write(f"- Total processing time: {total_processing_time:.2f} seconds\n")
        f.write(f"- Final number of arguments: {len(final_arguments)}\n\n")
        f.write(formatted_args)

    # for each comment index, check how many arguments it was used in when creating the arguments
    index_usage = {}
    for i, arg in enumerate(final_arguments):
        for idx in arg.source_indices:
            if idx not in index_usage:
                index_usage[idx] = 0
            index_usage[idx] += 1

    # Save detailed analysis
    with open(analysis_path, 'w', encoding='utf-8') as f:
        # Write overall statistics
        f.write("=== Processing Statistics ===\n")
        f.write(f"Total processing time: {total_processing_time:.2f} seconds\n")
        f.write(f"Average batch processing time: {sum(stats['processing_times'])/len(stats['processing_times']):.2f} seconds\n")
        f.write(f"Total comments processed: {sum(stats['num_comments'])}\n\n")
        
        # Write argument evolution analysis
        f.write("\n=== Argument Evolution ===\n")
        for i, arg in enumerate(final_arguments):
            f.write(f"\nArgument {i+1} Evolution:\n")
            evolution = processor.get_argument_evolution(i)
            for j, version in enumerate(evolution):
                f.write(f"Batch {j}: {version}\n")

        # Write index usage analysis
        f.write("\n=== Index Usage Analysis ===\n")
        for idx, count in index_usage.items():
            f.write(f"Comment index {idx} used in {count} arguments\n")
        
        # Write sample validation of final arguments
        f.write("\n=== Sample Validation of Final Arguments ===\n")
        for i, arg in enumerate(final_arguments):
            f.write(f"\nArgument {i+1}: {arg.main_point}\n")
            f.write("Sample supporting comments:\n")
            for idx in arg.source_indices[:5]:  # Show first 5 supporting comments
                comment = df.loc[idx, explanation_column_name]
                f.write(f"[{idx}] {comment}\n")

        # Write batch-by-batch statistics
        f.write("=== Batch-by-Batch Statistics ===\n")
        for i in range(len(stats['batch_numbers'])):
            f.write(f"\nBatch {stats['batch_numbers'][i]}:\n")
            f.write(f"- Timestamp: {stats['timestamps'][i]}\n")
            f.write(f"- Comments processed: {stats['num_comments'][i]}\n")
            f.write(f"- Arguments after batch: {stats['num_arguments'][i]}\n")
            f.write(f"- Processing time: {stats['processing_times'][i]:.2f} seconds\n")
            f.write(f"- Total subpoints: {stats['total_subpoints'][i]}\n")
        

    print(f"\nProcessing complete!")
    print(f"Results saved to: {results_path}")
    print(f"Processing history saved to: {history_path}")
    print(f"Detailed analysis saved to: {analysis_path}")

# run
await main()

Starting processing of 200 comments in batches of 20

Processing complete!
Results saved to: /Users/max/projects/learning/aalto/openvaa/argument-condensation/condensation/results/synthesis/arguments_20250101_153326.txt
Processing history saved to: /Users/max/projects/learning/aalto/openvaa/argument-condensation/condensation/results/synthesis/processing_history_20250101_153326.json
Detailed analysis saved to: /Users/max/projects/learning/aalto/openvaa/argument-condensation/condensation/results/synthesis/analysis_20250101_153326.txt
