<a href="https://colab.research.google.com/github/jiveshj/SeniorThesis/blob/main/entropy_and_trellis_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from collections import defaultdict
import scipy.special

class EntropyTrellisAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.paths = {}
        self.scores = {}
        self.states = {}
        self.entropies = {}
        self.logits = {}

    def record_path(self, name, tokens, states, scores, logits=None):
        """
        Record a decoding path for later analysis.

        Args:
            name: Name of the decoding algorithm (e.g., 'viterbi', 'beam', 'greedy')
            tokens: List of token IDs or strings representing the output
            states: List of states visited at each time step
            scores: List of scores (log probabilities) at each time step
            logits: List of logits at each time step (optional, for entropy calculation)
        """
        self.paths[name] = tokens
        self.states[name] = states
        self.scores[name] = scores

        if logits is not None:
            self.logits[name] = logits
            # Calculate entropies from logits
            self.entropies[name] = [self.calculate_conditional_entropy(logit) for logit in logits]

    def calculate_conditional_entropy(self, logits):
        """
        Calculate conditional entropy from token logits.
        Higher values indicate more uncertainty in the next token prediction.

        Args:
            logits: Raw logits from your model for the next token

        Returns:
            Conditional entropy value
        """
        # Convert logits to probabilities using softmax
        probs = scipy.special.softmax(logits, axis=-1)

        # Calculate entropy: -sum(p * log(p))
        entropy = -np.sum(probs * np.log(probs + 1e-10))
        return entropy

    def analyze_state_overlap(self):
        """
        Check if states from greedy and beam search are present in Viterbi trellis.
        Returns a dict with overlap percentages.
        """
        results = {}
        viterbi_states = set(self.states.get('viterbi', []))

        for name, states in self.states.items():
            if name == 'viterbi':
                continue

            overlap = [state for state in states if state in viterbi_states]
            overlap_percent = len(overlap) / len(states) * 100 if states else 0
            results[name] = {
                'overlap_count': len(overlap),
                'total_states': len(states),
                'overlap_percent': overlap_percent
            }

        return results

    def find_divergence_points(self):
        """
        Find points where paths diverge and analyze the scores at these points.
        Also compares entropy at divergence points if available.
        """
        divergences = {}
        viterbi_tokens = self.paths.get('viterbi', [])

        for name, tokens in self.paths.items():
            if name == 'viterbi':
                continue

            # Find position where paths diverge
            diverge_pos = None
            for i, (vt, t) in enumerate(zip(viterbi_tokens, tokens)):
                if vt != t:
                    diverge_pos = i
                    break

            if diverge_pos is not None:
                viterbi_score = self.scores['viterbi'][diverge_pos] if diverge_pos < len(self.scores['viterbi']) else None
                other_score = self.scores[name][diverge_pos] if diverge_pos < len(self.scores[name]) else None

                # Add entropy information if available
                viterbi_entropy = self.entropies.get('viterbi', [None])[diverge_pos] if 'viterbi' in self.entropies and diverge_pos < len(self.entropies['viterbi']) else None
                other_entropy = self.entropies.get(name, [None])[diverge_pos] if name in self.entropies and diverge_pos < len(self.entropies[name]) else None

                divergences[name] = {
                    'position': diverge_pos,
                    'viterbi_token': viterbi_tokens[diverge_pos] if diverge_pos < len(viterbi_tokens) else None,
                    'alternate_token': tokens[diverge_pos] if diverge_pos < len(tokens) else None,
                    'viterbi_score': viterbi_score,
                    'alternate_score': other_score,
                    'score_difference': viterbi_score - other_score if viterbi_score is not None and other_score is not None else None,
                    'viterbi_entropy': viterbi_entropy,
                    'alternate_entropy': other_entropy
                }

        return divergences

    def validate_paths(self):
        """
        Validate if the paths from beam and greedy search exist in the trellis.
        Calculate the full path probabilities.
        """
        # This depends on your trellis implementation
        results = {}
        viterbi_total = sum(self.scores.get('viterbi', [0]))

        for name, scores in self.scores.items():
            if name == 'viterbi':
                continue

            total_score = sum(scores)
            results[name] = {
                'total_score': total_score,
                'viterbi_score': viterbi_total,
                'difference': viterbi_total - total_score,
                'is_valid': True  # Assuming all paths are valid; adjust based on your implementation
            }

        return results

    def visualize_paths_with_entropy(self, output_file=None):
        """
        Visualize the different paths through the trellis along with entropy.
        """
        has_entropy = any(len(self.entropies.get(name, [])) > 0 for name in self.paths)

        if has_entropy:
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12), sharex=True)
        else:
            fig, ax1 = plt.subplots(figsize=(14, 6))

        # Get max length of all paths
        max_len = max(len(tokens) for tokens in self.paths.values())

        # Plot scores for each path
        for name, scores in self.scores.items():
            # Pad scores to max_len if needed
            padded_scores = scores + [None] * (max_len - len(scores))
            valid_scores = [s for s in padded_scores if s is not None]
            positions = list(range(len(valid_scores)))
            ax1.plot(positions, valid_scores, marker='o', label=name)

        ax1.set_ylabel('Log probability score')
        ax1.set_title('Decoding Paths Comparison')
        ax1.legend()
        ax1.grid(True)

        # Plot entropy if available
        if has_entropy:
            for name, entropies in self.entropies.items():
                padded_entropies = entropies + [None] * (max_len - len(entropies))
                valid_entropies = [e for e in padded_entropies if e is not None]
                positions = list(range(len(valid_entropies)))
                ax2.plot(positions, valid_entropies, marker='s', linestyle='--', label=f"{name} entropy")

            ax2.set_xlabel('Position in sequence')
            ax2.set_ylabel('Conditional entropy')
            ax2.set_title('Conditional Entropy at Each Position')
            ax2.legend()
            ax2.grid(True)
        else:
            ax1.set_xlabel('Position in sequence')

        # Highlight divergence points
        divergences = self.find_divergence_points()
        for name, info in divergences.items():
            pos = info['position']
            if pos is not None:
                ax1.axvline(x=pos, color='r', linestyle='--', alpha=0.5)
                ax1.text(pos, min(s for s in sum([list(self.scores.values())], []) if s is not None) * 0.95,
                         f"Divergence at {pos}", rotation=90, color='red')

                if has_entropy:
                    ax2.axvline(x=pos, color='r', linestyle='--', alpha=0.5)

        if output_file:
            plt.savefig(output_file)
        plt.tight_layout()
        plt.show()

    def analyze_entropy_at_divergence(self):
        """
        Analyze the relationship between conditional entropy and path divergence.
        """
        results = {}
        divergences = self.find_divergence_points()

        for name, info in divergences.items():
            pos = info['position']
            if pos is None:
                continue

            # Check if we have entropy data
            if info['viterbi_entropy'] is not None:
                # Calculate entropy statistics
                avg_entropy = np.mean([e for e in self.entropies.get('viterbi', []) if e is not None])
                entropy_at_divergence = info['viterbi_entropy']
                relative_entropy = entropy_at_divergence / avg_entropy if avg_entropy > 0 else None

                results[name] = {
                    'position': pos,
                    'entropy_at_divergence': entropy_at_divergence,
                    'average_entropy': avg_entropy,
                    'relative_entropy': relative_entropy,
                    'is_high_entropy': entropy_at_divergence > avg_entropy
                }

        return results

    def correlate_entropy_with_divergence(self):
        """
        Calculate correlation between entropy and path divergence.
        """
        # Get all positions where paths differ
        viterbi_tokens = self.paths.get('viterbi', [])
        divergence_positions = []

        for name, tokens in self.paths.items():
            if name == 'viterbi':
                continue

            for i, (vt, t) in enumerate(zip(viterbi_tokens, tokens)):
                if vt != t and i not in divergence_positions:
                    divergence_positions.append(i)

        # If we have entropy data and divergence points
        if 'viterbi' in self.entropies and divergence_positions:
            entropies = self.entropies['viterbi']
            max_pos = min(len(entropies), len(viterbi_tokens))

            # Create binary mask: 1 if position has divergence, 0 otherwise
            divergence_mask = np.zeros(max_pos)
            for pos in divergence_positions:
                if pos < max_pos:
                    divergence_mask[pos] = 1

            # Extract entropies for positions we have
            entropy_values = np.array(entropies[:max_pos])

            # Calculate point-biserial correlation (between continuous and binary variables)
            # This tells us if higher entropy correlates with divergence points
            mean_entropy_diverge = np.mean(entropy_values[divergence_mask == 1])
            mean_entropy_no_diverge = np.mean(entropy_values[divergence_mask == 0])
            std_entropy = np.std(entropy_values)
            n_diverge = np.sum(divergence_mask)
            n_no_diverge = len(divergence_mask) - n_diverge
            n_total = len(divergence_mask)

            # Point-biserial formula
            if std_entropy > 0 and n_diverge > 0 and n_no_diverge > 0:
                correlation = ((mean_entropy_diverge - mean_entropy_no_diverge) / std_entropy) * \
                               np.sqrt((n_diverge * n_no_diverge) / (n_total * n_total))

                return {
                    'correlation': correlation,
                    'mean_entropy_at_divergence': mean_entropy_diverge,
                    'mean_entropy_elsewhere': mean_entropy_no_diverge,
                    'entropy_ratio': mean_entropy_diverge / mean_entropy_no_diverge if mean_entropy_no_diverge > 0 else None
                }

        return None

    def visualize_trellis(self, timesteps=10, k_best=5, output_file=None):
        """
        Visualize the trellis structure with the different paths.

        Args:
            timesteps: Number of time steps to visualize
            k_best: Number of best states to show at each time step
            output_file: If provided, save the figure to this file
        """
        # Create a combined set of all states visited
        all_states = defaultdict(lambda: defaultdict(float))

        # For simplicity, we assume states dict contains token IDs at each position
        for name, states_path in self.states.items():
            for t, state in enumerate(states_path):
                if t >= timesteps:
                    break
                all_states[t][state] = max(all_states[t].get(state, 0),
                                         self.scores[name][t] if t < len(self.scores[name]) else 0)

        # Create a matrix for visualization
        matrix = np.zeros((k_best, timesteps))
        state_labels = [[] for _ in range(timesteps)]

        for t in range(timesteps):
            if t not in all_states:
                continue

            # Get top k states by score
            top_states = sorted(all_states[t].items(), key=lambda x: x[1], reverse=True)[:k_best]

            for i, (state, score) in enumerate(top_states):
                matrix[i, t] = score
                # Try to convert state to readable token if it's an integer
                if isinstance(state, int):
                    try:
                        token = self.tokenizer.decode([state])
                    except:
                        token = str(state)
                else:
                    token = str(state)
                state_labels[t].append(token)

        # Create a heatmap
        plt.figure(figsize=(16, 10))
        ax = sns.heatmap(matrix, cmap="YlGnBu", annot=False)

        # Add path markers
        for name, states_path in self.states.items():
            path_y = []
            path_x = []

            for t, state in enumerate(states_path):
                if t >= timesteps:
                    break

                # Find this state in our top-k states
                try:
                    top_states = sorted(all_states[t].items(), key=lambda x: x[1], reverse=True)[:k_best]
                    state_idx = [s[0] for s in top_states].index(state)
                    path_y.append(state_idx + 0.5)  # +0.5 for center of cell
                    path_x.append(t + 0.5)
                except ValueError:
                    # State not in top-k, skip this point
                    continue

            # Plot the path
            plt.plot(path_x, path_y, marker='o', linewidth=2,
                     label=name, alpha=0.7)

        # Add labels and other details
        plt.title("Trellis Paths Visualization")
        plt.ylabel("Top-k states at each time step")
        plt.xlabel("Time step")
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3)

        # Set custom y-tick labels (state labels)
        y_ticks = np.arange(k_best) + 0.5
        plt.yticks(y_ticks, [''] * k_best)

        # Add state labels as text annotations
        for t in range(timesteps):
            for i, label in enumerate(state_labels[t]):
                if i < k_best:
                    plt.text(t + 0.5, i + 0.5, label,
                            ha='center', va='center', fontsize=8)

        if output_file:
            plt.savefig(output_file)
        plt.tight_layout()
        plt.show()

