In [1]:
import sys
sys.path.append("../squash")

In [2]:
import pygame
from pygame import DOUBLEBUF, HWSURFACE
import numpy as np

from qiskit import QuantumCircuit, QuantumRegister
from qiskit import BasicAer, execute, ClassicalRegister
from copy import deepcopy
import math
import random
from pygame.locals import *
import os
from pygame.constants import RLEACCEL

pygame 2.1.2 (SDL 2.0.18, Python 3.7.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
# containers
class HBox(pygame.sprite.RenderPlain):
    """Arranges sprites horizontally"""
    def __init__(self, xpos, ypos, *sprites):
        pygame.sprite.RenderPlain.__init__(self, sprites)
        self.xpos = xpos
        self.ypos = ypos
        self.arrange()

    def arrange(self):
        next_xpos = self.xpos
        next_ypos = self.ypos
        sprite_list = self.sprites()
        for sprite in sprite_list:
            sprite.rect.left = next_xpos
            sprite.rect.top = next_ypos
            next_xpos += sprite.rect.width
            
class VBox(pygame.sprite.RenderPlain):
    """Arranges sprites vertically"""
    def __init__(self, xpos, ypos, *sprites):
        pygame.sprite.RenderPlain.__init__(self, sprites)
        self.xpos = xpos
        self.ypos = ypos
        self.arrange()

    def arrange(self):
        next_xpos = self.xpos
        next_ypos = self.ypos
        sprite_list = self.sprites()
        for sprite in sprite_list:
            sprite.rect.left = next_xpos
            sprite.rect.top = next_ypos
            next_ypos += sprite.rect.height

# controls          
class CircuitGrid(pygame.sprite.RenderPlain):
    """Enables interaction with circuit"""

    def __init__(self, xpos, ypos, circuit_grid_model):
        self.xpos = xpos
        self.ypos = ypos
        self.circuit_grid_model = circuit_grid_model
        self.selected_wire = 0
        self.selected_column = 0
        self.circuit_grid_background = CircuitGridBackground(circuit_grid_model)
        self.circuit_grid_cursor = CircuitGridCursor()
        self.gate_tiles = np.empty((circuit_grid_model.max_wires, circuit_grid_model.max_columns),
                                   dtype=CircuitGridGate)

        for row_idx in range(self.circuit_grid_model.max_wires):
            for col_idx in range(self.circuit_grid_model.max_columns):
                self.gate_tiles[row_idx][col_idx] = \
                    CircuitGridGate(circuit_grid_model, row_idx, col_idx)

        pygame.sprite.RenderPlain.__init__(self, self.circuit_grid_background,
                                           self.gate_tiles,
                                           self.circuit_grid_cursor)
        self.update()

    def update(self, *args):

        sprite_list = self.sprites()
        for sprite in sprite_list:
            sprite.update()

        self.circuit_grid_background.rect.left = self.xpos
        self.circuit_grid_background.rect.top = self.ypos

        for row_idx in range(self.circuit_grid_model.max_wires):
            for col_idx in range(self.circuit_grid_model.max_columns):
                self.gate_tiles[row_idx][col_idx].rect.centerx = \
                    self.xpos + GRID_WIDTH * (col_idx + 1.5)
                self.gate_tiles[row_idx][col_idx].rect.centery = \
                    self.ypos + GRID_HEIGHT * (row_idx + 1.0)

        self.highlight_selected_node(self.selected_wire, self.selected_column)

    def highlight_selected_node(self, wire_num, column_num):
        self.selected_wire = wire_num
        self.selected_column = column_num
        self.circuit_grid_cursor.rect.left = self.xpos + GRID_WIDTH * (self.selected_column + 1) + round(
            0.375 * WIDTH_UNIT)
        self.circuit_grid_cursor.rect.top = self.ypos + GRID_HEIGHT * (self.selected_wire + 0.5) + round(
            0.375 * WIDTH_UNIT)

    def reset_cursor(self):
        self.highlight_selected_node(0, 0)

    def display_exceptional_condition(self):
        # TODO: Make cursor appearance indicate condition such as unable to place a gate
        return

    def move_to_adjacent_node(self, direction):
        if direction == MOVE_LEFT and self.selected_column > 0:
            self.selected_column -= 1
        elif direction == MOVE_RIGHT and self.selected_column < self.circuit_grid_model.max_columns - 1:
            self.selected_column += 1
        elif direction == MOVE_UP and self.selected_wire > 0:
            self.selected_wire -= 1
        elif direction == MOVE_DOWN and self.selected_wire < self.circuit_grid_model.max_wires - 1:
            self.selected_wire += 1

        self.highlight_selected_node(self.selected_wire, self.selected_column)

    def get_selected_node_gate_part(self):
        return self.circuit_grid_model.get_node_gate_part(self.selected_wire, self.selected_column)

    def handle_input_x(self):
        # Add X gate regardless of whether there is an existing gate
        # circuit_grid_node = CircuitGridNode(X)
        # self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)

        # Allow deleting using the same key only
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == EMPTY:
            circuit_grid_node = CircuitGridNode(X)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)
        elif selected_node_gate_part == X:
            self.handle_input_delete()
        self.update()

    def handle_input_y(self):
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == EMPTY:
            circuit_grid_node = CircuitGridNode(Y)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)
        elif selected_node_gate_part == Y:
            self.handle_input_delete()
        self.update()

    def handle_input_z(self):
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == EMPTY:
            circuit_grid_node = CircuitGridNode(Z)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)
        elif selected_node_gate_part == Z:
            self.handle_input_delete()
        self.update()

    def handle_input_h(self):
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == EMPTY:
            circuit_grid_node = CircuitGridNode(H)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)
        elif selected_node_gate_part == H:
            self.handle_input_delete()
        self.update()

    def handle_input_delete(self):
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == X or \
                selected_node_gate_part == Y or \
                selected_node_gate_part == Z or \
                selected_node_gate_part == H:
            self.delete_controls_for_gate(self.selected_wire, self.selected_column)

        if selected_node_gate_part == CTRL:
            gate_wire_num = \
                self.circuit_grid_model.get_gate_wire_for_control_node(self.selected_wire,
                                                                       self.selected_column)
            if gate_wire_num >= 0:
                self.delete_controls_for_gate(gate_wire_num,
                                              self.selected_column)
        elif selected_node_gate_part != SWAP and \
                selected_node_gate_part != CTRL and \
                selected_node_gate_part != TRACE:
            circuit_grid_node = CircuitGridNode(EMPTY)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)

        self.update()

    def handle_input_ctrl(self):
        # TODO: Handle Toffoli gates. For now, control qubit is assumed to be in ctrl_a variable
        #       with ctrl_b variable reserved for Toffoli gates
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == X or \
                selected_node_gate_part == Y or \
                selected_node_gate_part == Z or \
                selected_node_gate_part == H:
            circuit_grid_node = self.circuit_grid_model.get_node(self.selected_wire, self.selected_column)
            if circuit_grid_node.ctrl_a >= 0:
                # Gate already has a control qubit so remove it
                orig_ctrl_a = circuit_grid_node.ctrl_a
                circuit_grid_node.ctrl_a = -1
                self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)

                # Remove TRACE nodes
                for wire_num in range(min(self.selected_wire, orig_ctrl_a) + 1,
                                      max(self.selected_wire, orig_ctrl_a)):
                    if self.circuit_grid_model.get_node_gate_part(wire_num,
                                                                  self.selected_column) == TRACE:
                        self.circuit_grid_model.set_node(wire_num, self.selected_column,
                                                         CircuitGridNode(EMPTY))
                self.update()
            else:
                # Attempt to place a control qubit beginning with the wire above
                if self.selected_wire >= 0:
                    if self.place_ctrl_qubit(self.selected_wire, self.selected_wire - 1) == -1:
                        if self.selected_wire < self.circuit_grid_model.max_wires:
                            if self.place_ctrl_qubit(self.selected_wire, self.selected_wire + 1) == -1:
                                print("Can't place control qubit")
                                self.display_exceptional_condition()

    def handle_input_move_ctrl(self, direction):
        # TODO: Handle Toffoli gates. For now, control qubit is assumed to be in ctrl_a variable
        #       with ctrl_b variable reserved for Toffoli gates
        # TODO: Simplify the logic in this method, including considering not actually ever
        #       placing a TRACE, but rather always dynamically calculating if a TRACE s/b displayed
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == X or \
                selected_node_gate_part == Y or \
                selected_node_gate_part == Z or \
                selected_node_gate_part == H:
            circuit_grid_node = self.circuit_grid_model.get_node(self.selected_wire, self.selected_column)
            if 0 <= circuit_grid_node.ctrl_a < self.circuit_grid_model.max_wires:
                # Gate already has a control qubit so try to move it
                if direction == MOVE_UP:
                    candidate_wire_num = circuit_grid_node.ctrl_a - 1
                    if candidate_wire_num == self.selected_wire:
                        candidate_wire_num -= 1
                else:
                    candidate_wire_num = circuit_grid_node.ctrl_a + 1
                    if candidate_wire_num == self.selected_wire:
                        candidate_wire_num += 1
                if 0 <= candidate_wire_num < self.circuit_grid_model.max_wires:
                    if self.place_ctrl_qubit(self.selected_wire, candidate_wire_num) == candidate_wire_num:
                        print("control qubit successfully placed on wire ", candidate_wire_num)
                        if direction == MOVE_UP and candidate_wire_num < self.selected_wire:
                            if self.circuit_grid_model.get_node_gate_part(candidate_wire_num + 1,
                                                                          self.selected_column) == EMPTY:
                                self.circuit_grid_model.set_node(candidate_wire_num + 1, self.selected_column,
                                                                 CircuitGridNode(TRACE))
                        elif direction == MOVE_DOWN and candidate_wire_num > self.selected_wire:
                            if self.circuit_grid_model.get_node_gate_part(candidate_wire_num - 1,
                                                                          self.selected_column) == EMPTY:
                                self.circuit_grid_model.set_node(candidate_wire_num - 1, self.selected_column,
                                                                 CircuitGridNode(TRACE))
                        self.update()
                    else:
                        print("control qubit could not be placed on wire ", candidate_wire_num)

    def handle_input_rotate(self, radians):
        selected_node_gate_part = self.get_selected_node_gate_part()
        if selected_node_gate_part == X or \
                selected_node_gate_part == Y or \
                selected_node_gate_part == Z:
            circuit_grid_node = self.circuit_grid_model.get_node(self.selected_wire, self.selected_column)
            circuit_grid_node.radians = (circuit_grid_node.radians + radians) % (2 * np.pi)
            self.circuit_grid_model.set_node(self.selected_wire, self.selected_column, circuit_grid_node)

        self.update()

    def place_ctrl_qubit(self, gate_wire_num, candidate_ctrl_wire_num):
        """Attempt to place a control qubit on a wire.
        If successful, return the wire number. If not, return -1
        """
        if candidate_ctrl_wire_num < 0 or candidate_ctrl_wire_num >= self.circuit_grid_model.max_wires:
            return -1
        candidate_wire_gate_part = \
            self.circuit_grid_model.get_node_gate_part(candidate_ctrl_wire_num,
                                                       self.selected_column)
        if candidate_wire_gate_part == EMPTY or \
                candidate_wire_gate_part == TRACE:
            circuit_grid_node = self.circuit_grid_model.get_node(gate_wire_num, self.selected_column)
            circuit_grid_node.ctrl_a = candidate_ctrl_wire_num
            self.circuit_grid_model.set_node(gate_wire_num, self.selected_column, circuit_grid_node)
            self.circuit_grid_model.set_node(candidate_ctrl_wire_num, self.selected_column,
                                             CircuitGridNode(EMPTY))
            self.update()
            return candidate_ctrl_wire_num
        else:
            print("Can't place control qubit on wire: ", candidate_ctrl_wire_num)
            return -1

    def delete_controls_for_gate(self, gate_wire_num, column_num):
        control_a_wire_num = self.circuit_grid_model.get_node(gate_wire_num, column_num).ctrl_a
        control_b_wire_num = self.circuit_grid_model.get_node(gate_wire_num, column_num).ctrl_b

        # Choose the control wire (if any exist) furthest away from the gate wire
        control_a_wire_distance = 0
        control_b_wire_distance = 0
        if control_a_wire_num >= 0:
            control_a_wire_distance = abs(control_a_wire_num - gate_wire_num)
        if control_b_wire_num >= 0:
            control_b_wire_distance = abs(control_b_wire_num - gate_wire_num)

        control_wire_num = -1
        if control_a_wire_distance > control_b_wire_distance:
            control_wire_num = control_a_wire_num
        elif control_a_wire_distance < control_b_wire_distance:
            control_wire_num = control_b_wire_num

        if control_wire_num >= 0:
            # TODO: If this is a controlled gate, remove the connecting TRACE parts between the gate and the control
            # ALSO: Refactor with similar code in this method
            for wire_idx in range(min(gate_wire_num, control_wire_num),
                                  max(gate_wire_num, control_wire_num) + 1):
                print("Replacing wire ", wire_idx, " in column ", column_num)
                circuit_grid_node = CircuitGridNode(EMPTY)
                self.circuit_grid_model.set_node(wire_idx, column_num, circuit_grid_node)


