In [6]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from collections import Counter
import itertools

class QAOAComparisonAnalyzer:
    """
    QAOA Energy Spectrum Comparison for 4-state vs 16-state problems
    """
    
    def __init__(self):
        # Define specific problems for comparison
        self.problems = {
            '2_vertex': {
                'adj': np.array([[0, 1], 
                                [1, 0]]),
                'name': '2-Vertex Graph (4 states)',
                'description': 'Simple edge: 2 qubits, 4 computational basis states',
                'num_states': 4
            },
            '4_vertex_cycle': {
                'adj': np.array([[0, 1, 0, 1], 
                                [1, 0, 1, 0], 
                                [0, 1, 0, 1], 
                                [1, 0, 1, 0]]),
                'name': '4-Vertex Cycle (16 states)',
                'description': 'Square graph: 4 qubits, 16 computational basis states',
                'num_states': 16
            },
            '4_vertex_complete': {
                'adj': np.array([[0, 1, 1, 1], 
                                [1, 0, 1, 1], 
                                [1, 1, 0, 1], 
                                [1, 1, 1, 0]]),
                'name': '4-Vertex Complete (16 states)',
                'description': 'Complete graph K4: 4 qubits, 16 states, highly connected',
                'num_states': 16
            }
        }
    
    def calculate_maxcut_energy(self, bit_string, adjacency_matrix):
        """Calculate MaxCut energy for a bit string configuration."""
        n = len(bit_string)
        energy = 0
        
        for i in range(n):
            for j in range(i + 1, n):
                if adjacency_matrix[i, j] == 1:
                    energy += int(bit_string[i] != bit_string[j])
        
        return energy
    
    def generate_energy_spectrum(self, adjacency_matrix):
        """Generate complete energy spectrum for all possible bit strings."""
        n = adjacency_matrix.shape[0]
        num_states = 2**n
        
        states = []
        energies = []
        
        # Generate all possible bit strings
        for i in range(num_states):
            bit_string = [(i >> j) & 1 for j in range(n)]
            energy = self.calculate_maxcut_energy(bit_string, adjacency_matrix)
            
            states.append({
                'index': i,
                'bit_string': ''.join(map(str, bit_string)),
                'bit_list': bit_string,
                'energy': energy
            })
            energies.append(energy)
        
        # Sort states by energy (descending)
        states.sort(key=lambda x: x['energy'], reverse=True)
        
        # Analysis
        max_energy = max(energies)
        min_energy = min(energies)
        ground_states = [s for s in states if s['energy'] == max_energy]
        energy_counts = Counter(energies)
        
        return {
            'states': states,
            'energies': energies,
            'max_energy': max_energy,
            'min_energy': min_energy,
            'ground_states': ground_states,
            'energy_counts': energy_counts,
            'num_states': num_states,
            'num_vertices': n
        }
    
    def create_detailed_comparison(self, figsize=(16, 12)):
        """Create detailed comparison between 4-state and 16-state problems."""
        
        # Analyze both problems
        small_problem = self.problems['2_vertex']
        large_problem = self.problems['4_vertex_cycle']  # Use cycle for interesting structure
        
        small_spectrum = self.generate_energy_spectrum(small_problem['adj'])
        large_spectrum = self.generate_energy_spectrum(large_problem['adj'])
        
        fig = plt.figure(figsize=figsize)
        
        # Create a complex grid layout
        gs = fig.add_gridspec(4, 4, hspace=0.6, wspace=0.3)
        
        # Title
        fig.suptitle('QAOA Energy Spectrum Comparison: 4 States vs 16 States', 
                    fontsize=16, fontweight='bold', y=0.95)
        
        # === 4-STATE PROBLEM ANALYSIS ===
        
        # Small problem: Energy spectrum
        ax1 = fig.add_subplot(gs[0, 0:2])
        self._plot_energy_spectrum(ax1, small_spectrum, small_problem, '4-State Problem')
        
        # Small problem: Individual state energies  
        ax2 = fig.add_subplot(gs[1, 0:2])
        self._plot_energy_histogram(ax2, small_spectrum, small_problem)
        
        # Small problem: State details
        ax3 = fig.add_subplot(gs[2:4, 0:2])
        self._plot_state_details(ax3, small_spectrum, small_problem)
        
        # === 16-STATE PROBLEM ANALYSIS ===
        
        # Large problem: Energy spectrum
        ax4 = fig.add_subplot(gs[0, 2:4])
        self._plot_energy_spectrum(ax4, large_spectrum, large_problem, '16-State Problem')
        
        # Large problem: Individual state energies
        ax5 = fig.add_subplot(gs[1, 2:4])
        self._plot_energy_histogram(ax5, large_spectrum, large_problem)
        
        # Large problem: Energy levels with individual states
        ax6 = fig.add_subplot(gs[2:4, 2:4])
        self._plot_energy_levels(ax6, large_spectrum, large_problem)
        
        plt.tight_layout()
        return fig, small_spectrum, large_spectrum
    
    def _plot_energy_spectrum(self, ax, spectrum, problem, title):
        """Plot energy spectrum scatter plot."""
        state_indices = range(len(spectrum['states']))
        state_energies = [s['energy'] for s in spectrum['states']]
        ground_energy = spectrum['max_energy']
        
        colors = ['red' if energy == ground_energy else 'blue' for energy in state_energies]
        sizes = [100 if energy == ground_energy else 60 for energy in state_energies]
        
        scatter = ax.scatter(state_indices, state_energies, c=colors, s=sizes, 
                           alpha=0.7, edgecolors='black', linewidth=0.5)
        
        ax.axhline(y=ground_energy, color='red', linestyle='--', alpha=0.5, 
                  label=f'Ground Energy = {ground_energy}')
        
        ax.set_xlabel('State Index (sorted by energy)')
        ax.set_ylabel('Energy')
        ax.set_title(f'{title}\n{problem["name"]}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add state labels for small problem
        if spectrum['num_states'] <= 4:
            for i, state in enumerate(spectrum['states']):
                ax.annotate(f"|{state['bit_string']}⟩", 
                          (i, state['energy']), 
                          xytext=(5, 5), textcoords='offset points',
                          fontsize=10, ha='left')
    
    def _plot_energy_histogram(self, ax, spectrum, problem):
        """Plot individual states with their energy levels."""
        # Get states in original order (not sorted by energy)
        n = spectrum['num_vertices']
        individual_states = []
        individual_energies = []
        
        for i in range(2**n):
            bit_string = [(i >> j) & 1 for j in range(n)]
            energy = self.calculate_maxcut_energy(bit_string, problem['adj'])
            individual_states.append(''.join(map(str, bit_string)))
            individual_energies.append(energy)
        
        ground_energy = spectrum['max_energy']
        
        # Create bars for each individual state
        state_indices = range(len(individual_states))
        
        # Create individual bars with their own colors and alphas
        bars = []
        for i, energy in enumerate(individual_energies):
            color = 'red' if energy == ground_energy else 'lightblue'
            alpha = 0.8 if energy == ground_energy else 0.6
            
            bar = ax.bar(i, energy, color=color, alpha=alpha, 
                        edgecolor='black', linewidth=0.5, width=0.8)
            bars.extend(bar)
        
        ax.set_xlabel('State Index')
        ax.set_ylabel('Energy Level')
        ax.set_title('Individual State Energies')
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add state labels on x-axis for small problems
        if len(individual_states) <= 8:
            ax.set_xticks(state_indices)
            ax.set_xticklabels([f"|{state}⟩" for state in individual_states], 
                              rotation=45, ha='right', fontsize=9)
        else:
            ax.set_xticks(range(0, len(individual_states), 2))
            ax.set_xticklabels([f"|{individual_states[i]}⟩" for i in range(0, len(individual_states), 2)], 
                              rotation=45, ha='right', fontsize=8)
        
        # Add energy value labels on bars for small problems
        if len(individual_states) <= 8:
            for i, energy in enumerate(individual_energies):
                ax.text(i, energy + 0.05, f'{energy}', ha='center', va='bottom', 
                       fontweight='bold', fontsize=10)
    
    def _plot_state_details(self, ax, spectrum, problem):
        """Plot detailed state information for small problems."""
        ax.axis('off')
        
        # Title
        ax.text(0.5, 0.95, 'Complete State Analysis', ha='center', va='top',
               fontsize=14, fontweight='bold', transform=ax.transAxes)
        
        # Create a table-like display
        y_pos = 0.85
        ax.text(0.1, y_pos, 'State', ha='center', va='center', fontweight='bold',
               transform=ax.transAxes)
        ax.text(0.3, y_pos, 'Binary', ha='center', va='center', fontweight='bold',
               transform=ax.transAxes)
        ax.text(0.5, y_pos, 'Energy', ha='center', va='center', fontweight='bold',
               transform=ax.transAxes)
        ax.text(0.7, y_pos, 'Type', ha='center', va='center', fontweight='bold',
               transform=ax.transAxes)
        
        # Draw separator line
        ax.plot([0.05, 0.95], [0.82, 0.82], color='black', linewidth=1,
               transform=ax.transAxes)
        
        y_pos = 0.78
        for i, state in enumerate(spectrum['states']):
            color = 'red' if state['energy'] == spectrum['max_energy'] else 'blue'
            state_type = 'OPTIMAL' if state['energy'] == spectrum['max_energy'] else 'Excited'
            
            ax.text(0.1, y_pos, f"|ψ{i}⟩", ha='center', va='center',
                   transform=ax.transAxes, color=color, fontweight='bold')
            ax.text(0.3, y_pos, f"|{state['bit_string']}⟩", ha='center', va='center',
                   transform=ax.transAxes, family='monospace', color=color)
            ax.text(0.5, y_pos, f"{state['energy']}", ha='center', va='center',
                   transform=ax.transAxes, color=color, fontweight='bold')
            ax.text(0.7, y_pos, state_type, ha='center', va='center',
                   transform=ax.transAxes, color=color, fontsize=9)
            
            y_pos -= 0.12
        
        # Add analysis text
        analysis_text = f"""
Problem Analysis:
• Total states: {spectrum['num_states']}
• Optimal solutions: {len(spectrum['ground_states'])}
• Success probability: {len(spectrum['ground_states'])/spectrum['num_states']:.1%}
• Energy range: {spectrum['min_energy']} to {spectrum['max_energy']}

QAOA Implications:
• {"Easy problem - high success rate" if len(spectrum['ground_states']) > 1 else "Challenging - single solution"}
• {"Large energy gap helps convergence" if spectrum['max_energy'] - spectrum['min_energy'] > 1 else "Small energy gap"}
        """
        
        ax.text(0.05, 0.4, analysis_text, ha='left', va='top', fontsize=9,
               transform=ax.transAxes, 
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", alpha=0.8))
    
    def _plot_energy_levels(self, ax, spectrum, problem):
        """Plot individual states at their energy levels."""
        # Get all states with their energies, sorted by energy then by state index
        all_states = []
        n = spectrum['num_vertices']
        
        for i in range(2**n):
            bit_string = [(i >> j) & 1 for j in range(n)]
            energy = self.calculate_maxcut_energy(bit_string, problem['adj'])
            all_states.append({
                'index': i,
                'bit_string': ''.join(map(str, bit_string)),
                'energy': energy
            })
        
        # Sort by energy (descending) then by state index
        all_states.sort(key=lambda x: (-x['energy'], x['index']))
        
        # Group states by energy level for positioning
        energy_groups = {}
        for state in all_states:
            energy = state['energy']
            if energy not in energy_groups:
                energy_groups[energy] = []
            energy_groups[energy].append(state)
        
        # Plot each state individually
        y_pos = 0
        unique_energies = sorted(energy_groups.keys(), reverse=True)
        
        for energy in unique_energies:
            states_at_energy = energy_groups[energy]
            
            for i, state in enumerate(states_at_energy):
                color = 'red' if energy == spectrum['max_energy'] else 'lightblue'
                alpha = 0.8 if energy == spectrum['max_energy'] else 0.6
                
                # Draw individual state bar
                rect = Rectangle((i, y_pos-0.15), 0.8, 0.3, facecolor=color, 
                               edgecolor='black', alpha=alpha, linewidth=1)
                ax.add_patch(rect)
                
                # State label
                if len(states_at_energy) <= 8:  # Only show labels if not too crowded
                    ax.text(i + 0.4, y_pos, f"|{state['bit_string']}⟩", 
                           ha='center', va='center', fontsize=8, fontweight='bold')
            
            # Energy level label
            ax.text(-0.5, y_pos, f'E={energy}', ha='right', va='center', 
                   fontweight='bold', fontsize=11)
            
            # Count label
            ax.text(len(states_at_energy) + 0.2, y_pos, 
                   f'({len(states_at_energy)} states)', 
                   ha='left', va='center', fontsize=9, color='gray')
            
            y_pos -= 1
        
        ax.set_xlim(-1, max(len(group) for group in energy_groups.values()) + 1)
        ax.set_ylim(y_pos, 1)
        ax.set_xlabel('Individual States at Each Energy Level')
        ax.set_title('Energy Level Distribution (Individual States)')
        ax.set_yticks([])
        ax.set_xticks([])
        
        # Add legend
        optimal_patch = Rectangle((0, 0), 1, 1, facecolor='red', alpha=0.8, label='Optimal States')
        excited_patch = Rectangle((0, 0), 1, 1, facecolor='lightblue', alpha=0.6, label='Excited States')
        ax.legend(handles=[optimal_patch, excited_patch], loc='upper right')
    
    def print_comparison_analysis(self):
        """Print detailed comparison analysis."""
        small_problem = self.problems['2_vertex']
        large_problem = self.problems['4_vertex_cycle']
        
        small_spectrum = self.generate_energy_spectrum(small_problem['adj'])
        large_spectrum = self.generate_energy_spectrum(large_problem['adj'])
        
        print("="*80)
        print("QAOA ENERGY SPECTRUM COMPARISON: 4 STATES vs 16 STATES")
        print("="*80)
        
        print(f"\n{'='*40}")
        print(f"4-STATE PROBLEM: {small_problem['name']}")
        print(f"{'='*40}")
        print(f"Description: {small_problem['description']}")
        print(f"Vertices: {small_spectrum['num_vertices']}")
        print(f"Total states: {small_spectrum['num_states']}")
        print(f"Energy range: {small_spectrum['min_energy']} to {small_spectrum['max_energy']}")
        print(f"Optimal solutions: {len(small_spectrum['ground_states'])}")
        print(f"Success probability: {len(small_spectrum['ground_states'])/small_spectrum['num_states']:.1%}")
        print("\nAll states:")
        for i, state in enumerate(small_spectrum['states']):
            marker = "★" if state['energy'] == small_spectrum['max_energy'] else " "
            print(f"  {marker} |{state['bit_string']}⟩ → Energy = {state['energy']}")
        
        print(f"\n{'='*40}")
        print(f"16-STATE PROBLEM: {large_problem['name']}")
        print(f"{'='*40}")
        print(f"Description: {large_problem['description']}")
        print(f"Vertices: {large_spectrum['num_vertices']}")
        print(f"Total states: {large_spectrum['num_states']}")
        print(f"Energy range: {large_spectrum['min_energy']} to {large_spectrum['max_energy']}")
        print(f"Optimal solutions: {len(large_spectrum['ground_states'])}")
        print(f"Success probability: {len(large_spectrum['ground_states'])/large_spectrum['num_states']:.1%}")
        print("\nEnergy distribution:")
        for energy in sorted(large_spectrum['energy_counts'].keys(), reverse=True):
            count = large_spectrum['energy_counts'][energy]
            percentage = count / large_spectrum['num_states'] * 100
            marker = "★" if energy == large_spectrum['max_energy'] else " "
            print(f"  {marker} Energy {energy}: {count:2d} states ({percentage:5.1f}%)")
        
        print(f"\nOptimal solutions for 16-state problem:")
        for state in large_spectrum['ground_states']:
            print(f"  ★ |{state['bit_string']}⟩")
        
        print(f"\n{'='*40}")
        print("QAOA IMPLICATIONS")
        print(f"{'='*40}")
        
        small_success = len(small_spectrum['ground_states'])/small_spectrum['num_states']
        large_success = len(large_spectrum['ground_states'])/large_spectrum['num_states']
        
        print(f"Small problem difficulty: {'EASY' if small_success > 0.25 else 'MODERATE' if small_success > 0.1 else 'HARD'}")
        print(f"Large problem difficulty: {'EASY' if large_success > 0.25 else 'MODERATE' if large_success > 0.1 else 'HARD'}")
        print(f"\nKey differences:")
        print(f"• State space size: {small_spectrum['num_states']} → {large_spectrum['num_states']} (4x larger)")
        print(f"• Success probability: {small_success:.1%} → {large_success:.1%}")
        print(f"• Problem complexity: Higher dimensional search space for QAOA")
        print(f"• Parameter sensitivity: 16-state problem likely more sensitive to QAOA angles")
        
        print("="*80)


def main():
    """Main function to run the comparison analysis."""
    analyzer = QAOAComparisonAnalyzer()
    
    # Print detailed text analysis
    analyzer.print_comparison_analysis()
    
    # Create visual comparison
    print("\nGenerating visual comparison...")
    fig, small_spectrum, large_spectrum = analyzer.create_detailed_comparison()
    
    # Show the plot
    plt.show()
    
    # Optional: Save the figure
    # fig.savefig('qaoa_4_vs_16_comparison.png', dpi=300, bbox_inches='tight')
    # print("Comparison saved as 'qaoa_4_vs_16_comparison.png'")


if __name__ == "__main__":
    main()