In [1]:
import numpy as np
from collections import defaultdict, Counter
import math
from typing import List, Dict, Tuple, Set, Optional, Any
import networkx as nx
import matplotlib.pyplot as plt
import random

class POSGraphWithAttention:
    """
    Implements a graph-based structure for POS sequence processing with predictive coding
    and attention mechanisms for dynamic precision weighting.
    """
    
    def __init__(self, predefined_boundaries: Optional[Dict[Tuple[str, str], float]] = None):
        """
        Initialize the POS graph with attention mechanisms.
        
        Args:
            predefined_boundaries: Optional dictionary of predefined boundary probabilities
        """
        # Main transition graph
        self.graph = nx.DiGraph()
        
        # Higher-order chunk graph for learned patterns
        self.chunk_graph = nx.DiGraph()
        
        # Track n-gram counts for training
        self.unigram_counts = Counter()
        self.bigram_counts = defaultdict(Counter)
        self.trigram_counts = defaultdict(lambda: defaultdict(Counter))
        
        # Boundary probabilities
        self.boundary_probs = defaultdict(float)
        
        # Predefined linguistic rules
        self.predefined_boundaries = predefined_boundaries or {
            ('NOUN', 'VERB'): 0.9,      # NP to VP transition
            ('VERB', 'DET'): 0.8,       # VP to NP transition
            ('PUNCT', 'DET'): 0.95,     # Punctuation followed by determiner
            ('NOUN', 'PREP'): 0.7,      # NP to PP transition
            ('VERB', 'PREP'): 0.6,      # VP to PP transition
            ('ADJ', 'NOUN'): 0.2,       # Within NP (low boundary probability)
            ('DET', 'ADJ'): 0.1,        # Within NP (very low boundary probability)
        }
        
        # Thresholds
        self.hard_boundary_threshold = 0.75
        self.soft_boundary_threshold = 0.4
        
        # Discovered chunks
        self.common_chunks = {}
        
        # Add special start and end nodes
        self.graph.add_node("<START>", pos_type="special")
        self.graph.add_node("<END>", pos_type="special")
        
        # Attention mechanisms
        self.attention_weights = {}  # POS tag attention weights
        self.chunk_attention_weights = {}  # Chunk attention weights
        self.surprisal_history = []  # Track surprisal for adaptive attention
        self.learning_rate = 0.1  # Base learning rate
        self.attention_learning_rate = 0.05  # For updating attention weights
        
        # Precision weighting parameters
        self.base_precision = 1.0
        self.max_precision = 5.0
        self.min_precision = 0.2
        
        # Context tracking
        self.context_history = []  # Track recent contexts for attention modulation

    def train(self, pos_sequences: List[List[str]], epochs: int = 1):
        """
        Train the POS graph on a corpus of POS tag sequences with attention.
        
        Args:
            pos_sequences: List of POS tag sequences, each representing a sentence
            epochs: Number of training epochs
        """
        print(f"Training on {len(pos_sequences)} sequences for {epochs} epochs")
        
        # 1. Initialize attention weights uniformly
        self._initialize_attention_weights(pos_sequences)
        
        # 2. Build initial graph and collect n-gram statistics
        self._build_initial_graph(pos_sequences)
        
        # 3. Iterative training with attention modulation
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            
            # Shuffle sequences for stochastic training
            random.shuffle(pos_sequences)
            
            epoch_surprisal = 0.0
            
            # Process each sequence with attention
            for sequence in pos_sequences:
                # Forward pass - calculate predictions and errors
                sequence_surprisal, prediction_errors = self._forward_pass(sequence)
                epoch_surprisal += sequence_surprisal
                
                # Update attention weights based on prediction errors
                self._update_attention_weights(sequence, prediction_errors)
                
                # Update graph weights with attention-modulated learning
                self._update_graph_weights(sequence, prediction_errors)
            
            avg_surprisal = epoch_surprisal / len(pos_sequences)
            self.surprisal_history.append(avg_surprisal)
            print(f"  Average surprisal: {avg_surprisal:.4f}")
            
            # Recalculate edge weights based on updated counts
            self._calculate_edge_weights()
            
            # After each epoch, update boundary probabilities
            self._calculate_boundary_probabilities()
        
        # 4. Identify common chunks with attention influence
        self._identify_common_chunks(pos_sequences)
        
        # 5. Build chunk graph with attention-weighted connections
        self._build_chunk_graph()
        
        print(f"Training complete.")
        print(f"POS graph has {len(self.graph.nodes)} nodes and {len(self.graph.edges)} edges")
        print(f"Chunk graph has {len(self.chunk_graph.nodes)} nodes and {len(self.chunk_graph.edges)} edges")
        
        # Report top attention weights
        self._report_attention_weights()

    def _initialize_attention_weights(self, pos_sequences: List[List[str]]):
        """Initialize attention weights for all POS tags."""
        # Extract all unique POS tags from sequences
        unique_pos = set()
        for sequence in pos_sequences:
            unique_pos.update(sequence)
        
        # Initialize with uniform weights
        for pos in unique_pos:
            self.attention_weights[pos] = 1.0
            
        # Special start and end nodes
        self.attention_weights["<START>"] = 1.0
        self.attention_weights["<END>"] = 1.0

    def _build_initial_graph(self, pos_sequences: List[List[str]]):
        """Build the initial graph structure and collect statistics."""
        # First pass - add all nodes and count statistics
        for sequence in pos_sequences:
            # Add nodes for each unique POS tag
            for pos in sequence:
                if not self.graph.has_node(pos):
                    self.graph.add_node(pos, pos_type="basic", precision=self.base_precision)
                self.unigram_counts[pos] += 1
            
            # Count bigrams and add edges
            for i in range(len(sequence) - 1):
                pos1, pos2 = sequence[i], sequence[i+1]
                self.bigram_counts[pos1][pos2] += 1
                
                # Ensure edge exists (weight will be calculated later)
                if not self.graph.has_edge(pos1, pos2):
                    self.graph.add_edge(pos1, pos2, weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                
                # Increment edge count
                self.graph[pos1][pos2]["count"] += 1
            
            # Add connections from start and to end
            if sequence:
                if not self.graph.has_edge("<START>", sequence[0]):
                    self.graph.add_edge("<START>", sequence[0], weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.graph["<START>"][sequence[0]]["count"] += 1
                
                if not self.graph.has_edge(sequence[-1], "<END>"):
                    self.graph.add_edge(sequence[-1], "<END>", weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.graph[sequence[-1]]["<END>"]["count"] += 1
            
            # Count trigrams
            for i in range(len(sequence) - 2):
                pos1, pos2, pos3 = sequence[i], sequence[i+1], sequence[i+2]
                self.trigram_counts[pos1][pos2][pos3] += 1

    def _forward_pass(self, sequence: List[str]) -> Tuple[float, List[float]]:
        """
        Process a sequence and calculate prediction errors with current model.
        
        Args:
            sequence: A sequence of POS tags
            
        Returns:
            Tuple of (total_surprisal, list_of_prediction_errors)
        """
        total_surprisal = 0.0
        prediction_errors = []
        
        # Start with START node
        current_pos = "<START>"
        
        # Process each position in the sequence
        for pos in sequence:
            # Calculate prediction probability for this position
            prediction_prob = 0.0
            if self.graph.has_edge(current_pos, pos):
                prediction_prob = self.graph[current_pos][pos].get("weight", 0.0)
            
            # Calculate surprisal (-log probability)
            if prediction_prob > 0:
                surprisal = -math.log2(prediction_prob)
            else:
                surprisal = 10.0  # High surprisal for unseen transitions
                
            # Get current precision for this transition
            precision = self.base_precision
            if self.graph.has_edge(current_pos, pos):
                precision = self.graph[current_pos][pos].get("precision", self.base_precision)
            
            # Calculate precision-weighted prediction error
            prediction_error = surprisal * precision
            
            prediction_errors.append(prediction_error)
            total_surprisal += surprisal
            
            # Update current position
            current_pos = pos
        
        return total_surprisal, prediction_errors

    def _update_attention_weights(self, sequence: List[str], prediction_errors: List[float]):
        """
        Update attention weights based on prediction errors.
        Higher errors lead to increased attention.
        
        Args:
            sequence: The POS sequence
            prediction_errors: Corresponding prediction errors for transitions
        """
        # Normalize prediction errors to [0,1] range for attention updates
        if prediction_errors:
            max_error = max(prediction_errors)
            if max_error > 0:
                normalized_errors = [error / max_error for error in prediction_errors]
            else:
                normalized_errors = [0.0] * len(prediction_errors)
            
            # Update attention for START node and first token
            self.attention_weights["<START>"] = (1 - self.attention_learning_rate) * self.attention_weights["<START>"] + \
                                      self.attention_learning_rate * normalized_errors[0]
            
            # Update attention weights for each POS tag based on prediction errors
            for i, pos in enumerate(sequence):
                # Current position's error influences its attention
                if i < len(normalized_errors):
                    error_weight = normalized_errors[i]
                    
                    # Update attention weight with learning rate
                    self.attention_weights[pos] = (1 - self.attention_learning_rate) * self.attention_weights.get(pos, 1.0) + \
                                          self.attention_learning_rate * error_weight
                    
                    # Ensure attention weights stay in reasonable range
                    self.attention_weights[pos] = max(0.2, min(3.0, self.attention_weights[pos]))

    def _update_graph_weights(self, sequence: List[str], prediction_errors: List[float]):
        """
        Update graph edge weights based on attention-modulated learning.
        
        Args:
            sequence: The POS sequence
            prediction_errors: Corresponding prediction errors
        """
        current_pos = "<START>"
        
        for i, pos in enumerate(sequence):
            # Get attention weights for current and next position
            current_attention = self.attention_weights.get(current_pos, 1.0)
            target_attention = self.attention_weights.get(pos, 1.0)
            
            # Combined attention effect (geometric mean)
            combined_attention = math.sqrt(current_attention * target_attention)
            
            # Attention-modulated learning rate
            effective_lr = self.learning_rate * combined_attention
            
            # Update precision for this transition based on prediction error
            if i < len(prediction_errors):
                error = prediction_errors[i]
                new_precision = self.base_precision * (1 + 0.1 * error)
                new_precision = max(self.min_precision, min(self.max_precision, new_precision))
                
                if self.graph.has_edge(current_pos, pos):
                    # Slowly update precision for this edge
                    old_precision = self.graph[current_pos][pos].get("precision", self.base_precision)
                    updated_precision = 0.9 * old_precision + 0.1 * new_precision
                    self.graph[current_pos][pos]["precision"] = updated_precision
            
            # Update current position
            current_pos = pos

    def _calculate_edge_weights(self):
        """Calculate edge weights (transition probabilities) based on counts."""
        # For each node, calculate outgoing transition probabilities
        for node in self.graph.nodes():
            if node == "<END>":
                continue  # End node has no outgoing edges
                
            # Get total count of outgoing transitions
            outgoing_edges = list(self.graph.out_edges(node, data=True))
            total_count = sum(data["count"] for _, _, data in outgoing_edges)
            
            if total_count > 0:
                # Calculate probability for each outgoing edge
                for _, target, data in outgoing_edges:
                    prob = data["count"] / total_count
                    self.graph[node][target]["weight"] = prob

    def _calculate_boundary_probabilities(self):
        """Calculate boundary probabilities with attention influence."""
        for source, target, data in self.graph.edges(data=True):
            if source in ("<START>", "<END>") or target in ("<START>", "<END>"):
                continue  # Skip special nodes
                
            # Calculate surprisal for this transition
            prob = data.get("weight", 0)
            if prob > 0:
                surprisal = -math.log2(prob)
                
                # Get attention weights for source and target
                source_attention = self.attention_weights.get(source, 1.0)
                target_attention = self.attention_weights.get(target, 1.0)
                
                # Attention-modulated boundary probability
                # Higher attention at either end means more salient boundary
                attention_factor = (source_attention + target_attention) / 2
                
                # Normalize surprisal to a boundary probability between 0 and 1
                raw_boundary_prob = 1 / (1 + math.exp(-(surprisal - 1)))
                
                # Adjust boundary probability based on attention
                boundary_prob = raw_boundary_prob * attention_factor
                
                # Consider predefined boundaries if available
                if (source, target) in self.predefined_boundaries:
                    predefined_prob = self.predefined_boundaries[(source, target)]
                    alpha = 0.3  # Weight for predefined rules
                    boundary_prob = alpha * predefined_prob + (1 - alpha) * boundary_prob
                
                # Store in graph and in lookup dictionary
                boundary_prob = max(0.0, min(1.0, boundary_prob))  # Ensure [0,1] range
                self.graph[source][target]["boundary_prob"] = boundary_prob
                self.boundary_probs[(source, target)] = boundary_prob

    def _identify_common_chunks(self, pos_sequences: List[List[str]]):
        """Identify common chunks using attention-weighted statistics."""
        # Use a sliding window approach to find potential chunks
        chunk_candidates = Counter()
        
        # Try different chunk sizes
        for size in range(2, 5):  # 2-grams to 4-grams
            for sequence in pos_sequences:
                if len(sequence) < size:
                    continue
                    
                for i in range(len(sequence) - size + 1):
                    chunk = tuple(sequence[i:i+size])
                    chunk_candidates[chunk] += 1
        
        print(f"Found {len(chunk_candidates)} potential chunks")
        
        # For small training sets, lower the threshold
        total_sentences = len(pos_sequences)
        min_occurrences = max(2, int(total_sentences * 0.05))
        
        # Lower the cohesion threshold for small datasets
        cohesion_threshold = 0.6 if total_sentences < 20 else 0.7
        
        # Count qualifying chunks
        qualifying_chunks = 0
        for chunk, count in chunk_candidates.items():
            if count >= min_occurrences:
                qualifying_chunks += 1
                
                # Calculate internal cohesion with attention weighting
                internal_boundaries = 0
                chunk_attention = 1.0  # Start with neutral attention
                
                for i in range(len(chunk) - 1):
                    pos1, pos2 = chunk[i], chunk[i+1]
                    
                    # Get boundary probability
                    boundary_prob = self.boundary_probs.get((pos1, pos2), 0.5)
                    
                    # Apply attention weighting - average of the two positions
                    pos1_attention = self.attention_weights.get(pos1, 1.0)
                    pos2_attention = self.attention_weights.get(pos2, 1.0)
                    avg_attention = (pos1_attention + pos2_attention) / 2
                    
                    # Accumulate attention-weighted boundary probability
                    internal_boundaries += boundary_prob
                    
                    # Calculate overall chunk attention (product of position attentions)
                    chunk_attention *= (pos1_attention * 0.5 + 0.5)  # Dampen the effect
                
                avg_internal_boundary = internal_boundaries / (len(chunk) - 1)
                cohesion = 1 - avg_internal_boundary
                
                # Boost cohesion for chunks with high attention
                attention_adjusted_cohesion = cohesion * (0.8 + 0.2 * chunk_attention)
                
                # Only keep reasonably cohesive chunks
                if attention_adjusted_cohesion > cohesion_threshold:
                    chunk_name = f"{'_'.join(chunk)}"
                    self.common_chunks[chunk] = {
                        "name": chunk_name,
                        "elements": chunk,
                        "count": count,
                        "cohesion": attention_adjusted_cohesion,
                        "attention": chunk_attention,
                        "activation": 0.0  # Initial activation level
                    }
                    
                    # Initialize chunk attention weight
                    self.chunk_attention_weights[chunk] = chunk_attention
                    
        print(f"{qualifying_chunks} chunks met frequency criteria, {len(self.common_chunks)} met cohesion criteria")

    def _build_chunk_graph(self):
        """Build higher-order graph representing transitions between chunks."""
        # Add nodes for each chunk
        for chunk_tuple, chunk_info in self.common_chunks.items():
            chunk_name = chunk_info["name"]
            self.chunk_graph.add_node(
                chunk_name, 
                pos_type="chunk", 
                elements=chunk_info["elements"],
                cohesion=chunk_info["cohesion"],
                attention=chunk_info["attention"]
            )
        
        # Connect chunks that can follow each other
        for chunk1_tuple, chunk1_info in self.common_chunks.items():
            for chunk2_tuple, chunk2_info in self.common_chunks.items():
                # Check if chunk2 can follow chunk1 (overlap or adjacency)
                if self._can_follow(chunk1_tuple, chunk2_tuple):
                    # Calculate transition probability
                    # This is simplified - would need corpus analysis for accurate probabilities
                    transition_prob = 0.1  # Default low probability
                    
                    # If we have trigram data, use it to estimate transition probability
                    if len(chunk1_tuple) >= 2 and len(chunk2_tuple) >= 1:
                        last1, last2 = chunk1_tuple[-2], chunk1_tuple[-1]
                        first = chunk2_tuple[0]
                        
                        if last2 in self.trigram_counts.get(last1, {}):
                            total = sum(self.trigram_counts[last1][last2].values())
                            if total > 0:
                                count = self.trigram_counts[last1][last2].get(first, 0)
                                transition_prob = count / total
                    
                    # Apply attention weighting to transition
                    chunk1_attention = self.chunk_attention_weights.get(chunk1_tuple, 1.0)
                    chunk2_attention = self.chunk_attention_weights.get(chunk2_tuple, 1.0)
                    
                    # Higher attention on both chunks strengthens their connection
                    attention_factor = (chunk1_attention + chunk2_attention) / 2
                    weighted_prob = transition_prob * attention_factor
                    
                    # Add edge with weight
                    chunk1_name = chunk1_info["name"]
                    chunk2_name = chunk2_info["name"]
                    self.chunk_graph.add_edge(
                        chunk1_name, 
                        chunk2_name, 
                        weight=weighted_prob,
                        attention=attention_factor
                    )

    def _can_follow(self, chunk1: Tuple[str, ...], chunk2: Tuple[str, ...]) -> bool:
        """
        Determine if chunk2 can follow chunk1 in a sequence.
        Either through overlap or adjacency.
        """
        # Check if there's an overlap
        for overlap_size in range(1, min(len(chunk1), len(chunk2))):
            if chunk1[-overlap_size:] == chunk2[:overlap_size]:
                return True
        
        # Check if there's an edge from the last element of chunk1 to the first of chunk2
        last_of_chunk1 = chunk1[-1]
        first_of_chunk2 = chunk2[0]
        
        return self.graph.has_edge(last_of_chunk1, first_of_chunk2)

    def _report_attention_weights(self):
        """Report the top and bottom attention weights."""
        # Sort attention weights
        sorted_pos = sorted(self.attention_weights.items(), key=lambda x: x[1], reverse=True)
        
        # Report top attention weights
        print("\nTop 5 POS tags by attention:")
        for pos, weight in sorted_pos[:5]:
            print(f"  {pos}: {weight:.3f}")
            
        # Report bottom attention weights if we have enough
        if len(sorted_pos) > 5:
            print("\nBottom 5 POS tags by attention:")
            for pos, weight in sorted_pos[-5:]:
                print(f"  {pos}: {weight:.3f}")
                
        # Report chunk attention if available
        if self.chunk_attention_weights:
            sorted_chunks = sorted(self.chunk_attention_weights.items(), 
                                   key=lambda x: x[1], reverse=True)
            
            print("\nTop 5 chunks by attention:")
            for chunk, weight in sorted_chunks[:min(5, len(sorted_chunks))]:
                print(f"  {chunk}: {weight:.3f}")

    def segment(self, pos_sequence: List[str]) -> List[List[str]]:
        """
        Segment a POS sequence into chunks based on boundary probabilities.
        
        Args:
            pos_sequence: List of POS tags for a sentence
            
        Returns:
            List of chunks, where each chunk is a list of POS tags
        """
        chunks = []
        current_chunk = [pos_sequence[0]]
        
        for i in range(1, len(pos_sequence)):
            pos1, pos2 = pos_sequence[i-1], pos_sequence[i]
            
            # Get boundary probability
            boundary_prob = self.boundary_probs.get((pos1, pos2), 0.2)  # Default if unseen
            
            # Apply attention weighting
            pos1_attention = self.attention_weights.get(pos1, 1.0)
            pos2_attention = self.attention_weights.get(pos2, 1.0)
            attention_factor = (pos1_attention + pos2_attention) / 2
            
            # Attention-modulated boundary decision
            effective_boundary = boundary_prob * attention_factor
            
            if effective_boundary > self.hard_boundary_threshold:
                # Hard boundary - create a new chunk
                chunks.append(current_chunk)
                current_chunk = [pos2]
            else:
                # Continue current chunk
                current_chunk.append(pos2)
        
        # Add the last chunk
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks

    def predict_next_pos(self, context: List[str], top_n: int = 3) -> List[Tuple[str, float]]:
        """
        Predict the next POS tag with attention-modulated probabilities.
        
        Args:
            context: List of preceding POS tags
            top_n: Number of top predictions to return
            
        Returns:
            List of (pos_tag, probability) pairs, sorted by probability
        """
        # Update context history
        self.context_history.append(context)
        if len(self.context_history) > 5:  # Keep only recent history
            self.context_history = self.context_history[-5:]
        
        if not context:
            # No context, use connections from start node
            predictions = []
            for target, data in self.graph.out_edges("<START>", data=True):
                if target != "<END>":
                    # Apply attention weighting
                    base_prob = data.get("weight", 0.0)
                    target_attention = self.attention_weights.get(target, 1.0)
                    adjusted_prob = base_prob * target_attention
                    predictions.append((target, adjusted_prob))
            
            # Normalize probabilities
            total_prob = sum(prob for _, prob in predictions)
            if total_prob > 0:
                predictions = [(tag, prob/total_prob) for tag, prob in predictions]
                
            return sorted(predictions, key=lambda x: x[1], reverse=True)[:top_n]
        
        # Use the last tag for prediction
        last_pos = context[-1]
        
        if self.graph.has_node(last_pos):
            # Apply attention to outgoing predictions
            predictions = []
            
            # Get attention for the source position
            source_attention = self.attention_weights.get(last_pos, 1.0)
            
            # First try to use chunk-based prediction if we have matching chunks
            chunk_predictions = self._predict_from_chunks(context)
            
            if chunk_predictions:
                # If we have chunk-based predictions, give them more weight
                # but also include some direct edge predictions
                direct_predictions = []
                for _, target, data in self.graph.out_edges(last_pos, data=True):
                    if target != "<END>":
                        base_prob = data.get("weight", 0.0)
                        target_attention = self.attention_weights.get(target, 1.0)
                        # Combined attention effect (geometric mean)
                        combined_attention = math.sqrt(source_attention * target_attention)
                        adjusted_prob = base_prob * combined_attention
                        direct_predictions.append((target, adjusted_prob))
                
                # Normalize direct predictions
                total_direct = sum(prob for _, prob in direct_predictions)
                if total_direct > 0:
                    direct_predictions = [(tag, prob/total_direct) for tag, prob in direct_predictions]
                
                # Combine chunk-based and direct predictions with 3:1 weighting
                combined = {}
                for tag, prob in chunk_predictions:
                    combined[tag] = prob * 0.75
                
                for tag, prob in direct_predictions:
                    if tag in combined:
                        combined[tag] += prob * 0.25
                    else:
                        combined[tag] = prob * 0.25
                
                predictions = [(tag, prob) for tag, prob in combined.items()]
            else:
                # No chunk predictions, use direct edge predictions
                for _, target, data in self.graph.out_edges(last_pos, data=True):
                    if target != "<END>":
                        base_prob = data.get("weight", 0.0)
                        target_attention = self.attention_weights.get(target, 1.0)
                        # Combined attention effect
                        combined_attention = math.sqrt(source_attention * target_attention)
                        adjusted_prob = base_prob * combined_attention
                        predictions.append((target, adjusted_prob))
            
            # Normalize probabilities
            total_prob = sum(prob for _, prob in predictions)
            if total_prob > 0:
                predictions = [(tag, prob/total_prob) for tag, prob in predictions]
                
            return sorted(predictions, key=lambda x: x[1], reverse=True)[:top_n]
        else:
            # Unseen POS tag
            return [("<UNK>", 1.0)]  # Return unknown with full probability

    def _predict_from_chunks(self, context: List[str]) -> List[Tuple[str, float]]:
        """
        Generate predictions based on chunk matching.
        
        Args:
            context: The context sequence
            
        Returns:
            List of (pos_tag, probability) tuples
        """
        if len(context) < 1:
            return []
            
        # Try to match the end of the context with the beginning of chunks
        matches = []
        max_match_length = 0
        
        for chunk_tuple, chunk_info in self.common_chunks.items():
            for match_length in range(min(len(context), len(chunk_tuple)), 0, -1):
                if context[-match_length:] == chunk_tuple[:match_length]:
                    if match_length > max_match_length:
                        max_match_length = match_length
                        matches = [(chunk_tuple, chunk_info, match_length)]
                    elif match_length == max_match_length:
                        matches.append((chunk_tuple, chunk_info, match_length))
        
        if not matches:
            return []
            
        # Generate predictions based on matched chunks
        predictions = Counter()
        total_weight = 0
        
        for chunk_tuple, chunk_info, match_length in matches:
            # If match is complete, this chunk can't help with prediction
            if match_length >= len(chunk_tuple):
                continue
                
            # The next element in the chunk is the prediction
            next_pos = chunk_tuple[match_length]
            
            # Weight by chunk cohesion and attention
            weight = chunk_info["cohesion"] * chunk_info.get("attention", 1.0)
            predictions[next_pos] += weight
            total_weight += weight
        
        # Normalize predictions
        if total_weight > 0:
            return [(pos, weight/total_weight) for pos, weight in predictions.items()]
        else:
            return []

    def predictive_processing(self, pos_sequence: List[str]) -> Tuple[List[Dict[str, Any]], List[List[str]]]:
        """
        Process a sequence using predictive coding principles with attention modulation.
        
        Args:
            pos_sequence: List of POS tags
            
        Returns:
            Tuple of (recognized chunks, segmented sequence)
        """
        # First pass: recognize chunks
        recognized_chunks = self.recognize_chunks(pos_sequence)
        
        # Second pass: resolve overlaps with attention-weighted resolution
        non_overlapping = self._resolve_chunk_overlaps(recognized_chunks, len(pos_sequence))
        
        # Third pass: final segmentation based on chunks and boundaries
        segmentation = self._create_final_segmentation(pos_sequence, non_overlapping)
        
        return non_overlapping, segmentation
    
    def recognize_chunks(self, pos_sequence: List[str]) -> List[Dict[str, Any]]:
        """
        Recognize known chunks in a POS sequence with attention modulation.
        
        Args:
            pos_sequence: List of POS tags
            
        Returns:
            List of recognized chunks with their properties
        """
        recognized = []
        
        # Try to match chunks of different sizes
        for i in range(len(pos_sequence)):
            for size in range(4, 1, -1):  # Try larger chunks first (4, 3, 2)
                if i + size <= len(pos_sequence):
                    chunk_tuple = tuple(pos_sequence[i:i+size])
                    if chunk_tuple in self.common_chunks:
                        # Calculate activation based on cohesion and attention
                        chunk_info = self.common_chunks[chunk_tuple]
                        base_activation = chunk_info["cohesion"]
                        chunk_attention = self.chunk_attention_weights.get(chunk_tuple, 1.0)
                        
                        # Attention-modulated activation
                        activation = base_activation * chunk_attention
                        
                        recognized.append({
                            "chunk": chunk_info,
                            "start": i,
                            "end": i + size,
                            "activation": activation
                        })
        
        # Sort by start position
        recognized.sort(key=lambda x: x["start"])
        
        return recognized
        
    def _resolve_chunk_overlaps(self, chunks: List[Dict[str, Any]], seq_length: int) -> List[Dict[str, Any]]:
        """
        Resolve overlapping chunks by selecting the most activated ones with attention influence.
        
        Args:
            chunks: List of recognized chunks
            seq_length: Length of the original sequence
            
        Returns:
            List of non-overlapping chunks
        """
        # If no chunks, return empty list
        if not chunks:
            return []
            
        # Sort by activation (influenced by attention) to prioritize strongest chunks
        sorted_chunks = sorted(chunks, key=lambda x: x["activation"], reverse=True)
        
        # Track which positions are covered
        covered = [False] * seq_length
        
        # Select non-overlapping chunks
        selected = []
        
        for chunk in sorted_chunks:
            start, end = chunk["start"], chunk["end"]
            
            # Check if this chunk overlaps with already selected ones
            overlap = False
            for i in range(start, end):
                if covered[i]:
                    overlap = True
                    break
            
            if not overlap:
                # Add chunk and mark positions as covered
                selected.append(chunk)
                for i in range(start, end):
                    covered[i] = True
        
        # Sort by start position
        selected.sort(key=lambda x: x["start"])
        
        return selected

    def visualize_pos_graph(self, filename: str = "pos_graph_attention.png"):
        """
        Visualize the POS transition graph with attention weighting.
        
        Args:
            filename: Output file name
        """
        # Check if graph is empty
        if len(self.graph) <= 2:  # Only START and END nodes
            print("POS graph is empty or contains only special nodes - no visualization created")
            return
            
        # Create a copy without special nodes for cleaner visualization
        g = self.graph.copy()
        
        # Only remove special nodes if they exist
        if "<START>" in g:
            g.remove_node("<START>")
        if "<END>" in g:
            g.remove_node("<END>")
            
        if len(g.edges()) == 0:
            print("POS graph has no edges - no visualization created")
            return
        
        # Set up the plot
        plt.figure(figsize=(12, 10))
        
        # Define node positions using spring layout
        pos = nx.spring_layout(g, seed=42)
        
        # Node sizes based on attention weights
        node_sizes = []
        node_colors = []
        for node in g.nodes():
            attention = self.attention_weights.get(node, 1.0)
            node_sizes.append(400 + 300 * attention)  # Base size + attention effect
            
            # Color gradient from cool to warm based on attention
            # Low attention: blue (0.0), High attention: red (1.0)
            node_colors.append((min(1.0, attention/2), 0.2, max(0.0, 1.0-attention/2)))
        
        # Draw nodes with size and color based on attention
        nx.draw_networkx_nodes(g, pos, node_size=node_sizes, node_color=node_colors)
        
        # Prepare edge attributes
        edge_width = []
        edge_color = []
        
        for source, target, data in g.edges(data=True):
            # Default to 0.1 if weight is missing or zero
            weight = data.get("weight", 0.1)
            if weight == 0:
                weight = 0.1
                
            # Adjust width by precision weighting
            precision = data.get("precision", self.base_precision)
            precision_factor = precision / self.base_precision
            
            edge_width.append(weight * 5 * precision_factor)
            
            # Edge color based on boundary probability
            edge_color.append(data.get("boundary_prob", 0.5))
        
        # Draw edges
        nx.draw_networkx_edges(
            g, pos, width=edge_width, 
            edge_color=edge_color, edge_cmap=plt.cm.Reds,
            connectionstyle="arc3,rad=0.1"
        )
        
        # Add labels
        nx.draw_networkx_labels(g, pos, font_size=10)
        
        # Edge labels (probability + precision)
        edge_labels = {}
        for u, v, d in g.edges(data=True):
            weight = d.get("weight", 0.0)
            precision = d.get("precision", self.base_precision)
            edge_labels[(u, v)] = f"{weight:.2f}\n(p:{precision:.1f})"
                
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_size=8)
        
        # Add a color bar for boundary probabilities
        fig = plt.gcf()
        ax = plt.gca()
        sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds)
        sm.set_array([])
        fig.colorbar(sm, ax=ax, label="Boundary Probability")
        
        plt.title("POS Transition Graph with Attention")
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(filename)
        print(f"POS graph visualization saved to {filename}")
        plt.close()
        
    def visualize_chunk_graph(self, filename: str = "chunk_graph_attention.png"):
        """
        Visualize the chunk transition graph with attention weighting.
        
        Args:
            filename: Output file name
        """
        if len(self.chunk_graph) == 0:
            print("Chunk graph is empty - no visualization created")
            return
            
        if len(self.chunk_graph.edges()) == 0:
            print("Chunk graph has no edges - adding artificial edges for visualization")
            # Create some artificial edges just for visualization
            nodes = list(self.chunk_graph.nodes())
            if len(nodes) > 1:
                for i in range(len(nodes)-1):
                    self.chunk_graph.add_edge(nodes[i], nodes[i+1], weight=0.1, attention=1.0)
            
        # Set up the plot
        plt.figure(figsize=(14, 12))
        
        # Define node positions using spring layout
        pos = nx.spring_layout(self.chunk_graph, seed=42)
        
        # Draw nodes with size and color based on attention and cohesion
        node_sizes = []
        node_colors = []
        for node in self.chunk_graph.nodes():
            cohesion = self.chunk_graph.nodes[node].get("cohesion", 0.5)
            attention = self.chunk_graph.nodes[node].get("attention", 1.0)
            
            if cohesion <= 0:
                cohesion = 0.5  # Ensure minimum size
                
            # Size based on cohesion and attention
            node_sizes.append(cohesion * attention * 1000)
            
            # Color based on attention (from cool to warm)
            node_colors.append((min(1.0, attention/2), 0.4, max(0.0, 1.0-attention/2)))
        
        nx.draw_networkx_nodes(
            self.chunk_graph, pos, 
            node_size=node_sizes, 
            node_color=node_colors
        )
        
        # Draw edges with width and color based on weight and attention
        if len(self.chunk_graph.edges()) > 0:
            edge_width = []
            edge_color = []
            
            for _, _, data in self.chunk_graph.edges(data=True):
                weight = data.get("weight", 0.1)
                attention = data.get("attention", 1.0)
                
                if weight <= 0:
                    weight = 0.1  # Ensure minimum width
                    
                # Width affected by both weight and attention
                edge_width.append(weight * attention * 10)
                
                # Edge color based on attention
                edge_color.append(attention)
            
            nx.draw_networkx_edges(
                self.chunk_graph, pos, width=edge_width, 
                edge_color=edge_color, edge_cmap=plt.cm.YlOrRd, alpha=0.7,
                connectionstyle="arc3,rad=0.1"
            )
        
        # Add labels
        nx.draw_networkx_labels(self.chunk_graph, pos, font_size=9)
        
        # Add a color bar for edge attention
        fig = plt.gcf()
        ax = plt.gca()
        sm = plt.cm.ScalarMappable(cmap=plt.cm.YlOrRd)
        sm.set_array([])
        fig.colorbar(sm, ax=ax, label="Attention Weight")
        
        plt.title("Chunk Transition Graph with Attention")
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(filename)
        print(f"Chunk graph visualization saved to {filename}")
        plt.close()
        
    def _create_final_segmentation(self, pos_sequence: List[str], chunks: List[Dict[str, Any]]) -> List[List[str]]:
        """
        Create final segmentation based on recognized chunks and boundary probabilities
        with attention modulation.
        
        Args:
            pos_sequence: Original POS sequence
            chunks: Non-overlapping chunks
            
        Returns:
            List of segments (chunks)
        """
        # If no chunks recognized, fall back to boundary-based segmentation
        if not chunks:
            return self.segment(pos_sequence)
        
        # Create segmentation based on recognized chunks and boundaries
        segmentation = []
        current_pos = 0
        
        for chunk in chunks:
            start, end = chunk["start"], chunk["end"]
            
            # If there's a gap before this chunk, segment it using boundaries
            if start > current_pos:
                gap_sequence = pos_sequence[current_pos:start]
                gap_segments = self.segment(gap_sequence)
                
                # Adjust segment positions
                adjusted_segments = []
                for segment in gap_segments:
                    adjusted_segments.append(segment)
                
                segmentation.extend(adjusted_segments)
            
            # Add the recognized chunk
            segmentation.append(pos_sequence[start:end])
            current_pos = end
        
        # Handle any remaining sequence after the last chunk
        if current_pos < len(pos_sequence):
            remaining = pos_sequence[current_pos:]
            remaining_segments = self.segment(remaining)
            segmentation.extend(remaining_segments)
        
        return segmentation


# Example usage
if __name__ == "__main__":
    # Sample POS sequences for training - Create more examples with repeating patterns
    # to increase likelihood of chunk detection
    training_data = [
        ["DET", "ADJ", "NOUN", "VERB", "DET", "NOUN"],
        ["PRON", "VERB", "PREP", "DET", "NOUN"],
        ["DET", "NOUN", "VERB", "ADV", "ADJ"],
        ["DET", "ADJ", "ADJ", "NOUN", "VERB", "PREP", "DET", "NOUN"],
        ["PRON", "VERB", "DET", "NOUN", "CONJ", "VERB", "ADV"],
        # More examples with repeating patterns to help chunk detection
        ["DET", "ADJ", "NOUN", "VERB", "PREP", "DET", "NOUN"],  # Repeat
        ["DET", "ADJ", "NOUN", "VERB", "DET", "NOUN"],          # Repeat
        ["DET", "NOUN", "VERB", "PREP", "DET", "ADJ", "NOUN"],
        ["PRON", "VERB", "ADV", "CONJ", "VERB", "DET", "NOUN"],
        ["DET", "ADJ", "NOUN", "VERB", "ADV", "PREP", "PRON"],
        ["NOUN", "VERB", "DET", "ADJ", "NOUN", "PREP", "DET", "NOUN"],
        ["DET", "NOUN", "VERB", "ADJ", "CONJ", "ADV"],
        # Even more repetition of common patterns
        ["DET", "ADJ", "NOUN", "VERB", "DET", "NOUN"],          # Repeat
        ["PRON", "VERB", "PREP", "DET", "NOUN"],                # Repeat
        ["DET", "ADJ", "NOUN", "VERB", "PREP", "DET", "NOUN"],  # Repeat
    ]
    
    print(f"Training on {len(training_data)} sentences")
    
    # Initialize and train the graph with attention mechanisms
    pos_graph = POSGraphWithAttention()
    pos_graph.train(training_data, epochs=3)  # Multiple epochs to allow attention to develop
    
    # Test on a new sentence
    test_sentence = ["DET", "ADJ", "NOUN", "VERB", "PREP", "DET", "NOUN"]
    
    print("\nTest sentence:", test_sentence)
    
    # Recognition and segmentation
    chunks, segments = pos_graph.predictive_processing(test_sentence)
    
    print("\nRecognized chunks:")
    if chunks:
        for chunk in chunks:
            print(f"  {chunk['chunk']['elements']} (Position {chunk['start']}-{chunk['end']}, " +
                  f"Activation: {chunk['activation']:.3f})")
    else:
        print("  No chunks recognized")
    
    print("\nFinal segmentation:", segments)
    
    # Predictions
    context = ["DET", "ADJ"]
    predictions = pos_graph.predict_next_pos(context)
    print(f"\nTop predictions after {context}:")
    for pos, prob in predictions:
        print(f"  {pos}: {prob:.2f} (Attention: {pos_graph.attention_weights.get(pos, 1.0):.2f})")
    
    # Visualize graphs with attention
    pos_graph.visualize_pos_graph()
    pos_graph.visualize_chunk_graph()
    
    # Plot surprisal history to show learning progress
    if pos_graph.surprisal_history:
        plt.figure(figsize=(10, 6))
        plt.plot(pos_graph.surprisal_history, 'b-o')
        plt.title('Learning Progress - Average Surprisal per Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Average Surprisal')
        plt.grid(True)
        plt.savefig('learning_progress.png')
        print("Learning progress plot saved to learning_progress.png")
        plt.close()

Training on 15 sentences
Training on 15 sequences for 3 epochs
Epoch 1/3
  Average surprisal: 64.6667
Epoch 2/3
  Average surprisal: 7.8583
Epoch 3/3
  Average surprisal: 7.8583
Found 83 potential chunks
27 chunks met frequency criteria, 16 met cohesion criteria
Training complete.
POS graph has 10 nodes and 27 edges
Chunk graph has 16 nodes and 139 edges

Top 5 POS tags by attention:
  <END>: 1.000
  CONJ: 0.941
  PRON: 0.922
  PREP: 0.890
  ADV: 0.859

Bottom 5 POS tags by attention:
  ADJ: 0.845
  <START>: 0.576
  VERB: 0.516
  DET: 0.409
  NOUN: 0.369

Top 5 chunks by attention:
  ('CONJ', 'VERB'): 0.971
  ('PRON', 'VERB'): 0.961
  ('PREP', 'DET'): 0.945
  ('ADJ', 'NOUN'): 0.923
  ('PRON', 'VERB', 'PREP'): 0.728

Test sentence: ['DET', 'ADJ', 'NOUN', 'VERB', 'PREP', 'DET', 'NOUN']

Recognized chunks:
  ('ADJ', 'NOUN') (Position 1-3, Activation: 0.716)
  ('PREP', 'DET') (Position 4-6, Activation: 0.747)

Final segmentation: [['DET'], ['ADJ', 'NOUN'], ['VERB'], ['PREP', 'DET'], ['NOUN