In [None]:
"""
Tesseract [[16,4,4]] Subsystem Color Code
Implementation following the Reichardt et al. paper and Entropica Labs style
"""

import numpy as np
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
from enum import Enum
import warnings

# ============================================================================
# Core Data Structures
# ============================================================================

class PauliType(Enum):
    """Pauli operator types"""
    I = 0
    X = 1
    Y = 2
    Z = 3

@dataclass
class PauliOperator:
    """Represents a Pauli operator on n qubits"""
    operator: np.ndarray  # shape (n,) with values 0,1,2,3 for I,X,Y,Z
    phase: int = 0  # 0, 1, 2, 3 for I, i, -1, -i

    def __mul__(self, other: 'PauliOperator') -> 'PauliOperator':
        """Multiply two Pauli operators"""
        if len(self.operator) != len(other.operator):
            raise ValueError("Operators must have same length")

        result = np.zeros_like(self.operator)
        phase = self.phase + other.phase

        for i in range(len(self.operator)):
            a, b = self.operator[i], other.operator[i]
            result[i], p = self._mult_single(a, b)
            phase += p

        return PauliOperator(result, phase % 4)

    @staticmethod
    def _mult_single(a: int, b: int) -> Tuple[int, int]:
        """Multiply single-qubit Paulis, return (result, phase)"""
        if a == 0: return b, 0
        if b == 0: return a, 0
        if a == b: return 0, 0

        # XY=iZ, YZ=iX, ZX=iY
        mult_table = {
            (1,2): (3, 1), (2,1): (3, 3),
            (2,3): (1, 1), (3,2): (1, 3),
            (3,1): (2, 1), (1,3): (2, 3)
        }
        return mult_table[(a, b)]

    def weight(self) -> int:
        """Return weight (number of non-identity Paulis)"""
        return np.sum(self.operator != 0)

    def __repr__(self) -> str:
        paulis = ['I', 'X', 'Y', 'Z']
        s = ''.join(paulis[int(p)] for p in self.operator)
        phases = ['', 'i', '-', '-i']
        return f"{phases[self.phase]}{s}" if self.phase else s

@dataclass
class MeasurementResult:
    """Result of a syndrome measurement"""
    outcome: int  # 0 or 1
    flag: bool = False  # True if measurement flagged correlated error

class ErrorType(Enum):
    """Types of errors"""
    NO_ERROR = 0
    CORRECTABLE = 1
    UNCORRECTABLE = 2

# ============================================================================
# Tesseract Code Structure
# ============================================================================