# Example usage function - now with entropy analysis
def compare_decoding_paths_with_entropy(input_text, model, tokenizer, viterbi_fn, greedy_fn, beam_fn):
    """
    Compare different decoding paths for the same input, with entropy analysis.

    Args:
        input_text: The input text to decode
        model: Your language model
        tokenizer: Your tokenizer
        viterbi_fn: Function that returns (tokens, states, scores, logits) for Viterbi decoding
        greedy_fn: Function that returns (tokens, states, scores, logits) for greedy decoding
        beam_fn: Function that returns (tokens, states, scores, logits) for beam search
    """
    analyzer = EntropyTrellisAnalyzer(model, tokenizer)

    # Run each decoding method and record results (now with logits for entropy calculation)
    viterbi_tokens, viterbi_states, viterbi_scores, viterbi_logits = viterbi_fn(input_text)
    analyzer.record_path('viterbi', viterbi_tokens, viterbi_states, viterbi_scores, viterbi_logits)

    greedy_tokens, greedy_states, greedy_scores, greedy_logits = greedy_fn(input_text)
    analyzer.record_path('greedy', greedy_tokens, greedy_states, greedy_scores, greedy_logits)

    beam_tokens, beam_states, beam_scores, beam_logits = beam_fn(input_text)
    analyzer.record_path('beam', beam_tokens, beam_states, beam_scores, beam_logits)

    # Run basic analysis
    print("=== State Overlap Analysis ===")
    overlap = analyzer.analyze_state_overlap()
    for name, results in overlap.items():
        print(f"{name.capitalize()} search: {results['overlap_percent']:.2f}% of states appear in Viterbi trellis")

    print("\n=== Divergence Points Analysis ===")
    divergences = analyzer.find_divergence_points()
    for name, info in divergences.items():
        if info['position'] is not None:
            print(f"{name.capitalize()} diverges from Viterbi at position {info['position']}:")
            print(f"  Viterbi chose '{info['viterbi_token']}' (score: {info['viterbi_score']:.4f})")
            print(f"  {name.capitalize()} chose '{info['alternate_token']}' (score: {info['alternate_score']:.4f})")
            print(f"  Score difference: {info['score_difference']:.4f}")

            # Print entropy information if available
            if info['viterbi_entropy'] is not None:
                print(f"  Entropy at divergence point: {info['viterbi_entropy']:.4f}")

    print("\n=== Path Validation ===")
    validations = analyzer.validate_paths()
    for name, results in validations.items():
        print(f"{name.capitalize()} path total score: {results['total_score']:.4f}")
        print(f"Viterbi path total score: {results['viterbi_score']:.4f}")
        print(f"Difference: {results['difference']:.4f}")
        print(f"Path exists in trellis: {'Yes' if results['is_valid'] else 'No'}")

    # Run entropy-specific analysis
    print("\n=== Entropy Analysis at Divergence Points ===")
    entropy_analysis = analyzer.analyze_entropy_at_divergence()
    if entropy_analysis:
        for name, results in entropy_analysis.items():
            print(f"{name.capitalize()} divergence entropy analysis:")
            print(f"  Entropy at divergence point: {results['entropy_at_divergence']:.4f}")
            print(f"  Average entropy across sequence: {results['average_entropy']:.4f}")
            print(f"  Relative entropy (divergence/average): {results['relative_entropy']:.4f}")
            print(f"  Is high entropy point: {'Yes' if results['is_high_entropy'] else 'No'}")
    else:
        print("No entropy data available for analysis.")

    # Calculate correlation between entropy and divergence
    print("\n=== Entropy-Divergence Correlation ===")
    correlation = analyzer.correlate_entropy_with_divergence()
    if correlation:
        print(f"Correlation between entropy and path divergence: {correlation['correlation']:.4f}")
        print(f"Mean entropy at divergence points: {correlation['mean_entropy_at_divergence']:.4f}")
        print(f"Mean entropy elsewhere: {correlation['mean_entropy_elsewhere']:.4f}")
        print(f"Ratio of entropy at divergence vs. elsewhere: {correlation['entropy_ratio']:.4f}")

        if correlation['correlation'] > 0.3:
            print("CONCLUSION: Strong positive correlation suggests higher entropy regions are associated with path divergence.")
        elif correlation['correlation'] < -0.3:
            print("CONCLUSION: Strong negative correlation suggests lower entropy regions are associated with path divergence.")
        else:
            print("CONCLUSION: No strong correlation between entropy and path divergence.")
    else:
        print("Insufficient data to calculate entropy-divergence correlation.")

    # Visualize the results
    print("\nGenerating visualizations...")
    analyzer.visualize_paths_with_entropy()
    analyzer.visualize_trellis()

    return analyzer