class CircuitGridBackground(pygame.sprite.Sprite):
    """Background for circuit grid"""

    def __init__(self, circuit_grid_model):
        pygame.sprite.Sprite.__init__(self)

        self.image = pygame.Surface([GRID_WIDTH * (18 + 2),
                                     GRID_HEIGHT * (2 + 1)])
        self.image.convert()
        self.image.fill(WHITE)
        self.rect = self.image.get_rect()
        pygame.draw.rect(self.image, BLACK, self.rect, LINE_WIDTH)

        for wire_num in range(circuit_grid_model.max_wires):
            pygame.draw.line(self.image, BLACK,
                             (GRID_WIDTH * 0.5, (wire_num + 1) * GRID_HEIGHT),
                             (self.rect.width - (GRID_WIDTH * 0.5), (wire_num + 1) * GRID_HEIGHT),
                             LINE_WIDTH)


class CircuitGridGate(pygame.sprite.Sprite):
    """Images for nodes"""

    def __init__(self, circuit_grid_model, wire_num, column_num):
        pygame.sprite.Sprite.__init__(self)
        self.circuit_grid_model = circuit_grid_model
        self.wire_num = wire_num
        self.column_num = column_num

        self.update()

    def update(self):
        node_type = self.circuit_grid_model.get_node_gate_part(self.wire_num, self.column_num)

        if node_type == H:
            self.image, self.rect = load_image('gate_images/h_gate.png', -1)
        elif node_type == X:
            node = self.circuit_grid_model.get_node(self.wire_num, self.column_num)
            if node.ctrl_a >= 0 or node.ctrl_b >= 0:
                # This is a control-X gate or Toffoli gate
                # TODO: Handle Toffoli gates more completely
                if self.wire_num > max(node.ctrl_a, node.ctrl_b):
                    self.image, self.rect = load_image('gate_images/not_gate_below_ctrl.png', -1)
                else:
                    self.image, self.rect = load_image('gate_images/not_gate_above_ctrl.png', -1)
            elif node.radians != 0:
                self.image, self.rect = load_image('gate_images/rx_gate.png', -1)
                self.rect = self.image.get_rect()
                pygame.draw.arc(self.image, MAGENTA, self.rect, 0, node.radians % (2 * np.pi), 6)
                pygame.draw.arc(self.image, MAGENTA, self.rect, node.radians % (2 * np.pi), 2 * np.pi, 1)
            else:
                self.image, self.rect = load_image('gate_images/x_gate.png', -1)
        elif node_type == Y:
            node = self.circuit_grid_model.get_node(self.wire_num, self.column_num)
            if node.radians != 0:
                self.image, self.rect = load_image('gate_images/ry_gate.png', -1)
                self.rect = self.image.get_rect()
                pygame.draw.arc(self.image, MAGENTA, self.rect, 0, node.radians % (2 * np.pi), 6)
                pygame.draw.arc(self.image, MAGENTA, self.rect, node.radians % (2 * np.pi), 2 * np.pi, 1)
            else:
                self.image, self.rect = load_image('gate_images/y_gate.png', -1)
        elif node_type == Z:
            node = self.circuit_grid_model.get_node(self.wire_num, self.column_num)
            if node.radians != 0:
                self.image, self.rect = load_image('gate_images/rz_gate.png', -1)
                self.rect = self.image.get_rect()
                pygame.draw.arc(self.image, MAGENTA, self.rect, 0, node.radians % (2 * np.pi), 6)
                pygame.draw.arc(self.image, MAGENTA, self.rect, node.radians % (2 * np.pi), 2 * np.pi, 1)
            else:
                self.image, self.rect = load_image('gate_images/z_gate.png', -1)
        elif node_type == S:
            self.image, self.rect = load_image('gate_images/s_gate.png', -1)
        elif node_type == SDG:
            self.image, self.rect = load_image('gate_images/sdg_gate.png', -1)
        elif node_type == T:
            self.image, self.rect = load_image('gate_images/t_gate.png', -1)
        elif node_type == TDG:
            self.image, self.rect = load_image('gate_images/tdg_gate.png', -1)
        elif node_type == IDEN:
            # a completely transparent PNG is used to place at the end of the circuit to prevent crash
            # the game crashes if the circuit is empty
            self.image, self.rect = load_image('gate_images/transparent.png', -1)
        elif node_type == CTRL:
            # TODO: Handle Toffoli gates correctly
            if self.wire_num > \
                    self.circuit_grid_model.get_gate_wire_for_control_node(self.wire_num, self.column_num):
                self.image, self.rect = load_image('gate_images/ctrl_gate_bottom_wire.png', -1)
            else:
                self.image, self.rect = load_image('gate_images/ctrl_gate_top_wire.png', -1)
        elif node_type == TRACE:
            self.image, self.rect = load_image('gate_images/trace_gate.png', -1)
        elif node_type == SWAP:
            self.image, self.rect = load_image('gate_images/swap_gate.png', -1)
        else:
            self.image = pygame.Surface([GATE_TILE_WIDTH, GATE_TILE_HEIGHT])
            self.image.set_alpha(0)
            self.rect = self.image.get_rect()

        self.image.convert()