class TesseractCode:
    """
    [[16,4,4]] Tesseract subsystem color code

    Encodes 4 logical qubits into 16 physical qubits with distance 4.
    Uses 2 gauge qubits (logical qubits 1,2) for workspace.

    Qubit layout in 4x4 grid:
      0  1  2  3
      4  5  6  7
      8  9 10 11
     12 13 14 15
    """

    def __init__(self):
        self.n_physical = 16
        self.n_logical = 4  # Active logical qubits (not counting gauge)
        self.n_total_logical = 6  # Including gauge qubits
        self.distance = 4

        # Physical qubits arranged in 4x4 grid
        self.grid_shape = (4, 4)

        # Initialize stabilizers and logical operators
        self._init_stabilizers()
        self._init_logical_operators()
        self._init_syndrome_circuits()

    def _init_stabilizers(self):
        """Initialize X and Z stabilizers"""
        # X stabilizers: pairs of rows (weight 4 each)
        self.x_stabilizers = []
        rows = [
            [0, 1, 2, 3],
            [4, 5, 6, 7],
            [8, 9, 10, 11],
            [12, 13, 14, 15]
        ]

        # Generate pairwise XOR of rows as stabilizers
        for i in range(4):
            for j in range(i+1, 4):
                stab = np.zeros(16, dtype=int)
                stab[rows[i]] = 1
                stab[rows[j]] = 1
                self.x_stabilizers.append(PauliOperator(stab))

        # Z stabilizers: pairs of columns (weight 4 each)
        self.z_stabilizers = []
        cols = [
            [0, 4, 8, 12],
            [1, 5, 9, 13],
            [2, 6, 10, 14],
            [3, 7, 11, 15]
        ]

        for i in range(4):
            for j in range(i+1, 4):
                stab = np.zeros(16, dtype=int)
                stab[cols[i]] = 3  # Z operator
                stab[cols[j]] = 3
                self.z_stabilizers.append(PauliOperator(stab))

    def _init_logical_operators(self):
        """Initialize logical X and Z operators"""
        # Logical operators for the 4 data qubits (3,4,5,6)
        # Using weight-4 representatives as shown in paper

        self.logical_x = []
        self.logical_z = []

        # Logical qubit 3: X on row 0, Z on column 0
        x3 = np.zeros(16, dtype=int)
        x3[[0,1,2,3]] = 1
        self.logical_x.append(PauliOperator(x3))

        z3 = np.zeros(16, dtype=int)
        z3[[0,4,8,12]] = 3
        self.logical_z.append(PauliOperator(z3))

        # Logical qubit 4: X on row 1, Z on column 1
        x4 = np.zeros(16, dtype=int)
        x4[[4,5,6,7]] = 1
        self.logical_x.append(PauliOperator(x4))

        z4 = np.zeros(16, dtype=int)
        z4[[1,5,9,13]] = 3
        self.logical_z.append(PauliOperator(z4))

        # Logical qubit 5: X on row 2, Z on column 2
        x5 = np.zeros(16, dtype=int)
        x5[[8,9,10,11]] = 1
        self.logical_x.append(PauliOperator(x5))

        z5 = np.zeros(16, dtype=int)
        z5[[2,6,10,14]] = 3
        self.logical_z.append(PauliOperator(z5))

        # Logical qubit 6: X on row 3, Z on column 3
        x6 = np.zeros(16, dtype=int)
        x6[[12,13,14,15]] = 1
        self.logical_x.append(PauliOperator(x6))

        z6 = np.zeros(16, dtype=int)
        z6[[3,7,11,15]] = 3
        self.logical_z.append(PauliOperator(z6))

    def _init_syndrome_circuits(self):
        """Initialize syndrome measurement circuits"""
        # Row measurements for X stabilizers
        self.row_measurements = [
            [0, 1, 2, 3],
            [4, 5, 6, 7],
            [8, 9, 10, 11],
            [12, 13, 14, 15]
        ]

        # Column measurements for Z stabilizers
        self.col_measurements = [
            [0, 4, 8, 12],
            [1, 5, 9, 13],
            [2, 6, 10, 14],
            [3, 7, 11, 15]
        ]

    def get_syndrome_weight4_ops(self) -> Dict[str, List[List[int]]]:
        """
        Get weight-4 operators for syndrome extraction.
        Returns dict with 'X_rows', 'Z_rows', 'X_cols', 'Z_cols'
        """
        return {
            'X_rows': self.row_measurements,
            'Z_rows': self.row_measurements,  # Same qubits, Z basis
            'X_cols': self.col_measurements,
            'Z_cols': self.col_measurements
        }

# ============================================================================
# Error Correction Decoder
# ============================================================================