# Helper function to adapt code to your specific trellis implementation
def extract_trellis_data_with_logits(trellis, path_indices, token_logits=None):
    """
    Extract tokens, states, scores, and logits from a trellis for a specific path.

    Args:
        trellis: Your trellis data structure
        path_indices: The indices of the states in the chosen path
        token_logits: Optional list of logits for each position (for entropy calculation)

    Returns:
        (tokens, states, scores, logits) tuples
    """
    tokens = []
    states = []
    scores = []
    logits = []

    # Example implementation - adapt to your trellis structure
    for t, idx in enumerate(path_indices):
        if t < len(trellis) and idx < len(trellis[t]):
            token = trellis[t][idx].token  # Assume trellis cells have .token attribute
            state = trellis[t][idx].state  # Assume trellis cells have .state attribute
            score = trellis[t][idx].score  # Assume trellis cells have .score attribute

            tokens.append(token)
            states.append(state)
            scores.append(score)

            # Add logits if available
            if token_logits is not None and t < len(token_logits):
                logits.append(token_logits[t])

    return tokens, states, scores, logits

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from collections import defaultdict
import scipy.special

class IterativeTrellisAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.iteration_paths = {}  # Will store paths for each iteration length
        self.entropies = {}
        self.trellis_paths = []
        self.scores = {}

    def add_iteration_result(self, iteration_length, viterbi_path, beam_path, greedy_path,all_trellis_path_ids):
        """
        Record decoding paths for a specific iteration length.

        Args:
            iteration_length: Number of tokens generated
            viterbi_path: List of state indices for the Viterbi path
            beam_path: List of state indices for the beam search path
            greedy_path: List of state indices for the greedy search path
        """
        self.iteration_paths[iteration_length] = {
            'viterbi': viterbi_path,
            'beam': beam_path,
            'greedy': greedy_path
        }
        self.trellis_paths = all_trellis_path_ids  #added this

    def find_path_shift_points(self):
        """
        Find points where paths shift between iterations.
        For example, when going from 3 tokens to 4 tokens, does the path for the first 3 tokens change?
        """
        iterations = sorted(self.iteration_paths.keys())
        print("iterations in find_path_shift points: ", iterations)
        shifts = {}

        for i in range(len(iterations) - 1):
            curr_iter = iterations[i]
            next_iter = iterations[i+1]

            # For each decoding method
            for method in ['viterbi', 'beam', 'greedy']:
                curr_path = self.iteration_paths[curr_iter].get(method, [])
                next_path = self.iteration_paths[next_iter].get(method, [])

                # Compare paths up to the length of the shorter path
                min_len = min(len(curr_path), len(next_path))
                if min_len == 0:
                    continue

                # Check if paths diverge
                diverges = False
                diverge_pos = None

                for pos in range(min_len):
                    if curr_path[pos] != next_path[pos]:
                        diverges = True
                        diverge_pos = pos
                        break

                key = f"{method}_{curr_iter}_to_{next_iter}"
                shifts[key] = {
                    'method': method,
                    'from_iteration': curr_iter,
                    'to_iteration': next_iter,
                    'diverges': diverges,
                    'diverge_position': diverge_pos,
                    'path_before': curr_path[:min_len],
                    'path_after': next_path[:min_len]
                }

        return shifts

    def compare_paths_at_iteration(self, iteration):
        """
        Compare paths between different decoding methods at a specific iteration length.
        """
        if iteration not in self.iteration_paths:
            return None

        paths = self.iteration_paths[iteration]
        comparisons = {}

        # Compare Viterbi with other methods
        viterbi_path = paths.get('viterbi', [])
        for method in ['beam', 'greedy']:
            other_path = paths.get(method, [])
            min_len = min(len(viterbi_path), len(other_path))

            if min_len == 0:
                continue

            # Check for path divergence
            diverges = False
            diverge_pos = None

            for pos in range(min_len):
                if viterbi_path[pos] != other_path[pos]:
                    diverges = True
                    diverge_pos = pos
                    break

            comparisons[f"viterbi_vs_{method}"] = {
                'iteration': iteration,
                'diverges': diverges,
                'diverge_position': diverge_pos,
                'viterbi_path': viterbi_path[:min_len],
                'other_path': other_path[:min_len],
                'other_method': method
            }

        return comparisons

    def decoder_in_trellis(self,iteration,beam_path,greedy_path,trellis_paths):
        return beam_path in trellis_paths or greedy_path in trellis_paths

    def analyze_all_iterations(self):
        """
        Analyze path divergence across all recorded iterations.
        """
        results = {
            'path_shifts': self.find_path_shift_points(),
            'method_comparisons': {}
        }

        for iteration in sorted(self.iteration_paths.keys()):
            comparison = self.compare_paths_at_iteration(iteration)
            if comparison:
                results['method_comparisons'][iteration] = comparison

        return results

    def visualize_path_stability(self, max_states=10, output_file=None):
        """
        Visualize how paths change across iterations.

        Args:
            max_states: Maximum number of states to show in the visualization
            output_file: If provided, save the figure to this file
        """
        iterations = sorted(self.iteration_paths.keys())
        if not iterations:
            print("No iteration data to visualize.")
            return

        methods = ['viterbi', 'beam', 'greedy']
        colors = {'viterbi': 'blue', 'beam': 'green', 'greedy': 'red'}

        # Create a figure with subplots for each method
        fig, axes = plt.subplots(len(methods), 1, figsize=(14, 4 * len(methods)), sharex=True)

        for i, method in enumerate(methods):
            ax = axes[i] if len(methods) > 1 else axes

            # Create a matrix of paths for this method across iterations
            # Each row is an iteration, each column is a position in the path
            max_iteration_len = max(len(self.iteration_paths[it].get(method, []))
                                   for it in iterations)

            path_matrix = np.ones((len(iterations), max_iteration_len)) * np.nan

            # Fill the matrix with path data
            for j, iteration in enumerate(iterations):
                path = self.iteration_paths[iteration].get(method, [])
                for k, state in enumerate(path):
                    if state < max_states:  # Only include states below max_states
                        path_matrix[j, k] = state

            # Create a heatmap
            im = ax.imshow(path_matrix, aspect='auto', cmap='viridis',
                          interpolation='nearest', vmin=0, vmax=max_states-1)

            # Add colorbar
            cbar = fig.colorbar(im, ax=ax, orientation='vertical')
            cbar.set_label('State index')

            # Add labels and title
            ax.set_ylabel('Iteration length')
            ax.set_title(f'{method.capitalize()} paths across iterations')

            # Set y-tick labels as iteration lengths
            ax.set_yticks(range(len(iterations)))
            ax.set_yticklabels(iterations)

            # Highlight path shifts
            shifts = self.find_path_shift_points()
            for key, info in shifts.items():
                if info['method'] == method and info['diverges']:
                    from_idx = iterations.index(info['from_iteration'])
                    to_idx = iterations.index(info['to_iteration'])
                    pos = info['diverge_position']
                    if pos is not None and pos < max_iteration_len:
                        ax.add_patch(plt.Rectangle((pos-0.5, to_idx-0.5), 1, 1,
                                                 fill=False, edgecolor='red', linewidth=2))

        # Set common x-label
        fig.text(0.5, 0.04, 'Position in sequence', ha='center')

        plt.tight_layout()
        if output_file:
            plt.savefig(output_file)
        plt.show()

    def visualize_method_comparison(self, iteration, output_file=None):
        """
        Visualize path comparison between methods at a specific iteration length.

        Args:
            iteration: Iteration length to visualize
            output_file: If provided, save the figure to this file
        """
        if iteration not in self.iteration_paths:
            print(f"No data for iteration length {iteration}")
            return

        paths = self.iteration_paths[iteration]
        methods = [m for m in ['viterbi', 'beam', 'greedy'] if m in paths]

        if not methods:
            print(f"No method data for iteration length {iteration}")
            return

        # Get the max path length
        max_len = max(len(paths[m]) for m in methods)

        # Create a figure
        plt.figure(figsize=(14, 6))

        # Plot each path
        for method in methods:
            path = paths[method]
            plt.plot(range(len(path)), path, 'o-', label=method)

        # Add labels and title
        plt.xlabel('Position in sequence')
        plt.ylabel('State index')
        plt.title(f'Path comparison at iteration length {iteration}')
        plt.legend()
        plt.grid(True)

        # Add annotations for divergence points
        comparisons = self.compare_paths_at_iteration(iteration)
        if comparisons:
            for key, info in comparisons.items():
                if info['diverges']:
                    pos = info['diverge_position']
                    if pos is not None:
                        plt.axvline(x=pos, color='r', linestyle='--', alpha=0.5)
                        plt.text(pos, plt.ylim()[0] * 0.9, f"Divergence at {pos}",
                               rotation=90, color='red')

        if output_file:
            plt.savefig(output_file)
        plt.tight_layout()
        plt.show()

    def visualize_trellis_with_multiple_paths(self, iteration, top_k=10, output_file=None):
        """
        Visualize the trellis structure with paths from different methods and iterations.

        Args:
            iteration: Iteration length to visualize
            top_k: Number of top states to show at each position
            output_file: If provided, save the figure to this file
        """
        if iteration not in self.iteration_paths:
            print(f"No data for iteration length {iteration}")
            return

        paths = self.iteration_paths[iteration]
        methods = [m for m in ['viterbi', 'beam', 'greedy'] if m in paths]

        if not methods:
            print(f"No method data for iteration length {iteration}")
            return

        # Get the max path length
        max_len = max(len(paths[m]) for m in methods)

        # Create a simple synthetic trellis for visualization
        # In a real implementation, you would use your actual trellis data
        synthetic_trellis = np.zeros((top_k, max_len))

        # Fill in the paths we know about
        for method in methods:
            path = paths[method]
            for t, state in enumerate(path):
                if state < top_k:
                    synthetic_trellis[state, t] = 1  # Mark this state as visited

        # Create the visualization
        plt.figure(figsize=(16, 10))

        # Plot the trellis as a grid
        plt.imshow(synthetic_trellis, cmap='Blues', alpha=0.3, aspect='auto')

        # Plot each path
        colors = {'viterbi': 'blue', 'beam': 'green', 'greedy': 'red'}
        for method in methods:
            path = paths[method]
            y_coords = [state for state in path if state < top_k]
            x_coords = list(range(len(y_coords)))
            plt.plot(x_coords, y_coords, 'o-', label=method, color=colors.get(method, 'black'))

        # Add labels and title
        plt.xlabel('Position in sequence')
        plt.ylabel('State index')
        plt.title(f'Trellis paths at iteration length {iteration}')
        plt.legend()
        plt.grid(True)

        # Add y-axis ticks
        plt.yticks(range(top_k))

        if output_file:
            plt.savefig(output_file)
        plt.tight_layout()
        plt.show()

