In [None]:
def visualize_graphs(self, prefix="bidirectional_"):
        """Visualize all graphs (forward, backward, and chunk)."""
        self.visualize_pos_graph(self.forward_graph, f"{prefix}forward_graph.png", "Forward")
        self.visualize_pos_graph(self.backward_graph, f"{prefix}backward_graph.png", "Backward")
        self.visualize_chunk_graph(f"{prefix}chunk_graph.png")
        
        # Plot learning progress
        if self.surprisal_history:
            plt.figure(figsize=(10, 6))
            plt.plot(self.surprisal_history, 'b-o')
            plt.title('Bidirectional Learning Progress - Average Surprisal per Epoch')
            plt.xlabel('Epoch')
            plt.ylabel('Average Surprisal')
            plt.grid(True)
            plt.savefig(f'{prefix}learning_progress.png')
            print(f"Learning progress plot saved to {prefix}learning_progress.png")
            plt.close()

    def visualize_pos_graph(self, graph, filename: str = "pos_graph_attention.png", direction: str = "Forward"):
        """
        Visualize a POS transition graph with attention weighting.
        
        Args:
            graph: The graph to visualize
            filename: Output file name
            direction: Label for the graph direction
        """
        # Check if graph is empty
        if len(graph) <= 2:  # Only START and END nodes
            print(f"{direction} graph is empty or contains only special nodes - no visualization created")
            return
            
        # Create a copy without special nodes for cleaner visualization
        g = graph.copy()
        
        # Only remove special nodes if they exist
        if "<START>" in g:
            g.remove_node("<START>")
        if "<END>" in g:
            g.remove_node("<END>")
            
        if len(g.edges()) == 0:
            print(f"{direction} graph has no edges - no visualization created")
            return
        
        # Set up the plot
        plt.figure(figsize=(12, 10))
        
        # Define node positions using spring layout
        pos = nx.spring_layout(g, seed=42)
        
        # Node sizes based on attention weights
        node_sizes = []
        node_colors = []
        for node in g.nodes():
            attention = self.attention_weights.get(node, 1.0)
            node_sizes.append(400 + 300 * attention)  # Base size + attention effect
            
            # Color gradient from cool to warm based on attention
            # Low attention: blue (0.0), High attention: red (1.0)
            node_colors.append((min(1.0, attention/2), 0.2, max(0.0, 1.0-attention/2)))
        
        # Draw nodes with size and color based on attention
        nx.draw_networkx_nodes(g, pos, node_size=node_sizes, node_color=node_colors)
        
        # Prepare edge attributes
        edge_width = []
        edge_color = []
        
        for source, target, data in g.edges(data=True):
            # Default to 0.1 if weight is missing or zero
            weight = data.get("weight", 0.1)
            if weight == 0:
                weight = 0.1
                
            # Adjust width by precision weighting
            precision = data.get("precision", self.base_precision)
            precision_factor = precision / self.base_precision
            
            edge_width.append(weight * 5 * precision_factor)
            
            # Edge color based on boundary probability
            edge_color.append(data.get("boundary_prob", 0.5))
        
        # Draw edges
        nx.draw_networkx_edges(
            g, pos, width=edge_width, 
            edge_color=edge_color, edge_cmap=plt.cm.Reds,
            connectionstyle="arc3,rad=0.1"
        )
        
        # Add labels
        nx.draw_networkx_labels(g, pos, font_size=10)
        
        # Edge labels (probability + precision)
        edge_labels = {}
        for u, v, d in g.edges(data=True):
            weight = d.get("weight", 0.0)
            precision = d.get("precision", self.base_precision)
            edge_labels[(u, v)] = f"{weight:.2f}\n(p:{precision:.1f})"
                
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels, font_size=8)
        
        # Add a color bar for boundary probabilities
        fig = plt.gcf()
        ax = plt.gca()
        sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds)
        sm.set_array([])
        fig.colorbar(sm, ax=ax, label="Boundary Probability")
        
        plt.title(f"{direction} POS Transition Graph with Attention")
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(filename)
        print(f"{direction} graph visualization saved to {filename}")
        plt.close()import numpy as np