class TesseractDecoder:
    """
    Decoder for tesseract code implementing the paper's algorithm.
    Handles single-shot error correction with flag qubits.
    """

    def __init__(self, code: TesseractCode):
        self.code = code
        self.pauli_frame_x = np.zeros(16, dtype=int)
        self.pauli_frame_z = np.zeros(16, dtype=int)
        self.flagged_row = -1
        self.flagged_col = -1

    def reset_pauli_frame(self):
        """Reset Pauli frame tracking"""
        self.pauli_frame_x = np.zeros(16, dtype=int)
        self.pauli_frame_z = np.zeros(16, dtype=int)
        self.flagged_row = -1
        self.flagged_col = -1

    def decode_z_errors(self,
                       row_x_outcomes: List[int],
                       row_x_flags: List[bool],
                       col_x_outcomes: List[int],
                       col_x_flags: List[bool]) -> ErrorType:
        """
        Decode Z errors using X-basis measurements.

        Args:
            row_x_outcomes: X measurement outcomes for 4 rows [0-1]
            row_x_flags: Flag bits for row measurements
            col_x_outcomes: X measurement outcomes for 4 columns
            col_x_flags: Flag bits for column measurements

        Returns:
            ErrorType indicating if error is correctable
        """
        # Process row measurements
        status = self._process_rows(row_x_outcomes, row_x_flags, 'Z')
        if status == ErrorType.UNCORRECTABLE:
            return status

        # Process column measurements
        status = self._process_cols(col_x_outcomes, col_x_flags, 'Z')
        return status

    def decode_x_errors(self,
                       row_z_outcomes: List[int],
                       row_z_flags: List[bool],
                       col_z_outcomes: List[int],
                       col_z_flags: List[bool]) -> ErrorType:
        """
        Decode X errors using Z-basis measurements.
        Similar to decode_z_errors but for X errors.
        """
        status = self._process_rows(row_z_outcomes, row_z_flags, 'X')
        if status == ErrorType.UNCORRECTABLE:
            return status

        status = self._process_cols(col_z_outcomes, col_z_flags, 'X')
        return status

    def _process_rows(self, outcomes: List[int], flags: List[bool],
                     error_type: str) -> ErrorType:
        """Process row measurements for error correction"""
        if self.flagged_col == -1:  # No column flagged
            sum_outcomes = sum(outcomes)

            if sum_outcomes == 2:
                return ErrorType.UNCORRECTABLE

            if sum_outcomes in (1, 3):
                # Identify disagreeing row
                if sum_outcomes == 1:
                    self.flagged_row = outcomes.index(1)
                else:
                    self.flagged_row = outcomes.index(0)

        else:  # Column was flagged for correlated error
            sum_outcomes = sum(outcomes)

            if sum_outcomes in (1, 3):
                # Single error in row
                if sum_outcomes == 1:
                    col = outcomes.index(1)
                else:
                    col = outcomes.index(0)

                # Apply correction
                qubit = 4 * self.flagged_col + col
                if error_type == 'Z':
                    self.pauli_frame_z[qubit] ^= 1
                else:
                    self.pauli_frame_x[qubit] ^= 1

            elif sum_outcomes == 2:
                # Correlated error
                if outcomes in ([0,0,1,1], [1,1,0,0]):
                    # Apply ZZII or XXII correction
                    qubits = [4*self.flagged_col, 4*self.flagged_col+1]
                    if error_type == 'Z':
                        self.pauli_frame_z[qubits] ^= 1
                    else:
                        self.pauli_frame_x[qubits] ^= 1
                else:
                    return ErrorType.UNCORRECTABLE

            self.flagged_col = -1

        return ErrorType.CORRECTABLE

    def _process_cols(self, outcomes: List[int], flags: List[bool],
                     error_type: str) -> ErrorType:
        """Process column measurements for error correction"""
        if self.flagged_row == -1:
            sum_outcomes = sum(outcomes)

            if sum_outcomes == 2:
                return ErrorType.UNCORRECTABLE

            if sum_outcomes in (1, 3):
                if sum_outcomes == 1:
                    self.flagged_col = outcomes.index(1)
                else:
                    self.flagged_col = outcomes.index(0)

        else:  # Row was flagged
            sum_outcomes = sum(outcomes)

            if sum_outcomes in (1, 3):
                if sum_outcomes == 1:
                    row = outcomes.index(1)
                else:
                    row = outcomes.index(0)

                qubit = 4 * row + self.flagged_row
                if error_type == 'Z':
                    self.pauli_frame_z[qubit] ^= 1
                else:
                    self.pauli_frame_x[qubit] ^= 1

            elif sum_outcomes == 2:
                if outcomes in ([0,0,1,1], [1,1,0,0]):
                    qubits = [self.flagged_row, 4 + self.flagged_row]
                    if error_type == 'Z':
                        self.pauli_frame_z[qubits] ^= 1
                    else:
                        self.pauli_frame_x[qubits] ^= 1
                else:
                    return ErrorType.UNCORRECTABLE

            self.flagged_row = -1

        return ErrorType.CORRECTABLE

# ============================================================================
# Quantum Circuit Components
# ============================================================================