def analyze_paths_across_iterations(model, tokenizer, decode_fn, input_text, max_length=10):
    """
    Analyze how paths change across different iteration lengths.

    Args:
        model: Your language model
        tokenizer: Your tokenizer
        decode_fn: Function that takes (input_text, length) and returns paths for each method
        input_text: The input text to decode
        max_length: Maximum number of tokens to generate
    """
    analyzer = IterativeTrellisAnalyzer(model, tokenizer)

    # Run decoding for each length
    for length in range(1, max_length + 1):
        print(f"Decoding iteration {length}...")
        viterbi_path, beam_path, greedy_path = decode_fn(input_text, length)
        analyzer.add_iteration_result(length, viterbi_path, beam_path, greedy_path)

    # Run analysis
    print("\n=== Path Stability Analysis ===")
    analysis = analyzer.analyze_all_iterations()

    # Report path shifts
    print("\nPath shifts between iterations:")
    for key, info in analysis['path_shifts'].items():
        if info['diverges']:
            print(f"{info['method'].capitalize()} path changes when going from {info['from_iteration']} to {info['to_iteration']} tokens")
            print(f"  Divergence at position: {info['diverge_position']}")
            print(f"  Path before: {info['path_before']}")
            print(f"  Path after:  {info['path_after']}")

    # Report method differences
    print("\nMethod comparisons at each iteration:")
    for iteration, comparisons in analysis['method_comparisons'].items():
        print(f"\nIteration length {iteration}:")
        for key, info in comparisons.items():
            if info['diverges']:
                print(f"  {info['other_method'].capitalize()} diverges from Viterbi at position {info['diverge_position']}")
                print(f"    Viterbi path: {info['viterbi_path']}")
                print(f"    {info['other_method'].capitalize()} path: {info['other_path']}")

    # Visualize results
    print("\nGenerating visualizations...")
    analyzer.visualize_path_stability()

    # Visualize a few specific iterations
    mid_point = max_length // 2
    analyzer.visualize_method_comparison(mid_point)
    analyzer.visualize_trellis_with_multiple_paths(mid_point)

    if max_length > 1:
        analyzer.visualize_method_comparison(max_length)
        analyzer.visualize_trellis_with_multiple_paths(max_length)

    return analyzer