from collections import defaultdict, Counter
import math
from typing import List, Dict, Tuple, Set, Optional, Any
import networkx as nx
import matplotlib.pyplot as plt
import random
import time

class POSGraphBidirectional:
    """
    Implements a graph-based structure for POS sequence processing with predictive coding,
    attention mechanisms, and bidirectional processing.
    """
    
    def __init__(self, predefined_boundaries: Optional[Dict[Tuple[str, str], float]] = None):
        """
        Initialize the POS graph with attention mechanisms and bidirectional processing.
        
        Args:
            predefined_boundaries: Optional dictionary of predefined boundary probabilities
        """
        # Main transition graph (forward direction)
        self.forward_graph = nx.DiGraph()
        
        # Backward transition graph (for bidirectional processing)
        self.backward_graph = nx.DiGraph()
        
        # Higher-order chunk graph for learned patterns
        self.chunk_graph = nx.DiGraph()
        
        # Track n-gram counts for training
        self.unigram_counts = Counter()
        self.bigram_counts = defaultdict(Counter)
        self.backward_bigram_counts = defaultdict(Counter)  # For backward transitions
        self.trigram_counts = defaultdict(lambda: defaultdict(Counter))
        
        # Boundary probabilities
        self.forward_boundary_probs = defaultdict(float)  # Forward direction
        self.backward_boundary_probs = defaultdict(float)  # Backward direction
        self.combined_boundary_probs = defaultdict(float)  # Combined
        
        # Predefined linguistic rules
        self.predefined_boundaries = predefined_boundaries or {
            ('NOUN', 'VERB'): 0.9,      # NP to VP transition
            ('VERB', 'DET'): 0.8,       # VP to NP transition
            ('PUNCT', 'DET'): 0.95,     # Punctuation followed by determiner
            ('NOUN', 'PREP'): 0.7,      # NP to PP transition
            ('VERB', 'PREP'): 0.6,      # VP to PP transition
            ('ADJ', 'NOUN'): 0.2,       # Within NP (low boundary probability)
            ('DET', 'ADJ'): 0.1,        # Within NP (very low boundary probability)
        }
        
        # Thresholds
        self.hard_boundary_threshold = 0.75
        self.soft_boundary_threshold = 0.4
        
        # Discovered chunks
        self.common_chunks = {}
        
        # Add special start and end nodes
        self.forward_graph.add_node("<START>", pos_type="special")
        self.forward_graph.add_node("<END>", pos_type="special")
        self.backward_graph.add_node("<START>", pos_type="special")
        self.backward_graph.add_node("<END>", pos_type="special")
        
        # Attention mechanisms
        self.attention_weights = {}  # POS tag attention weights
        self.chunk_attention_weights = {}  # Chunk attention weights
        self.surprisal_history = []  # Track surprisal for adaptive attention
        self.learning_rate = 0.1  # Base learning rate
        self.attention_learning_rate = 0.05  # For updating attention weights
        
        # Precision weighting parameters
        self.base_precision = 1.0
        self.max_precision = 5.0
        self.min_precision = 0.2
        
        # Context tracking
        self.context_history = []  # Track recent contexts for attention modulation
        
        # Bidirectional weights
        self.forward_weight = 0.6  # Weight for forward processing (typically higher)
        self.backward_weight = 0.4  # Weight for backward processing

    def train(self, pos_sequences: List[List[str]], epochs: int = 1):
        """
        Train the POS graph on a corpus of POS tag sequences with attention
        and bidirectional processing.
        
        Args:
            pos_sequences: List of POS tag sequences, each representing a sentence
            epochs: Number of training epochs
        """
        print(f"Training on {len(pos_sequences)} sequences for {epochs} epochs")
        start_time = time.time()
        
        # 1. Initialize attention weights uniformly
        self._initialize_attention_weights(pos_sequences)
        
        # 2. Build initial graphs (forward and backward) and collect statistics
        self._build_initial_graphs(pos_sequences)
        
        # 3. Iterative training with attention modulation and bidirectional processing
        for epoch in range(epochs):
            print(f"Epoch {epoch+1}/{epochs}")
            
            # Shuffle sequences for stochastic training
            random.shuffle(pos_sequences)
            
            epoch_surprisal = 0.0
            
            # Process each sequence with attention
            for sequence in pos_sequences:
                # Forward pass - calculate predictions and errors
                sequence_surprisal, prediction_errors = self._forward_pass(sequence)
                epoch_surprisal += sequence_surprisal
                
                # Backward pass - process the sequence in reverse
                reversed_sequence = list(reversed(sequence))
                backward_surprisal, backward_errors = self._backward_pass(reversed_sequence)
                epoch_surprisal += backward_surprisal
                
                # Update attention weights based on both forward and backward errors
                self._update_attention_weights(sequence, prediction_errors, reversed_sequence, backward_errors)
                
                # Update graph weights with attention-modulated learning
                self._update_graph_weights(sequence, prediction_errors, reversed_sequence, backward_errors)
            
            avg_surprisal = epoch_surprisal / (len(pos_sequences) * 2)  # Both directions
            self.surprisal_history.append(avg_surprisal)
            print(f"  Average surprisal: {avg_surprisal:.4f}")
            
            # Recalculate edge weights based on updated counts
            self._calculate_edge_weights()
            
            # After each epoch, update boundary probabilities (both directions)
            self._calculate_boundary_probabilities()
            
            # Combine boundary probabilities from both directions
            self._combine_boundary_probabilities()
        
        # 4. Identify common chunks with attention influence
        self._identify_common_chunks(pos_sequences)
        
        # 5. Build chunk graph with attention-weighted connections
        self._build_chunk_graph()
        
        training_time = time.time() - start_time
        print(f"Training complete in {training_time:.2f} seconds.")
        print(f"Forward graph has {len(self.forward_graph.nodes)} nodes and {len(self.forward_graph.edges)} edges")
        print(f"Backward graph has {len(self.backward_graph.nodes)} nodes and {len(self.backward_graph.edges)} edges")
        print(f"Chunk graph has {len(self.chunk_graph.nodes)} nodes and {len(self.chunk_graph.edges)} edges")
        
        # Report top attention weights
        self._report_attention_weights()

    def _initialize_attention_weights(self, pos_sequences: List[List[str]]):
        """Initialize attention weights for all POS tags."""
        # Extract all unique POS tags from sequences
        unique_pos = set()
        for sequence in pos_sequences:
            unique_pos.update(sequence)
        
        # Initialize with uniform weights
        for pos in unique_pos:
            self.attention_weights[pos] = 1.0
            
        # Special start and end nodes
        self.attention_weights["<START>"] = 1.0
        self.attention_weights["<END>"] = 1.0

    def _build_initial_graphs(self, pos_sequences: List[List[str]]):
        """Build the initial forward and backward graphs and collect statistics."""
        # Forward graph
        for sequence in pos_sequences:
            # Add nodes for each unique POS tag
            for pos in sequence:
                if not self.forward_graph.has_node(pos):
                    self.forward_graph.add_node(pos, pos_type="basic", precision=self.base_precision)
                self.unigram_counts[pos] += 1
            
            # Count bigrams and add edges (forward)
            for i in range(len(sequence) - 1):
                pos1, pos2 = sequence[i], sequence[i+1]
                self.bigram_counts[pos1][pos2] += 1
                
                # Ensure edge exists (weight will be calculated later)
                if not self.forward_graph.has_edge(pos1, pos2):
                    self.forward_graph.add_edge(pos1, pos2, weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                
                # Increment edge count
                self.forward_graph[pos1][pos2]["count"] += 1
            
            # Add connections from start and to end (forward)
            if sequence:
                if not self.forward_graph.has_edge("<START>", sequence[0]):
                    self.forward_graph.add_edge("<START>", sequence[0], weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.forward_graph["<START>"][sequence[0]]["count"] += 1
                
                if not self.forward_graph.has_edge(sequence[-1], "<END>"):
                    self.forward_graph.add_edge(sequence[-1], "<END>", weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.forward_graph[sequence[-1]]["<END>"]["count"] += 1
            
            # Count trigrams
            for i in range(len(sequence) - 2):
                pos1, pos2, pos3 = sequence[i], sequence[i+1], sequence[i+2]
                self.trigram_counts[pos1][pos2][pos3] += 1
        
        # Backward graph (reversed sequences)
        for sequence in pos_sequences:
            reversed_seq = list(reversed(sequence))
            
            # Add nodes for each unique POS tag in backward graph
            for pos in reversed_seq:
                if not self.backward_graph.has_node(pos):
                    self.backward_graph.add_node(pos, pos_type="basic", precision=self.base_precision)
            
            # Count bigrams and add edges (backward)
            for i in range(len(reversed_seq) - 1):
                pos1, pos2 = reversed_seq[i], reversed_seq[i+1]
                self.backward_bigram_counts[pos1][pos2] += 1
                
                # Ensure edge exists in backward graph
                if not self.backward_graph.has_edge(pos1, pos2):
                    self.backward_graph.add_edge(pos1, pos2, weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                
                # Increment edge count
                self.backward_graph[pos1][pos2]["count"] += 1
            
            # Add connections from start and to end (backward)
            if reversed_seq:
                if not self.backward_graph.has_edge("<START>", reversed_seq[0]):
                    self.backward_graph.add_edge("<START>", reversed_seq[0], weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.backward_graph["<START>"][reversed_seq[0]]["count"] += 1
                
                if not self.backward_graph.has_edge(reversed_seq[-1], "<END>"):
                    self.backward_graph.add_edge(reversed_seq[-1], "<END>", weight=0, count=0, boundary_prob=0, precision=self.base_precision)
                self.backward_graph[reversed_seq[-1]]["<END>"]["count"] += 1

    def _forward_pass(self, sequence: List[str]) -> Tuple[float, List[float]]:
        """
        Process a sequence in forward direction and calculate prediction errors.
        
        Args:
            sequence: A sequence of POS tags
            
        Returns:
            Tuple of (total_surprisal, list_of_prediction_errors)
        """
        total_surprisal = 0.0
        prediction_errors = []
        
        # Start with START node
        current_pos = "<START>"
        
        # Process each position in the sequence
        for pos in sequence:
            # Calculate prediction probability for this position
            prediction_prob = 0.0
            if self.forward_graph.has_edge(current_pos, pos):
                prediction_prob = self.forward_graph[current_pos][pos].get("weight", 0.0)
            
            # Calculate surprisal (-log probability)
            if prediction_prob > 0:
                surprisal = -math.log2(prediction_prob)
            else:
                surprisal = 10.0  # High surprisal for unseen transitions
                
            # Get current precision for this transition
            precision = self.base_precision
            if self.forward_graph.has_edge(current_pos, pos):
                precision = self.forward_graph[current_pos][pos].get("precision", self.base_precision)
            
            # Calculate precision-weighted prediction error
            prediction_error = surprisal * precision
            
            prediction_errors.append(prediction_error)
            total_surprisal += surprisal
            
            # Update current position
            current_pos = pos
        
        return total_surprisal, prediction_errors

    def _backward_pass(self, reversed_sequence: List[str]) -> Tuple[float, List[float]]:
        """
        Process a sequence in backward direction and calculate prediction errors.
        
        Args:
            reversed_sequence: A reversed sequence of POS tags
            
        Returns:
            Tuple of (total_surprisal, list_of_prediction_errors)
        """
        total_surprisal = 0.0
        prediction_errors = []
        
        # Start with START node
        current_pos = "<START>"
        
        # Process each position in the sequence
        for pos in reversed_sequence:
            # Calculate prediction probability for this position
            prediction_prob = 0.0
            if self.backward_graph.has_edge(current_pos, pos):
                prediction_prob = self.backward_graph[current_pos][pos].get("weight", 0.0)
            
            # Calculate surprisal (-log probability)
            if prediction_prob > 0:
                surprisal = -math.log2(prediction_prob)
            else:
                surprisal = 10.0  # High surprisal for unseen transitions
                
            # Get current precision for this transition
            precision = self.base_precision
            if self.backward_graph.has_edge(current_pos, pos):
                precision = self.backward_graph[current_pos][pos].get("precision", self.base_precision)
            
            # Calculate precision-weighted prediction error
            prediction_error = surprisal * precision
            
            prediction_errors.append(prediction_error)
            total_surprisal += surprisal
            
            # Update current position
            current_pos = pos
        
        return total_surprisal, prediction_errors

    def _update_attention_weights(self, sequence: List[str], forward_errors: List[float], 
                                 reversed_sequence: List[str], backward_errors: List[float]):
        """
        Update attention weights based on prediction errors from both directions.
        
        Args:
            sequence: The forward POS sequence
            forward_errors: Prediction errors for forward transitions
            reversed_sequence: The reversed POS sequence
            backward_errors: Prediction errors for backward transitions
        """
        # Combine forward and backward errors for each position
        combined_errors = {}
        
        # Process forward errors
        if forward_errors:
            max_forward = max(forward_errors)
            if max_forward > 0:
                for i, pos in enumerate(sequence):
                    if i < len(forward_errors):
                        error = forward_errors[i] / max_forward
                        combined_errors[pos] = combined_errors.get(pos, 0) + error * self.forward_weight
        
        # Process backward errors (need to re-reverse to align with original positions)
        if backward_errors:
            max_backward = max(backward_errors)
            if max_backward > 0:
                backward_positions = list(reversed(reversed_sequence))  # Re-reverse to match original
                for i, pos in enumerate(backward_positions):
                    if i < len(backward_errors):
                        error_idx = len(backward_errors) - i - 1  # Map to correct backward error
                        if error_idx >= 0 and error_idx < len(backward_errors):
                            error = backward_errors[error_idx] / max_backward
                            combined_errors[pos] = combined_errors.get(pos, 0) + error * self.backward_weight
        
        # Update attention weights based on combined errors
        for pos, error in combined_errors.items():
            current_attention = self.attention_weights.get(pos, 1.0)
            # Update with learning rate
            self.attention_weights[pos] = (1 - self.attention_learning_rate) * current_attention + \
                                      self.attention_learning_rate * error
            
            # Ensure attention weights stay in reasonable range
            self.attention_weights[pos] = max(0.2, min(3.0, self.attention_weights[pos]))
        
        # Update special nodes
        if "<START>" not in combined_errors:
            # Keep START node's attention stable
            self.attention_weights["<START>"] = self.attention_weights.get("<START>", 1.0)
        
        if "<END>" not in combined_errors:
            # Keep END node's attention stable
            self.attention_weights["<END>"] = self.attention_weights.get("<END>", 1.0)

    def _update_graph_weights(self, sequence: List[str], forward_errors: List[float],
                            reversed_sequence: List[str], backward_errors: List[float]):
        """
        Update graph edge weights based on attention-modulated learning in both directions.
        
        Args:
            sequence: The forward POS sequence
            forward_errors: Prediction errors for forward transitions
            reversed_sequence: The reversed POS sequence
            backward_errors: Prediction errors for backward transitions
        """
        # Update forward graph
        current_pos = "<START>"
        for i, pos in enumerate(sequence):
            # Get attention weights for current and next position
            current_attention = self.attention_weights.get(current_pos, 1.0)
            target_attention = self.attention_weights.get(pos, 1.0)
            
            # Combined attention effect (geometric mean)
            combined_attention = math.sqrt(current_attention * target_attention)
            
            # Attention-modulated learning rate
            effective_lr = self.learning_rate * combined_attention
            
            # Update precision for this transition based on prediction error
            if i < len(forward_errors):
                error = forward_errors[i]
                new_precision = self.base_precision * (1 + 0.1 * error)
                new_precision = max(self.min_precision, min(self.max_precision, new_precision))
                
                if self.forward_graph.has_edge(current_pos, pos):
                    # Slowly update precision for this edge
                    old_precision = self.forward_graph[current_pos][pos].get("precision", self.base_precision)
                    updated_precision = 0.9 * old_precision + 0.1 * new_precision
                    self.forward_graph[current_pos][pos]["precision"] = updated_precision
            
            # Update current position
            current_pos = pos
        
        # Update backward graph
        current_pos = "<START>"
        for i, pos in enumerate(reversed_sequence):
            # Get attention weights
            current_attention = self.attention_weights.get(current_pos, 1.0)
            target_attention = self.attention_weights.get(pos, 1.0)
            
            # Combined attention effect (geometric mean)
            combined_attention = math.sqrt(current_attention * target_attention)
            
            # Attention-modulated learning rate
            effective_lr = self.learning_rate * combined_attention
            
            # Update precision for this transition based on backward prediction error
            if i < len(backward_errors):
                error = backward_errors[i]
                new_precision = self.base_precision * (1 + 0.1 * error)
                new_precision = max(self.min_precision, min(self.max_precision, new_precision))
                
                if self.backward_graph.has_edge(current_pos, pos):
                    # Slowly update precision for this edge
                    old_precision = self.backward_graph[current_pos][pos].get("precision", self.base_precision)
                    updated_precision = 0.9 * old_precision + 0.1 * new_precision
                    self.backward_graph[current_pos][pos]["precision"] = updated_precision
            
            # Update current position
            current_pos = pos

    def _calculate_edge_weights(self):
        """Calculate edge weights (transition probabilities) based on counts for both graphs."""
        # Forward graph
        for node in self.forward_graph.nodes():
            if node == "<END>":
                continue  # End node has no outgoing edges
                
            # Get total count of outgoing transitions
            outgoing_edges = list(self.forward_graph.out_edges(node, data=True))
            total_count = sum(data["count"] for _, _, data in outgoing_edges)
            
            if total_count > 0:
                # Calculate probability for each outgoing edge
                for _, target, data in outgoing_edges:
                    prob = data["count"] / total_count
                    self.forward_graph[node][target]["weight"] = prob
        
        # Backward graph
        for node in self.backward_graph.nodes():
            if node == "<END>":
                continue  # End node has no outgoing edges
                
            # Get total count of outgoing transitions
            outgoing_edges = list(self.backward_graph.out_edges(node, data=True))
            total_count = sum(data["count"] for _, _, data in outgoing_edges)
            
            if total_count > 0:
                # Calculate probability for each outgoing edge
                for _, target, data in outgoing_edges:
                    prob = data["count"] / total_count
                    self.backward_graph[node][target]["weight"] = prob

    def _calculate_boundary_probabilities(self):
        """Calculate boundary probabilities with attention influence for both directions."""
        # Forward direction
        for source, target, data in self.forward_graph.edges(data=True):
            if source in ("<START>", "<END>") or target in ("<START>", "<END>"):
                continue  # Skip special nodes
                
            # Calculate surprisal for this transition
            prob = data.get("weight", 0)
            if prob > 0:
                surprisal = -math.log2(prob)
                
                # Get attention weights for source and target
                source_attention = self.attention_weights.get(source, 1.0)
                target_attention = self.attention_weights.get(target, 1.0)
                
                # Attention-modulated boundary probability
                # Higher attention at either end means more salient boundary
                attention_factor = (source_attention + target_attention) / 2
                
                # Normalize surprisal to a boundary probability between 0 and 1
                raw_boundary_prob = 1 / (1 + math.exp(-(surprisal - 1)))
                
                # Adjust boundary probability based on attention
                boundary_prob = raw_boundary_prob * attention_factor
                
                # Consider predefined boundaries if available
                if (source, target) in self.predefined_boundaries:
                    predefined_prob = self.predefined_boundaries[(source, target)]
                    alpha = 0.3  # Weight for predefined rules
                    boundary_prob = alpha * predefined_prob + (1 - alpha) * boundary_prob
                
                # Store in graph and in lookup dictionary
                boundary_prob = max(0.0, min(1.0, boundary_prob))  # Ensure [0,1] range
                self.forward_graph[source][target]["boundary_prob"] = boundary_prob
                self.forward_boundary_probs[(source, target)] = boundary_prob
        
        # Backward direction
        for source, target, data in self.backward_graph.edges(data=True):
            if source in ("<START>", "<END>") or target in ("<START>", "<END>"):
                continue  # Skip special nodes
                
            # Calculate surprisal for this transition
            prob = data.get("weight", 0)
            if prob > 0:
                surprisal = -math.log2(prob)
                
                # Get attention weights for source and target
                source_attention = self.attention_weights.get(source, 1.0)
                target_attention = self.attention_weights.get(target, 1.0)
                
                # Attention-modulated boundary probability
                attention_factor = (source_attention + target_attention) / 2
                
                # Normalize surprisal to a boundary probability between 0 and 1
                raw_boundary_prob = 1 / (1 + math.exp(-(surprisal - 1)))
                
                # Adjust boundary probability based on attention
                boundary_prob = raw_boundary_prob * attention_factor
                
                # Handle predefined boundaries - note reversed order for backward graph
                if (target, source) in self.predefined_boundaries:  # Reversed for backward direction
                    predefined_prob = self.predefined_boundaries[(target, source)]
                    alpha = 0.3  # Weight for predefined rules
                    boundary_prob = alpha * predefined_prob + (1 - alpha) * boundary_prob
                
                # Store in graph and in lookup dictionary
                boundary_prob = max(0.0, min(1.0, boundary_prob))  # Ensure [0,1] range
                self.backward_graph[source][target]["boundary_prob"] = boundary_prob
                self.backward_boundary_probs[(source, target)] = boundary_prob

    def _combine_boundary_probabilities(self):
        """Combine forward and backward boundary probabilities."""
        # First combine all the transitions found in either direction
        all_transitions = set(self.forward_boundary_probs.keys()) | set(self.backward_boundary_probs.keys())
        
        # For each transition, create a combined boundary probability
        for source, target in all_transitions:
            # Get forward boundary probability
            forward_prob = self.forward_boundary_probs.get((source, target), 0.0)
            
            # Get backward boundary probability (note: need to reverse direction)
            backward_prob = self.backward_boundary_probs.get((target, source), 0.0)
            
            # Weighted combination (can adjust weights as needed)
            combined_prob = self.forward_weight * forward_prob + self.backward_weight * backward_prob
            
            # Store the combined probability
            self.combined_boundary_probs[(source, target)] = combined_prob

    def _identify_common_chunks(self, pos_sequences: List[List[str]]):
        """Identify common chunks using attention-weighted statistics with bidirectional information."""
        # Use a sliding window approach to find potential chunks
        chunk_candidates = Counter()
        
        # Try different chunk sizes
        for size in range(2, 5):  # 2-grams to 4-grams
            for sequence in pos_sequences:
                if len(sequence) < size:
                    continue
                    
                for i in range(len(sequence) - size + 1):
                    chunk = tuple(sequence[i:i+size])
                    chunk_candidates[chunk] += 1
        
        print(f"Found {len(chunk_candidates)} potential chunks")
        
        # For small training sets, lower the threshold
        total_sentences = len(pos_sequences)
        min_occurrences = max(2, int(total_sentences * 0.05))
        
        # Lower the cohesion threshold for small datasets
        cohesion_threshold = 0.6 if total_sentences < 20 else 0.7
        
        # Count qualifying chunks
        qualifying_chunks = 0
        for chunk, count in chunk_candidates.items():
            if count >= min_occurrences:
                qualifying_chunks += 1
                
                # Calculate internal cohesion with attention weighting
                # Using combined boundary probabilities (bidirectional)
                internal_boundaries = 0
                chunk_attention = 1.0  # Start with neutral attention
                
                for i in range(len(chunk) - 1):
                    pos1, pos2 = chunk[i], chunk[i+1]
                    
                    # Get combined boundary probability
                    boundary_prob = self.combined_boundary_probs.get((pos1, pos2), 0.5)
                    
                    # Apply attention weighting - average of the two positions
                    pos1_attention = self.attention_weights.get(pos1, 1.0)
                    pos2_attention = self.attention_weights.get(pos2, 1.0)
                    avg_attention = (pos1_attention + pos2_attention) / 2
                    
                    # Accumulate attention-weighted boundary probability
                    internal_boundaries += boundary_prob
                    
                    # Calculate overall chunk attention (product of position attentions)
                    chunk_attention *= (pos1_attention * 0.5 + 0.5)  # Dampen the effect
                
                avg_internal_boundary = internal_boundaries / (len(chunk) - 1)
                cohesion = 1 - avg_internal_boundary
                
                # Boost cohesion for chunks with high attention
                attention_adjusted_cohesion = cohesion * (0.8 + 0.2 * chunk_attention)
                
                # Only keep reasonably cohesive chunks
                if attention_adjusted_cohesion > cohesion_threshold:
                    chunk_name = f"{'_'.join(chunk)}"
                    self.common_chunks[chunk] = {
                        "name": chunk_name,
                        "elements": chunk,
                        "count": count,
                        "cohesion": attention_adjusted_cohesion,
                        "attention": chunk_attention,
                        "activation": 0.0  # Initial activation level
                    }
                    
                    # Initialize chunk attention weight
                    self.chunk_attention_weights[chunk] = chunk_attention
                    
        print(f"{qualifying_chunks} chunks met frequency criteria, {len(self.common_chunks)} met cohesion criteria")

    def _build_chunk_graph(self):
        """Build higher-order graph representing transitions between chunks."""
        # Add nodes for each chunk
        for chunk_tuple, chunk_info in self.common_chunks.items():
            chunk_name = chunk_info["name"]
            self.chunk_graph.add_node(
                chunk_name, 
                pos_type="chunk", 
                elements=chunk_info["elements"],
                cohesion=chunk_info["cohesion"],
                attention=chunk_info["attention"]
            )
        
        # Connect chunks that can follow each other
        for chunk1_tuple, chunk1_info in self.common_chunks.items():
            for chunk2_tuple, chunk2_info in self.common_chunks.items():
                # Check if chunk2 can follow chunk1 (overlap or adjacency)
                if self._can_follow(chunk1_tuple, chunk2_tuple):
                    # Calculate transition probability
                    # This is simplified - would need corpus analysis for accurate probabilities
                    transition_prob = 0.1  # Default low probability
                    
                    # If we have trigram data, use it to estimate transition probability
                    if len(chunk1_tuple) >= 2 and len(chunk2_tuple) >= 1:
                        last1, last2 = chunk1_tuple[-2], chunk1_tuple[-1]
                        first = chunk2_tuple[0]
                        
                        if last2 in self.trigram_counts.get(last1, {}):
                            total = sum(self.trigram_counts[last1][last2].values())
                            if total > 0:
                                count = self.trigram_counts[last1][last2].get(first, 0)
                                transition_prob = count / total
                    
                    # Apply attention weighting to transition
                    chunk1_attention = self.chunk_attention_weights.get(chunk1_tuple, 1.0)
                    chunk2_attention = self.chunk_attention_weights.get(chunk2_tuple, 1.0)
                    
                    # Higher attention on both chunks strengthens their connection
                    attention_factor = (chunk1_attention + chunk2_attention) / 2
                    weighted_prob = transition_prob * attention_factor
                    
                    # Add edge with weight
                    chunk1_name = chunk1_info["name"]
                    chunk2_name = chunk2_info["name"]
                    self.chunk_graph.add_edge(
                        chunk1_name, 
                        chunk2_name, 
                        weight=weighted_prob,
                        attention=attention_factor
                    )