class TesseractCircuits:
    """Circuit primitives for tesseract code"""

    def __init__(self, code: TesseractCode):
        self.code = code

    @staticmethod
    def measure_weight4_x_with_flag(qubits: List[int]) -> Dict:
        """
        Circuit to measure X^4 with one flag qubit (Fig 4c in paper).
        Returns circuit structure.
        """
        return {
            'type': 'weight4_X_flagged',
            'data_qubits': qubits,
            'ancilla': 'syndrome',
            'flag': 'flag',
            'gates': [
                ('H', 'syndrome'),
                ('H', 'flag'),
                ('CNOT', 'syndrome', qubits[0]),
                ('CNOT', 'flag', qubits[1]),
                ('CNOT', 'syndrome', qubits[2]),
                ('CNOT', 'flag', qubits[3]),
                ('CNOT', 'syndrome', 'flag'),
                ('H', 'flag'),
                ('H', 'syndrome'),
                ('MEASURE', 'flag'),
                ('MEASURE', 'syndrome')
            ]
        }

    @staticmethod
    def measure_weight4_xz_simultaneous(qubits: List[int]) -> Dict:
        """
        Circuit to measure X^4 and Z^4 simultaneously (Fig 4d).
        """
        return {
            'type': 'weight4_XZ_simultaneous',
            'data_qubits': qubits,
            'ancilla_x': 'syndrome_x',
            'ancilla_z': 'syndrome_z',
            'gates': [
                ('H', 'syndrome_x'),
                ('CNOT', 'syndrome_x', qubits[0]),
                ('CNOT', 'syndrome_x', qubits[1]),
                ('CNOT', 'syndrome_x', qubits[2]),
                ('CNOT', 'syndrome_x', qubits[3]),
                ('CNOT', qubits[0], 'syndrome_z'),
                ('CNOT', qubits[1], 'syndrome_z'),
                ('CNOT', qubits[2], 'syndrome_z'),
                ('CNOT', qubits[3], 'syndrome_z'),
                ('H', 'syndrome_x'),
                ('MEASURE', 'syndrome_x'),
                ('MEASURE', 'syndrome_z')
            ]
        }

    def initialization_circuit(self, state: str = '++0000') -> Dict:
        """
        Generate initialization circuit for encoded states.

        Args:
            state: One of '++0000', '+0+0+0', '00++++'
        """
        if state == '++0000':
            return self._init_plus_plus_zero_zero()
        elif state == '+0+0+0':
            return self._init_alternating()
        else:
            raise ValueError(f"Unknown state: {state}")

    def _init_plus_plus_zero_zero(self) -> Dict:
        """Initialize |++0000⟩ encoded state (Fig 9a)"""
        return {
            'type': 'init_++0000',
            'circuit': [
                # Prepare |+⟩ states
                ('H', [0,1,2,3,4,5,6,7]),
                # Entangle for [[8,3,2]] code structure
                ('CNOT', [0,4], [1,5]),
                ('CNOT', [0,4], [2,6]),
                ('CNOT', [0,4], [3,7]),
                # Flag measurements (postselected)
                ('MEASURE_FLAG', 'stabilizers')
            ]
        }

    def _init_alternating(self) -> Dict:
        """Initialize |+0+0+0⟩ using two [[8,3,2]] codes"""
        return {
            'type': 'init_+0+0+0',
            'circuit': [
                ('INIT_BLOCK', 'block1', '|000⟩'),
                ('INIT_BLOCK', 'block2', '|+++⟩'),
                ('TRANSVERSAL_GATES', 'combine')
            ]
        }

    def error_correction_round(self) -> Dict:
        """
        One round of error correction.
        Measure rows then columns with simultaneous X/Z measurements.
        """
        circuits = []

        # Measure rows (X and Z simultaneously)
        for row in self.code.row_measurements:
            circuits.append(self.measure_weight4_xz_simultaneous(row))

        # Measure columns (X and Z simultaneously)
        for col in self.code.col_measurements:
            circuits.append(self.measure_weight4_xz_simultaneous(col))

        return {
            'type': 'error_correction_round',
            'circuits': circuits,
            'decoder': 'tesseract_decoder'
        }

# ============================================================================
# Permutation Automorphisms
# ============================================================================

class TesseractPermutations:
    """
    Logical operations via qubit permutations.
    Certain permutations preserve code space with non-trivial logical effect.
    """

    @staticmethod
    def get_cnot_permutations() -> Dict[Tuple[int,int], List[int]]:
        """
        Get permutations that implement logical CNOT gates.
        Returns dict mapping (control, target) -> permutation
        """
        # Identity permutation
        identity = list(range(16))

        # Example: CNOT(5,6) via permutation (3,5)(4,6)
        perm_cnot_56 = identity.copy()
        perm_cnot_56[3], perm_cnot_56[5] = perm_cnot_56[5], perm_cnot_56[3]
        perm_cnot_56[4], perm_cnot_56[6] = perm_cnot_56[6], perm_cnot_56[4]

        return {
            (5, 6): perm_cnot_56,
            # Add more as needed
        }

    @staticmethod
    def apply_permutation(state: np.ndarray, perm: List[int]) -> np.ndarray:
        """Apply qubit permutation to state"""
        return state[perm]

# ============================================================================
# Full Stabilizer Anticommutation Table
# ============================================================================