# Example wrapper function to adapt to your implementation
def run_decoding_for_length(input_text, length):
    """
    Run all decoding methods for a specific length.
    Adapt this to your implementation.

    Args:
        input_text: Input text to decode
        length: Number of tokens to generate

    Returns:
        (viterbi_path, beam_path, greedy_path) - each is a list of state indices
    """
    # This is where you'd call your actual decoders
    # Example placeholder implementation:
    viterbi_path = []  # Replace with your actual Viterbi path
    beam_path = []     # Replace with your actual beam search path
    greedy_path = []   # Replace with your actual greedy search path

    return viterbi_path, beam_path, greedy_path

In [None]:
import math
class SearchTree:
    def __init__(self,context,probability,token_id,model,tokenizer,parent = None,child = None,parent_index = None):
        self.token_id = token_id
        context = context.strip()
        self.context = context
        self.probability = probability
        self.parent = parent
        self.child = []
        self.parent_index = parent_index  # newly created.
        self.cached_tokenids = [self.token_id]
        if child is not None:
           self.child.append(child)

        # Cache cumulative probability at node creation
        if parent:
            self.cached_prob = parent.calcProbTillNow()+probability #parent.calcProbTillNow() * probability
            self.cached_tokenids.insert(0,parent.token_id)
        else:
            self.cached_prob = probability

    def build_Context(self):
        context_list = []
        full_context = []
        node = self
        while node.parent is not None:
            context_list.extend([node.token_id])
            node = node.parent
        context_list.reverse()
        full_context.extend(node.token_id)
        full_context.extend(context_list)
        full_context = torch.tensor([full_context])
        generated_sentence = tokenizer.decode(full_context[0], skip_special_tokens=True)
        return generated_sentence


    def create_child(self):
        if self.parent is not None:
           self.parent.child.append(self)

    def replace_parent(self, new_parent):
        """Assign a new parent and update cached probability."""
        self.parent = new_parent
        self.cached_prob = new_parent.calcProbTillNow() + self.probability


    def calcProbTillNow(self):
        """Return cached cumulative probability to avoid redundant calculations."""
        return self.cached_prob

    def token_idsTillNow(self):
        "Return the token ids of all the tokens up till now. This method is to be used in the last step of the trellis to get all the paths in there."
        return self.cached_tokenids
    def change_probability(self,new_probability,new_cached_prob):
        self.cached_prob = new_cached_prob
        self.probability = new_probability

    # def calcProbTillNow(self):
    #   prob = self.probability
    #   node = self
    #   while node.parent is not None:
    #     prob = prob*node.parent.probability
    #     node = node.parent
    #   return prob    #can make this negative log probability.

    def assign_parent_index(self,parent_index):
      self.parent_index = parent_index