class CircuitGridCursor(pygame.sprite.Sprite):
    """Cursor to highlight current grid node"""

    def __init__(self):
        pygame.sprite.Sprite.__init__(self)
        self.image, self.rect = load_image('cursor_images/circuit-grid-cursor-medium.png', -1)
        self.image.convert_alpha()
   
MAX_NUM_QUBITS = 10
# model 
class CircuitGridModel:
    """Grid-based model that is built when user interacts with circuit"""
    def __init__(self, max_wires, max_columns):
        self.max_wires = max_wires
        self.max_columns = max_columns
        self.nodes = np.empty((max_wires, max_columns), dtype=CircuitGridNode)

    def __str__(self):
        retval = ''
        for wire_num in range(self.max_wires):
            retval += '\n'
            for column_num in range(self.max_columns):
                retval += str(self.get_node_gate_part(wire_num, column_num)) + ', '
        return 'CircuitGridModel: ' + retval

    def set_node(self, wire_num, column_num, circuit_grid_node):
        self.nodes[wire_num][column_num] = \
            CircuitGridNode(circuit_grid_node.node_type,
                            circuit_grid_node.radians,
                            circuit_grid_node.ctrl_a,
                            circuit_grid_node.ctrl_b,
                            circuit_grid_node.swap)


    def get_node(self, wire_num, column_num):
        return self.nodes[wire_num][column_num]

    def get_node_gate_part(self, wire_num, column_num):
        requested_node = self.nodes[wire_num][column_num]
        if requested_node and requested_node.node_type != EMPTY:
            # Node is occupied so return its gate
            return requested_node.node_type
        else:
            # Check for control nodes from gates in other nodes in this column
            nodes_in_column = self.nodes[:, column_num]
            for idx in range(self.max_wires):
                if idx != wire_num:
                    other_node = nodes_in_column[idx]
                    if other_node:
                        if other_node.ctrl_a == wire_num or other_node.ctrl_b == wire_num:
                            return CTRL
                        elif other_node.swap == wire_num:
                            return SWAP

        return EMPTY

    def get_gate_wire_for_control_node(self, control_wire_num, column_num):
        """Get wire for gate that belongs to a control node on the given wire"""
        gate_wire_num = -1
        nodes_in_column = self.nodes[:, column_num]
        for wire_idx in range(self.max_wires):
            if wire_idx != control_wire_num:
                other_node = nodes_in_column[wire_idx]
                if other_node:
                    if other_node.ctrl_a == control_wire_num or \
                            other_node.ctrl_b == control_wire_num:
                        gate_wire_num =  wire_idx
                        print("Found gate: ",
                              self.get_node_gate_part(gate_wire_num, column_num),
                              " on wire: " , gate_wire_num)
        return gate_wire_num

    def compute_circuit(self):
        qr = QuantumRegister(self.max_wires, 'q')
        qc = QuantumCircuit(qr)

        for column_num in range(self.max_columns):
            for wire_num in range(self.max_wires):
                node = self.nodes[wire_num][column_num]
                if node:
                    if node.node_type == IDEN:
                        # Identity gate
                        qc.i(qr[wire_num])
                    elif node.node_type == X:
                        if node.radians == 0:
                            if node.ctrl_a != -1:
                                if node.ctrl_b != -1:
                                    # Toffoli gate
                                    qc.ccx(qr[node.ctrl_a], qr[node.ctrl_b], qr[wire_num])
                                else:
                                    # Controlled X gate
                                    qc.cx(qr[node.ctrl_a], qr[wire_num])
                            else:
                                # Pauli-X gate
                                qc.x(qr[wire_num])
                        else:
                            # Rotation around X axis
                            qc.rx(node.radians, qr[wire_num])
                    elif node.node_type == Y:
                        if node.radians == 0:
                            if node.ctrl_a != -1:
                                # Controlled Y gate
                                qc.cy(qr[node.ctrl_a], qr[wire_num])
                            else:
                                # Pauli-Y gate
                                qc.y(qr[wire_num])
                        else:
                            # Rotation around Y axis
                            qc.ry(node.radians, qr[wire_num])
                    elif node.node_type == Z:
                        if node.radians == 0:
                            if node.ctrl_a != -1:
                                # Controlled Z gate
                                qc.cz(qr[node.ctrl_a], qr[wire_num])
                            else:
                                # Pauli-Z gate
                                qc.z(qr[wire_num])
                        else:
                            if node.ctrl_a != -1:
                                # Controlled rotation around the Z axis
                                qc.crz(node.radians, qr[node.ctrl_a], qr[wire_num])
                            else:
                                # Rotation around Z axis
                                qc.rz(node.radians, qr[wire_num])
                    elif node.node_type == S:
                        # S gate
                        qc.s(qr[wire_num])
                    elif node.node_type == SDG:
                        # S dagger gate
                        qc.sdg(qr[wire_num])
                    elif node.node_type == T:
                        # T gate
                        qc.t(qr[wire_num])
                    elif node.node_type == TDG:
                        # T dagger gate
                        qc.tdg(qr[wire_num])
                    elif node.node_type == H:
                        if node.ctrl_a != -1:
                            # Controlled Hadamard
                            qc.ch(qr[node.ctrl_a], qr[wire_num])
                        else:
                            # Hadamard gate
                            qc.h(qr[wire_num])
                    elif node.node_type == SWAP:
                        if node.ctrl_a != -1:
                            # Controlled Swap
                            qc.cswap(qr[node.ctrl_a], qr[wire_num], qr[node.swap])
                        else:
                            # Swap gate
                            qc.swap(qr[wire_num], qr[node.swap])

        return qc

    def reset_circuit(self):
        self.nodes = np.empty((self.max_wires, self.max_columns),
                              dtype=CircuitGridNode)
        # the game crashes if the circuit is empty
        # initialize circuit with 3 identity gate at the end to prevent crash
        # identity gate are displayed by completely transparent PNG

        for i in range(self.max_wires):
            self.set_node(i, CIRCUIT_DEPTH - 1, CircuitGridNode(IDEN))