def build_stabilizer_anticommutation_table(code: TesseractCode) -> np.ndarray:
    """
    Build full anticommutation table for stabilizers and logical operators.
    Entry (i,j) is 1 if operators i and j anticommute, 0 otherwise.
    """
    n_x_stab = len(code.x_stabilizers)
    n_z_stab = len(code.z_stabilizers)
    n_log = len(code.logical_x)

    all_ops = (code.x_stabilizers + code.z_stabilizers +
               code.logical_x + code.logical_z)
    n_ops = len(all_ops)

    table = np.zeros((n_ops, n_ops), dtype=int)

    for i in range(n_ops):
        for j in range(i+1, n_ops):
            if anticommute(all_ops[i], all_ops[j]):
                table[i,j] = 1
                table[j,i] = 1

    return table

def anticommute(op1: PauliOperator, op2: PauliOperator) -> bool:
    """Check if two Pauli operators anticommute"""
    count = 0
    for p1, p2 in zip(op1.operator, op2.operator):
        if p1 != 0 and p2 != 0 and p1 != p2:
            count += 1
    return count % 2 == 1

# ============================================================================
# Example Usage and Tests
# ============================================================================

def demonstrate_tesseract_code():
    """Demonstrate basic tesseract code functionality"""
    print("=" * 70)
    print("TESSERACT [[16,4,4]] SUBSYSTEM COLOR CODE")
    print("=" * 70)

    # Initialize code
    code = TesseractCode()
    print(f"\nCode parameters: [[{code.n_physical}, {code.n_logical}, {code.distance}]]")
    print(f"Physical qubits: {code.n_physical}")
    print(f"Logical qubits: {code.n_logical} (+ 2 gauge qubits)")
    print(f"Distance: {code.distance}")

    # Show stabilizers
    print(f"\nX Stabilizers ({len(code.x_stabilizers)}):")
    for i, stab in enumerate(code.x_stabilizers[:3]):
        print(f"  {i}: {stab} (weight {stab.weight()})")

    print(f"\nZ Stabilizers ({len(code.z_stabilizers)}):")
    for i, stab in enumerate(code.z_stabilizers[:3]):
        print(f"  {i}: {stab} (weight {stab.weight()})")

    # Show logical operators
    print(f"\nLogical Operators:")
    for i in range(code.n_logical):
        print(f"  X_{i+3}: {code.logical_x[i]} (weight {code.logical_x[i].weight()})")
        print(f"  Z_{i+3}: {code.logical_z[i]} (weight {code.logical_z[i].weight()})")

    # Build anticommutation table
    print("\nBuilding stabilizer anticommutation table...")
    table = build_stabilizer_anticommutation_table(code)
    print(f"Table shape: {table.shape}")

    # Demonstrate decoder
    print("\n" + "=" * 70)
    print("ERROR CORRECTION EXAMPLE")
    print("=" * 70)

    decoder = TesseractDecoder(code)

    # Simulate Z error in row 1, column 2
    print("\nSimulating Z error at qubit 6 (row 1, col 2)")
    row_x_outcomes = [0, 1, 0, 0]  # Row 1 disagrees
    row_x_flags = [False] * 4
    col_x_outcomes = [0, 0, 1, 0]  # Col 2 disagrees
    col_x_flags = [False] * 4

    status = decoder.decode_z_errors(row_x_outcomes, row_x_flags,
                                     col_x_outcomes, col_x_flags)

    print(f"Decoding status: {status}")
    print(f"Z corrections applied: {np.where(decoder.pauli_frame_z)[0]}")

    # Demonstrate circuits
    print("\n" + "=" * 70)
    print("CIRCUIT STRUCTURES")
    print("=" * 70)

    circuits = TesseractCircuits(code)

    # Weight-4 measurement
    meas_circuit = circuits.measure_weight4_x_with_flag([0,1,2,3])
    print("\nWeight-4 X measurement circuit:")
    print(f"  Data qubits: {meas_circuit['data_qubits']}")
    print(f"  Number of gates: {len(meas_circuit['gates'])}")

    # Error correction round
    ec_round = circuits.error_correction_round()
    print(f"\nError correction round:")
    print(f"  Type: {ec_round['type']}")
    print(f"  Number of measurement circuits: {len(ec_round['circuits'])}")

    print("\n" + "=" * 70)
    print("TESSERACT CODE DEMONSTRATION COMPLETE")
    print("=" * 70)

if __name__ == "__main__":
    demonstrate_tesseract_code()