def generate_token_and_probability(model, tokenizer, batch_prompts,max_length=1, top_k=4):
    tokenizer.pad_token= tokenizer.eos_token
    tokenized_result = tokenizer(batch_prompts, return_tensors="pt",padding = True,truncation = True)
    input_ids = tokenized_result["input_ids"].to(model.device)
    attn_mask = tokenized_result["attention_mask"].to(model.device)
    #Added new stuff for handling newlines below
    for sentence_id in range(len(input_ids)):
        if (input_ids[sentence_id][-1] == 50256 and input_ids[sentence_id][-2] == 628):
            input_ids[sentence_id][-2] = 198
            input_ids[sentence_id][-1] = 198
            attn_mask[sentence_id][-1] = 1
            attn_mask[sentence_id][-2] = 1

    num_sentences = len(input_ids)
    if (num_sentences == 1):
        if (input_ids[0][-1] == 628):
            input_ids = input_ids[:, :-1]
            attn_mask = attn_mask[:, :-1]
            input_ids = torch.cat((input_ids,torch.tensor([[198,198]]).to(model.device)),dim = 1)
            attn_mask = torch.cat((attn_mask,torch.tensor([[1,1]]).to(model.device)),dim = 1)

    with torch.no_grad():
      outputs = model.generate(
         input_ids=input_ids,
         attention_mask=attn_mask,
         max_length=input_ids.size(-1) + max_length,
         do_sample=False,  # Greedy decoding
         output_scores=True,
         return_dict_in_generate=True
        )

    sequences, scores = outputs.sequences, outputs.scores  # scores will have only one element per batch
    predictions = []
    # print("generated_token_id",generated_token_id)
    for i in range(len(batch_prompts)):
        generated_token_id = sequences[i][input_ids.size(-1):].tolist()[0]  # Extract generated token ID
        generated_token = tokenizer.decode(generated_token_id, skip_special_tokens=True)


        # Log probabilities of all possible tokens at the generated step
        log_probs = torch.nn.functional.log_softmax(scores[0][i], dim=-1)  # scores[0] corresponds to the single generation step
        # Keep increasing top_k until we have enough valid tokens
        valid_predictions = []
        curr_top_k = top_k
        topk_logprobs, topk_ids = log_probs.topk(curr_top_k)  # Get top-k log probabilities
        topk_tokens = tokenizer.batch_decode(topk_ids, skip_special_tokens=True)
        predictions.append([(generated_token,tid, tok, lp.item()) for tid, tok, lp in zip(topk_ids, topk_tokens, topk_logprobs)])


    return predictions