class CircuitGridNode:
    """Represents a node in the circuit grid"""
    def __init__(self, node_type, radians=0.0, ctrl_a=-1, ctrl_b=-1, swap=-1):
        self.node_type = node_type
        self.radians = radians
        self.ctrl_a = ctrl_a
        self.ctrl_b = ctrl_b
        self.swap = swap

    def __str__(self):
        string = 'type: ' + str(self.node_type)
        string += ', radians: ' + str(self.radians) if self.radians != 0 else ''
        string += ', ctrl_a: ' + str(self.ctrl_a) if self.ctrl_a != -1 else ''
        string += ', ctrl_b: ' + str(self.ctrl_b) if self.ctrl_b != -1 else ''
        return string

EMPTY = -1
IDEN = 0
X = 1
Y = 2
Z = 3
S = 4
SDG = 5
T = 6
TDG = 7
H = 8
SWAP = 9
# B = 10
CTRL = 11  # "control" part of multi-qubit gate
TRACE = 12  # In the path between a gate part and a "control" or "swap" part


class StatevectorGrid(pygame.sprite.Sprite):
    """Displays a statevector grid"""
    def __init__(self, circuit, qubit_num, num_shots):
        pygame.sprite.Sprite.__init__(self)
        self.image = None
        self.rect = None
        self.ball = Ball()
        self.block_size = int(round(self.ball.screenheight / 2 ** qubit_num))
        self.basis_states = comp_basis_states(circuit.width())
        self.circuit = circuit

        self.paddle = pygame.Surface([WIDTH_UNIT, self.block_size])
        self.paddle.fill(WHITE)
        self.paddle.convert()

        self.paddle_before_measurement(circuit, qubit_num, num_shots)

    def display_statevector(self, qubit_num):
        for y in range(2**qubit_num):
            text = pygame.font.Font(None, 60).render("|"+self.basis_states[y]+">", 1, WHITE)
            text_height = text.get_height()
            y_offset = self.block_size * 0.5 - text_height * 0.5
            self.image.blit(text, (2 * WIDTH_UNIT, y * self.block_size + y_offset))

    def paddle_before_measurement(self, circuit, qubit_num, shot_num):
        self.update()
        self.display_statevector(qubit_num)

        backend_sv_sim = BasicAer.get_backend('statevector_simulator')
        job_sim = execute(circuit, backend_sv_sim, shots=shot_num)
        result_sim = job_sim.result()
        quantum_state = result_sim.get_statevector(circuit, decimals=3)

        for y in range(len(quantum_state)):
            if abs(quantum_state[y]) > 0:
                self.paddle.set_alpha(int(round(abs(quantum_state[y])*255)))
                self.image.blit(self.paddle, (0, y * self.block_size))

    def paddle_after_measurement(self, circuit, qubit_num, shot_num):
        self.update()
        self.display_statevector(qubit_num)

        backend_sv_sim = BasicAer.get_backend('qasm_simulator')
        cr = ClassicalRegister(qubit_num)
        measure_circuit = deepcopy(circuit)  # make a copy of circuit
        measure_circuit.add_register(cr)    # add classical registers for measurement readout
        measure_circuit.measure(measure_circuit.qregs[0], measure_circuit.cregs[0])
        job_sim = execute(measure_circuit, backend_sv_sim, shots=shot_num)
        result_sim = job_sim.result()
        counts = result_sim.get_counts(circuit)

        self.paddle.set_alpha(255)
        self.image.blit(self.paddle, (0, int(list(counts.keys())[0], 2) * self.block_size))

        return int(list(counts.keys())[0], 2)

    def update(self):
        self.image = pygame.Surface([(self.circuit.width() + 1) * 3 * WIDTH_UNIT, self.ball.screenheight])
        self.image.convert()
        self.image.fill(BLACK)
        self.rect = self.image.get_rect()
        
