In [None]:
# default_exp protocol

# Protocol

> Representation of quantum error correction protocol.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
#export
import networkx as nx
from re import sub
from qsam.circuit import make_hash
from functools import lru_cache
import pickle

We model a protocol as a directed graph without loops (acyclic `DiGraph`) in which each `node` represents one `circuit` and every `edge` represents a transition rule between circuits, which we call `check`. We give each `node` a **unique** label and a `circuit`. A protocol must always *START* at **one** point and *EXIT* at **one or several** points. The next circuit in a protocol must always be uniquely identifiable. Thus, for each circuit there can only be one check which evaluates to *True* at any time. If this is not a case an error is thrown. If no check evaluates to *True*, the protocol is understood to terminate implicitly, i.e. terminates with status *no logical error occurred*. If an *EXIT* node is reached on the other hand, the protocol terminates explicitly, i.e. with status *logical error occurred*. These events are later explicitly tracked in the samplers.

In [None]:
#export
class Protocol(nx.DiGraph):
    """Representation of a Quantum Error Correction protocol"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._circuits = {} # hash table
        
    def add_node(self, name, circuit):
        circuit_hash = make_hash(circuit)
        self._circuits[circuit_hash] = circuit
        super().add_node(name, circuit_hash=circuit_hash)
        
    def add_nodes_from(self, names, circuits):
        if not isinstance(circuits, (list,tuple,set)):
            circuits = [circuits] * len(names)
        for name, circuit in zip(names, circuits):
            self.add_node(name, circuit)
    
    @lru_cache(maxsize=128)
    def circuit_hash(self, node):
        return self.nodes(data='circuit_hash')[node]
    
    @lru_cache(maxsize=128)
    def checks(self, node):
        adj_nodes = self.out_edges(node)
        return {pair[1]: self.edges[pair]['check'] for pair in adj_nodes}

Helper functions

In [None]:
#export
def draw_protocol(protocol):
    """Draw graph representation of protocol"""
    
    pos = nx.kamada_kawai_layout(protocol)
    col_val_map = {'START': '#99ccff', 'EXIT': '#ff9999'}
    col_vals = [col_val_map.get(node, '#ffb266') for node in protocol.nodes]
    nx.draw(protocol, pos, node_color=col_vals, with_labels=True, node_size=1200)#, font_color='white')
    edge_labels = nx.get_edge_attributes(protocol, 'check')
    nx.draw_networkx_edge_labels(protocol, pos, edge_labels)

In [None]:
#export 
#hide
@lru_cache(maxsize=256)
def cached_eval(eval_str):
    return eval(eval_str)

In [None]:
#export
def iterate(protocol):
    """Iterator over protocol"""
    
    hist = {}
    node = "START"
    repl_fn = lambda match: hist.get(match.group(1), "None")
    eval_check = lambda check: cached_eval(sub("`(.*?)`", repl_fn, check))
    
    while True:
        checks = protocol.checks(node)
        next_nodes = [n for n,c in checks.items() if eval_check(c)]
        if len(next_nodes) == 0: 
            yield None
        elif len(next_nodes) == 1: 
            node = next_nodes[0]
            hist[node] = yield node
        else: 
            raise Exception(f"Too many checks True for node {node}.") 

We can generate a protocol by adding nodes and edges with their corresponding `circuit`s and `check`s. Note that we also give a check from the *START* node to the first node in the protocol which always evaluates to *True*. Thus this transition will always be made. Futhermore, we define variables representing current or past measurements of circuits for a node by using backticks. We can also see that the check can be anything which evaluates to a boolean *True* or *False* - we only need to take care that in case a measurment has not occurred for a certain "string of events" that part of the expression involving this variable must evaluate to False.

In [None]:
p = Protocol()
p.add_nodes_from(['c1','c2','c3'], circuits=['a','b','c'])
p.add_edge('START', 'c1', check='True')
p.add_edge('c1', 'c2', check='`c1`==0')
p.add_edge('c2', 'c3', check='`c2`==0')

# custom check function
parity = lambda x, y: bin(x).count('1') % 2 == y if x else False

p.add_edge('c1', 'EXIT', check='parity(`c1`,1)')
p.add_edge('c2', 'EXIT', check='`c2`==1')
p.add_edge('c3', 'EXIT', check='`c3`==1')

draw_protocol(p)

NameError: name 'draw_protocol' is not defined

We can simply iterate over the protocol by calling the `iterate` iterator function. We have to provide a measurement with each iteration which is used to evaluate the checks and to finally yield a following node. At the end of the protocol the iterator will return `None` which we can use as a exit condition for the while loop. Note that we need to always call `next()` on the iterator before we can `.send()` stuff to it.

In [None]:
p_it = iterate(p)
node = next(p_it)

while node:
    print(node)
    node = p_it.send('0')

c1
c2
c3


In [None]:
#export
def save_protocol(protocol, fname, path='.'):
    """Saves a protocol to `path` with file name `fname`"""
    file = open(f'{path}/{fname}.proto', 'wb')
    pickle.dump(protocol,file)
    file.close()

In [None]:
#export
def load_protocol(fname, path='.'):
    """Loads a protocol from `path` with file name `fname`"""
    file = open(f'{path}/{fname}.proto', 'rb')
    res = pickle.load(file)
    file.close()
    return res

We can also save and load protocols

In [None]:
save_protocol(p, 'test_protocol')
p2 = load_protocol('test_protocol')
print(p2)

Protocol with 5 nodes and 6 edges