def check_bad_predictions(text):
    bad_patterns = [r'={2,}', r'!{2,}', r'\?{2,}', r',{2,}', r';{2,}', r'\|{2,}', r'~{2,}', r'&{2,}', r'-{2,}']

    # Check for unwanted punctuation patterns (two or more consecutive occurrences)
    for pattern in bad_patterns:
        if re.search(pattern, text):
            return True

    # Check for non-ASCII characters
    if any(ord(char) > 127 for char in text):  # ASCII characters are in the range 0-127
        return True

    return False

def generateIntermediates(root,model,tokenizer,numTokens = 3, loop_runner = 4,**kwargs):
  root_token_id = tokenizer.encode(root)
  sentence = SearchTree(root,0,token_id =root_token_id,model = model, tokenizer = tokenizer)
  context = []
  entropy_array = []
  num_tokens = numTokens
  content = []
  probability = []
  with torch.no_grad():
     tokens_50K = generate_token_and_probability(model, tokenizer, [root], top_k=numTokens)

  #unique_elements = []   # to store unique elements at each iteration
  unique_tokens = set()
  probabilityMatrix = []
  uniqueTokensList = []
  new_content = []
  uniqueTokenLength = []
  lastTokens_probability = []
  flops_counter = {}
  generated_sentence_GI = ''
  batch_size = 15
  holdout_number = 15
  trellis_paths = []
  for i in range(num_tokens):
    _,token_id,context,prob = tokens_50K[0][i]  # Assuming it's structured as a tuple (best_token, token, probability)
    # context = context.strip()  #This is not the correct solution. I am doing this rather than only leaving one strip command in search tree because I am appending to unique tokens before I am assigning this to search tree.
    # context2 = context.strip()
    # bad_prediction_checker = check_bad_predictions(context2)
    initial_loop_probability.append(prob)