#utils 

class Ball(pygame.sprite.Sprite):
    def __init__(self):
        super().__init__()

        # get ball screen dimensions
        self.screenheight = round(WINDOW_HEIGHT * 0.7)
        self.screenwidth = WINDOW_WIDTH
        self.width_unit = WIDTH_UNIT

        self.left_edge = self.width_unit
        self.right_edge = self.screenwidth - self.left_edge

        self.top_edge = self.width_unit * 0
        self.bottom_edge = self.screenheight - self.top_edge

        # define the ball sizes
        self.height = self.width_unit
        self.width = self.width_unit

        # create a pygame Surface with ball size
        self.image = pygame.Surface([self.height, self.width])

        self.image.fill(WHITE)

        self.rect = self.image.get_rect()

        self.x = 0
        self.y = 0
        self.speed = 0
        self.initial_speed_factor = 0.8
        self.direction = 0

        # initialize ball action type, measure and bounce flags
        self.ball_action = NOTHING
        self.measure_flag = NO

        # initialize ball reset on the left
        self.reset_position = LEFT
        self.reset()

        
        self.score = Score()

    def update(self):
        radians = math.radians(self.direction)

        self.x += self.speed * math.sin(radians)
        self.y -= self.speed * math.cos(radians)

        # Update ball position
        self.rect.x = self.x
        self.rect.y = self.y

        if self.y <= self.top_edge:
            self.direction = (180-self.direction) % 360

        if self.y > self.bottom_edge - 1*self.height:
            self.direction = (180-self.direction) % 360


    def reset(self):

        self.y = self.screenheight / 2
        self.speed = self.width_unit * self.initial_speed_factor

        # alternate reset at left and right
        if self.reset_position == LEFT:
            self.x = self.left_edge + self.width_unit * 15
            self.direction = random.randrange(30, 120)
            self.reset_position = RIGHT
        else:
            self.x = self.right_edge - self.width_unit * 15
            self.direction = random.randrange(-120, -30)
            self.reset_position = LEFT

    def bounce_edge(self):
        self.direction = (360-self.direction) % 360
        self.speed *= 1.1


    def get_xpos(self):
        xpos = self.x
        return xpos

    def get_ypos(self):
        ypos = self.y
        return ypos

    # 1 = comp, 2 = player, none = 0
    def action(self):

        if self.x -57 < self.left_edge:
            self.score.update(1)

        elif self.left_edge + 10 * self.width_unit <= self.x < self.left_edge + 12 * self.width_unit + 0:
            # measure the ball when it reaches the left measurement zone
            if self.measure_flag == NO:
                self.ball_action = MEASURE_LEFT
                self.measure_flag = YES
            else:
                self.ball_action = NOTHING

        elif self.right_edge - 12 * self.width_unit <= self.x < self.right_edge - 10 * self.width_unit:
            # measure the ball when it reaches the right measurement zone
            if self.measure_flag == NO:
                # do measurement if not yet done
                self.ball_action = MEASURE_RIGHT
                self.measure_flag = YES
            else:
                # do nothing if measurement was done already
                self.ball_action = NOTHING

        elif self.x > self.right_edge:
            # reset the ball when it reaches beyond right edge
            self.reset()

            self.score.update(0)

        else:
            # reset flags and do nothing when the ball is outside measurement and bounce zone
            self.ball_action = NOTHING
            self.measure_flag = NO

    def check_score(self, player):
        return self.score.get_score(player)

WHITE = 255, 255, 255
BLACK = 0, 0, 0
RED = 255, 0, 0
CYAN = 0, 255, 255
MAGENTA = 255, 0, 255
BLUE = 0, 0, 255
GREEN = 0, 255, 0
YELLOW = 255, 255, 0
GRAY = 128, 128, 128


class Input:
    """Handle input events"""

    def __init__(self):
        self.running = True
        pygame.init()
       

        self.gamepad_repeat_delay = 200
        self.gamepad_neutral = True
        self.gamepad_pressed_timer = 0
        self.gamepad_last_update = pygame.time.get_ticks()

    def handle_input(self, level, screen, scene):

        gamepad_move = False
        circuit_grid = level.circuit_grid


        # Handle Input Events
        for event in pygame.event.get():
            pygame.event.pump()

            if event.type == QUIT:
                self.running = False
            
            elif event.type == KEYDOWN:
                if event.key == K_ESCAPE:
                    self.running = False
                elif event.key == K_a:
                    circuit_grid.move_to_adjacent_node(MOVE_LEFT)
                    circuit_grid.draw(screen)
                    pygame.display.flip()
                elif event.key == K_d:
                    circuit_grid.move_to_adjacent_node(MOVE_RIGHT)
                    circuit_grid.draw(screen)
                    pygame.display.flip()
                elif event.key == K_w:
                    circuit_grid.move_to_adjacent_node(MOVE_UP)
                    circuit_grid.draw(screen)
                    pygame.display.flip()
                elif event.key == K_s:
                    circuit_grid.move_to_adjacent_node(MOVE_DOWN)
                    circuit_grid.draw(screen)
                    pygame.display.flip()
                elif event.key == K_x:
                    circuit_grid.handle_input_x()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_y:
                    circuit_grid.handle_input_y()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_z:
                    circuit_grid.handle_input_z()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_h:
                    circuit_grid.handle_input_h()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_SPACE:
                    circuit_grid.handle_input_delete()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_c:
                    # Add or remove a control
                    circuit_grid.handle_input_ctrl()
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_UP:
                    # Move a control qubit up
                    circuit_grid.handle_input_move_ctrl(MOVE_UP)
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_DOWN:
                    # Move a control qubit down
                    circuit_grid.handle_input_move_ctrl(MOVE_DOWN)
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_LEFT:
                    # Rotate a gate
                    circuit_grid.handle_input_rotate(-np.pi / 8)
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_RIGHT:
                    # Rotate a gate
                    circuit_grid.handle_input_rotate(np.pi / 8)
                    circuit_grid.draw(screen)
                    self.update_paddle(level, screen, scene)
                    pygame.display.flip()
                elif event.key == K_TAB:
                    # Update visualizations
                    # TODO: Refactor following code into methods, etc.
                    self.update_paddle(level, screen, scene)

    def update_paddle(self, level, screen, scene):
        # Update visualizations
        # TODO: Refactor following code into methods, etc.

        circuit_grid_model = level.circuit_grid_model
        right_statevector = level.right_statevector
        circuit_grid = level.circuit_grid
        statevector_grid = level.statevector_grid

        circuit = circuit_grid_model.compute_circuit()
        statevector_grid.paddle_before_measurement(
            circuit, scene.qubit_num, 100)
        right_statevector.arrange()
        circuit_grid.draw(screen)
        pygame.display.flip()

    def move_update_circuit_grid_display(self, screen, circuit_grid, direction):
        circuit_grid.move_to_adjacent_node(direction)
        circuit_grid.draw(screen)
        pygame.display.flip()
        