#     print("initial_loop:")
#     print(tokens_50K[0][i])
    unique_tokens.add(context)
    probability.append(prob)
    context = SearchTree(context,prob,token_id = token_id.item(),model = model,tokenizer = tokenizer,parent =sentence,parent_index = 0)
    new_content.append(context)
    context.create_child()
    uniqueTokensList.append(context)

  entropy_array.append([entropy(np.array([math.exp(prob) for prob in probability]))])
  content.append(new_content)
  previousUniqueLength = num_tokens
  #unique_elements.append(unique_tokens)
  initialStateProbability = probability
  uniqueTokenLength.append(num_tokens)
  max_index = initial_loop_probability.index(max(initial_loop_probability))
  generated_sentence_GI = uniqueTokensList[max_index].build_Context()
  for i in range(2,loop_runner):
    unique_tokens = set()
    probability = []
    entropies = []
    new_content = []
    total_predictions = []
    previousSetLength = 0
    batch_sentences = [child.build_Context() for child in uniqueTokensList]

    if len(batch_sentences)>batch_size:
        total_predictions = []
        start_index = 0
        num_sentences_left = len(batch_sentences)
        while (num_sentences_left>batch_size):
            batch_sentences2 = batch_sentences[start_index*batch_size:(start_index+1)*batch_size]
            with torch.no_grad():
              batch_predictions =  generate_token_and_probability(model, tokenizer, batch_sentences2, top_k=numTokens)
            total_predictions.extend(batch_predictions)
            start_index +=1
            num_sentences_left -= batch_size
        if num_sentences_left > 0:
           batch_sentences2 = batch_sentences[start_index*batch_size :]
           with torch.no_grad():
             batch_predictions =  generate_token_and_probability(model, tokenizer, batch_sentences2, top_k=numTokens)
           total_predictions.extend(batch_predictions)
    else:
        with torch.no_grad():
            total_predictions = generate_token_and_probability(model, tokenizer, batch_sentences,top_k=numTokens)

    for j in range(len(uniqueTokensList)):
      for s in range(num_tokens):
        _,token_id,context,prob = total_predictions[j][s]
        context2 = context.strip()
        #bad_predictions_checker = check_bad_predictions(context2)
        # if context2:
        unique_tokens.add(context)   # also this if condition is not the correct solution
        context = SearchTree(context,prob,token_id = token_id.item(),model = model,tokenizer = tokenizer,parent = uniqueTokensList[j])   #probably redundant: Because I should only create SearchTree of unique tokens
        # context.create_child() Removed this 2/19/2025
        if (len(unique_tokens)>previousSetLength):
          previousSetLength = len(unique_tokens)
          uniqueTokensList.append(context)
          new_content.append(context)

    #unique_elements.append(unique_tokens) # append the unique tokens list at each iteration to unique_elements list
    content.append(new_content) # for storing tokens which will pass to the decode_path function.


    comb_prob = []
    for prevToken in uniqueTokensList[:previousUniqueLength]:
      comb_prob.append(findProbability(prevToken,uniqueTokensList[previousUniqueLength:], model,tokenizer))
    comb_prob = list(itertools.chain(*comb_prob)) # flattening the list

    for tokenumber,newToken in enumerate(uniqueTokensList[previousUniqueLength:]):
      probs = [comb_prob[a*len(uniqueTokensList[previousUniqueLength:]) + tokenumber] for a in range(len(uniqueTokensList[:previousUniqueLength]))]
      probs2 = [probs[b] + uniqueTokensList[:previousUniqueLength][b].calcProbTillNow() for b in range(len(probs))]
      entropies.append(list(entropy(np.array([math.exp(prob) for prob in probs]))))
#       print("parent_prob Up Till now: ",[uniqueTokensList[:previousUniqueLength][i].calcProbTillNow() for i in range(len(probs))])
#       print("combined probs: ", probs2)
#       print("actual_probs: ", [math.exp(probs2[i]) for i in range(len(probs2))])
      if not probs2:
        continue
      else:
        max_value = max(probs2)
        max_index = probs2.index(max_value)
        new_transition_probability = probs[max_index]
        newToken.replace_parent(uniqueTokensList[:previousUniqueLength][max_index])
        newToken.change_probability(new_transition_probability,max_value) # just added this 4/4/2025
        newToken.assign_parent_index(max_index)
        if (i == loop_runner-1):
             #print("parent_assigning_loop")
#             print(tokenumber)
             #print("uniqueToken.context: ",uniqueTokensList[previousUniqueLength+tokenumber].context)
#             print("parent_calc_prob_till_now: ",probs3)
#             print("new_context: ",uniqueTokensList[previousUniqueLength+tokenumber].build_Context())
             #print("new_transition_probability: ", new_transition_probability)
#             print("new_total_prob: ", max_value)
            lastTokens_probability.append(max_value)
            trellis_paths.append(newToken.token_idsTillNow())

      probability.append(probs)
    probabilityMatrix.append(probability)
    entropy_array.append(entropies)
    # flops_counter[i-1] = model.get_batch_prediction_count()
    #model.reset_batch_prediction_count()

    uniqueTokenLength.append(len(uniqueTokensList[previousUniqueLength:]))

    previousUniqueLength = len(uniqueTokensList[previousUniqueLength:])
    uniqueTokensList = uniqueTokensList[len(uniqueTokensList)-previousUniqueLength:]

    if (i ==loop_runner-1):
        max_lastToken = max(lastTokens_probability)
        max_lastTokenIndex = lastTokens_probability.index(max_lastToken)
        generated_sentence_GI = uniqueTokensList[max_lastTokenIndex].build_Context()

  return {"probabilityMatrix": probabilityMatrix, "initialStateProbability": initialStateProbability,"content": content,"uniqueTokenLength": uniqueTokenLength,
          "generated_sentence_GI": generated_sentence_GI, "entropies_array": entropy_array,"trellis_paths":trellis_paths} #, flops_counter
def runViterbiTransformerPipeline(rootSentence, numTokens = 3, loop_runner=3,**kwargs):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    result= generateIntermediates(rootSentence,model,tokenizer,numTokens = numTokens,loop_runner =loop_runner+1,**kwargs)
    probabilityMatrix,initialStateProbability,content,uniqueTokenLength,generated_sentence_GI,entropy_array,trellis_paths  = result[probabilityMatrix],result[initialStateProbability],result[content],result[uniqueTokenLength],result[generated_sentence_GI],result[entropy_array],result[trellis_paths]
    best_path,viterbi_mat,best_path_prob = VITERBI_Lists(probabilityMatrix, initialStateProbability,device)
    print("uniqueTokenLength: ", uniqueTokenLength)
    print("best_path: ", best_path)
    decodedString = decodePath(best_path,content,rootSentence,tokenizer)
    return decodedString,best_path_prob,generated_sentence_GI,entropy_array