class Level:
    """Start up a level"""
    def __init__(self):
        self.level = 2  # game level
        self.win = False  # flag for winning the game
        self.left_paddle = pygame.sprite.Sprite()
        self.right_paddle = pygame.sprite.Sprite()

    def setup(self, scene, ball):
        """Setup a level with a certain level number"""
        scene.qubit_num = self.level
        self.circuit_grid_model = CircuitGridModel(scene.qubit_num, CIRCUIT_DEPTH)

        # the game crashes if the circuit is empty
        # initialize circuit with identity gate at the end of each line to prevent crash
        # identity gate are displayed by completely transparent PNG
        for i in range(scene.qubit_num):
            self.circuit_grid_model.set_node(i, CIRCUIT_DEPTH - 1, CircuitGridNode(IDEN))

        self.circuit = self.circuit_grid_model.compute_circuit()
        self.statevector_grid = StatevectorGrid(self.circuit, scene.qubit_num, 100)
        self.right_statevector = VBox(WIDTH_UNIT * 90, WIDTH_UNIT * 0, self.statevector_grid)
        self.circuit_grid = CircuitGrid(0, ball.screenheight, self.circuit_grid_model)

        # computer paddle

        self.left_paddle.image = pygame.Surface([WIDTH_UNIT, int(round(ball.screenheight ))])
        self.left_paddle.image.fill((255, 255, 255))
        self.left_paddle.image.set_alpha(255)
        self.left_paddle.rect = self.left_paddle.image.get_rect()
        self.left_paddle.rect.x = 9 * WIDTH_UNIT -50


        # player paddle for detection of collision. It is invisible on the screen

        self.right_paddle.image = pygame.Surface([WIDTH_UNIT, int(round(ball.screenheight / 2 ** scene.qubit_num))])
        self.right_paddle.image.fill((255, 0, 255))
        self.right_paddle.image.set_alpha(0)
        self.right_paddle.rect = self.right_paddle.image.get_rect()
        self.right_paddle.rect.x = self.right_statevector.xpos
MOVE_LEFT = 1
MOVE_RIGHT = 2
MOVE_UP = 3
MOVE_DOWN = 4
 
    
# Define global parameters

# For main.py

# For 15-inch MacBook Pro
#WINDOW_WIDTH=2880
#WINDOW_HEIGHT=1800

# For 13-inch MacBook Pro
#WINDOW_WIDTH=2560
#WINDOW_HEIGHT=1600

# For lower resolution screen
WINDOW_WIDTH = 1200
WINDOW_HEIGHT = 750

WIDTH_UNIT = round(WINDOW_WIDTH/100)
WINDOW_SIZE = WINDOW_WIDTH, WINDOW_HEIGHT
QUBIT_NUM = 2
CIRCUIT_DEPTH = 18

WIN_SCORE = 7

# For ball.py
LEFT = 0
RIGHT = 1

MEASURE_RIGHT = 1
MEASURE_LEFT = 2
NOTHING = 0

YES = 1
NO = 0

# For circuit_grid.py
GRID_WIDTH = WIDTH_UNIT * 4.96
GRID_HEIGHT = GRID_WIDTH

GATE_TILE_WIDTH = GRID_WIDTH * 0.76
GATE_TILE_HEIGHT = GATE_TILE_WIDTH

LINE_WIDTH = round(WIDTH_UNIT * 0.15)

# For scene.py
CLASSICAL_COMPUTER = 0
QUANTUM_COMPUTER = 1

EASY = 0.3
NORMAL = 0.6
EXPERT = 1.5



MOVE_LEFT = 1
MOVE_RIGHT = 2
MOVE_UP = 3
MOVE_DOWN = 4


# "home/felix/PycharmProjects/squash/data"
# main_dir = os.path.split(os.path.abspath(__file__))[0]
# data_dir = os.path.join(main_dir, '..', 'data')
data_dir = "../data"

def load_image(name, colorkey=None, scale=WIDTH_UNIT/13):
    fullname = os.path.join(data_dir, 'images', name)
    try:
        image = pygame.image.load(fullname)
    except pygame.error:
        print('Cannot load image:', fullname)
        raise SystemExit(str(geterror()))
    image = image.convert()
    if colorkey is not None:
        if colorkey == -1:
            colorkey = image.get_at((0, 0))
        image.set_colorkey(colorkey, RLEACCEL)
    image = pygame.transform.scale(image, tuple(round(scale*x) for x in image.get_rect().size))
    return image, image.get_rect()





class Scene:
    """Display Game Over screen and handle play again"""

    def __init__(self):
        super().__init__()

        self.begin = False
        self.restart = False
        self.qubit_num = 2

    def start(self, screen, ball):
        """Show start screen"""

        screen.fill(BLACK)


        #self.credits(screen)

        while not self.begin:
            ball.initial_speed_factor = NORMAL
            return True
 

            if self.begin:
                # reset all parameters to restart the game
                screen.fill(BLACK)


            pygame.display.flip()

        # reset restart flag when self.restart = True and the while ends
        self.begin = False

    
        

    def dashed_line(self, screen, ball):
        for i in range(10, ball.screenheight, 2 * WIDTH_UNIT):  # draw dashed line
            pygame.draw.rect(screen, GRAY, (WINDOW_WIDTH // 2 - 5, i, 0.5 * WIDTH_UNIT, WIDTH_UNIT), 0)

    def score(self, screen, ball):

        text = pygame.font.Font(None, 74).render('Score', 1, GRAY)
        text_pos = text.get_rect(center=(round(WINDOW_WIDTH * 0.75) - WIDTH_UNIT * 4.5, WIDTH_UNIT * 2))
        screen.blit(text, text_pos)


        score_print = str(ball.check_score(1))
        text = pygame.font.Font(None, 74).render(score_print, 1, GRAY)
        text_pos = text.get_rect(center=(round(WINDOW_WIDTH * 0.75) - WIDTH_UNIT * 4.5, WIDTH_UNIT * 6))
        screen.blit(text, text_pos)


class Score(pygame.sprite.Sprite):
    def __init__(self):
        super().__init__()

        self.player = 0
        self.computer = 0

    # Player = 0, Computer = 1
    def update(self, score):
        if score == 0:
            self.computer += 1
            self.player = 0
        if score == 1:
            self.player += 1

    def get_score(self, player):
        if player == 0:
            return self.computer
        if player == 1:
            return self.player

    def reset_score(self):
        self.computer = 0
        self.player = 0

        MAX_NUM_QUBITS = 10


def comp_basis_states(num_qubits):
    num_qb = min(num_qubits, MAX_NUM_QUBITS)
    basis_states = []
    for idx in range(2**num_qb):
        state = format(idx, '0' + str(num_qb) + 'b')
        basis_states.append(state)
    return basis_states

In [4]:
pygame.init()

flags = DOUBLEBUF | HWSURFACE 
screen = pygame.display.set_mode(WINDOW_SIZE)

pygame.display.set_caption('Squash')

# clock for timing
clock = pygame.time.Clock()
old_clock = pygame.time.get_ticks()

# initialize scene, level and input Classes
scene = Scene()
level = Level()
input = Input()

# define ball
ball = Ball()
balls = pygame.sprite.Group()   # sprite group type is needed for sprite collide function in pygame
balls.add(ball)

# Show start screen to select difficulty
input.running = scene.start(screen, ball)     # start screen returns running flag
level.setup(scene, ball)

# Put all moving sprites a group so that they can be drawn together
moving_sprites = pygame.sprite.Group()
moving_sprites.add(ball)
moving_sprites.add(level.left_paddle)
moving_sprites.add(level.right_paddle)

# update the screen
pygame.display.flip()

# reset the ball
ball.reset()

# a valuable to record the time when the paddle is measured
measure_time = 100000

# Main Loop
while input.running:
    # set maximum frame rate
    clock.tick(60)
    # refill whole screen with black color at each frame
    screen.fill(BLACK)

    ball.update()  # update ball position
    #scene.dashed_line(screen, ball)  # draw dashed line in the middle of the screen
    scene.score(screen, ball)   # print score

    # level.statevector_grid.display_statevector(scene.qubit_num) # generate statevector grid
    level.right_statevector.draw(screen)  # draw right paddle together with statevector grid
    level.circuit_grid.draw(screen)  # draw circuit grid
    moving_sprites.draw(screen)  # draw moving sprites


    # handle input events
    input.handle_input(level, screen, scene)

    # check ball location and decide what to do
    ball.action()

    if ball.ball_action == MEASURE_RIGHT:
        circuit = level.circuit_grid_model.compute_circuit()
        pos = level.statevector_grid.paddle_after_measurement(circuit, scene.qubit_num, 1)
        level.right_statevector.arrange()

        # paddle after measurement
        level.right_paddle.rect.y = pos * ball.screenheight/(2**scene.qubit_num)
        measure_time = pygame.time.get_ticks()

    if pygame.sprite.spritecollide(level.right_paddle, balls, False):
        ball.bounce_edge()

    if pygame.sprite.spritecollide(level.left_paddle, balls, False):
        ball.bounce_edge()

    if pygame.time.get_ticks() - measure_time > 400:
        # refresh the screen a moment after measurement to update visual
        input.update_paddle(level, screen, scene)
        # add a buffer time before measure again
        measure_time = pygame.time.get_ticks() + 100000

    # Update the screen
    pygame.display.flip()

pygame.quit()

Found gate:  1  on wire:  1
Replacing wire  0  in column  1
Replacing wire  1  in column  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Found gate:  1  on wire:  1
Replacing wire  0  in column  1
Replacing wire  1  in column  1
