In [None]:
# default_exp sampler.base

# Sampler base class

> Super class for all sampler objects

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
#export
from qsam.sampler.tree import CountTree, Variable, Constant

from qsam.callbacks import CallbackList
from tqdm.auto import tqdm
import itertools as it
import numpy as np
import qsam.math as math
from qsam.noise import E0

from collections.abc import Iterable
from collections import OrderedDict

Definitions: 
* A *circuit location* (`cloc`) is a tuple (tick_index,qubit(s)) which specifies a location in a circuit. In the simplest case this tuple can be understood as a Euclidean coordinate in tick-qubit-space, for example (1,2) for a circuit element at tick 1 on qubit 2, for example a single-qubit gate. However, the second entry in the tuple can also be a tuple of qubits, i.e. (tick_index, (qubitA, qubitB)), for example (1,(2,3)) for a 2-qubit gate in tick 1 from qubit 2 (control) to qubit 3 (target). It is important that (1,(2,3)) $\neq$ (1,(3,2)) $\neq$ (1,2)+(1,3). 
* Common circuit locations can be grouped into a *location group* (`locgrp`), for example the set of all 1-qubit gate locations or the set of all idle qubits or all neighboring qubits of 1-qubit gates. A location group is defined by the error model, which sweeps through a circuit and extracts all elements belonging to a certain group. Note that circuit locations can appear in more than one location group and the union of all location groups must not necessarily give all possible circuit locations. This is for example the case for error models which only define errors on gates but not on idle qubits.
* By *location group weight* (`locgrp_wgt`) we refer to an integer number specifying the amount of location group elements on which to place faults. Note that the weight only specifies the amount, neither the specific element on which the fault is placed (this will be choosen uniformly random) nor the specific error which will be placed (this will be drawn uniformly random from a set of group errors (`locgrp_errsets`) defined by the error model).
* By *location group probability* (`locgrp_prob`) on the other hand, we refer to the probability of placing faults on **each** element of a group. Thus, during circuit execution it will be decided for each element of a location group if an error is placed based on the group probability.
* Both the above can be written as vectors over all location groups an error model specifies. The vector of location group weights (`locgrp_wgts`) thereby characterizes a so-called fault-weight `subset`, i.e. a particular amount of faults on each location group defined in an error model. These subsets form the sampling space on which we apply importance sampling during subset sampling.
* For subset sampling we analytically calculate the *subset occurence probability* (`Aws`) as a function of location group probabilities `locgrp_probs`for each location group. As a subset occurence probability results from a binomial distribution *binom(locgrp_wts, locgrp_lens, locgrp_probs)* we also refer to it as *binomial weight* of a subset, or simply $A_w$, as this is the symbol used in the paper.
* If one of the above quantities is preceeded by a `protocol_` we have a dictionary with *circuit ids* (`cids`) as keys and the given quantity for a circuit as value. Note that `locgrp_probs` are defined equally for each circuit in the protocol as those represent physical error rates of a (faulty) device which also won't be different for different circuits in reality. To sample from more than physical error rate a range `locgrp_probs_range` can be specified which contains a range of error rates for each location group.
* `locgrp_wgts_combis` are all possible `locgrp_wgts` vectors up to the total number of elements (`locgrp_len`) in each location group.

In [None]:
#export

locgrp_lens = lambda locgrps: [len(locgrp) for locgrp in locgrps.values()]

def ranges_from_probs(locgrp_probs):

    if isinstance(locgrp_probs, dict): 
        locgrp_probs = list(locgrp_probs.values())
    
    prob_lists = [p if isinstance(p, Iterable) else [p] for p in locgrp_probs]
    targ_len = max([len(prob_list) for prob_list in prob_lists])

    pad = lambda lst,e: np.append(lst, [e] * (targ_len - len(lst))) if len(lst) < targ_len else lst
    prob_lists = np.array([pad(prob_list,prob_list[-1]) for prob_list in prob_lists])

    return prob_lists.T

def circuit_subset_occurence(locgrps: dict, locgrps_wgts_combis: list, locgrp_probs: (dict,list)) -> np.ndarray:
    """Return matrix of dims (locgrp_wgts)x(probs_range), i.e. the binomial weights 
    for each location group weight (rows) and the corresponding physical error rates (columns)."""
    # assert len(locgrps) == len(locgrp_probs), f"Must specify probabilities for all location groups ({len(locgrps)}). Received ({len(locgrp_probs)}) values."

    probs_ranges = ranges_from_probs(locgrp_probs)
    Aws = np.array([math.binom(np.array(locgrp_wgts), np.array(locgrp_lens(locgrps)), probs_ranges) for locgrp_wgts in locgrps_wgts_combis])
    Aws = np.product(Aws, axis=-1)
    return Aws

def protocol_subset_occurence(protocol_locgrps: dict, protocol_locgrp_wgts_combis: dict, locgrp_probs: (dict,list)) -> dict:
    """Generate subset occurence probabilities corresponding to each subset weight
    combination per circuit location group in `locgrp` in protocol."""
    
    return {cid: {wgts_vec: Aw for wgts_vec, Aw in zip(wgts_combis, circuit_subset_occurence(protocol_locgrps[cid], wgts_combis, locgrp_probs))} 
            for cid, wgts_combis in protocol_locgrp_wgts_combis.items()}

def locgrp_wgts_combis(locgrps: dict) -> list:
    """Return list of all possible location group weight vector combinations."""
    return list(it.product( *[tuple(range(N+1)) for N in locgrp_lens(locgrps)] ))
    
def protocol_locgrp_wgts_combis(protocol_locgrps: dict) -> dict:
    """Generate all possible location group weight vector combinations 
    for each circuit in a given protocol."""
    
    return {cid: locgrp_wgts_combis(locgrps) for cid, locgrps in protocol_locgrps.items()}

In [None]:
prob_vecs = ranges_from_probs([0.1,0.1])
for prob_vec in prob_vecs:
    print(prob_vec)

[0.1 0.1]


In [None]:
d = {'a': {'p': {(0,1)}, 'p2': {(0,(0,1))}}}
proto_ws = protocol_locgrp_wgts_combis(d)
proto_Aws = protocol_subset_occurence(d,proto_ws,{'p': 0.1, 'p2': 0.2})

{cid: {w: Aw for Aw, w in zip(Aws,ws)} for (cid, Aws), ws in zip(proto_Aws.items(), proto_ws.values())}

{'a': {(0, 0): (0, 0), (0, 1): (0, 1), (1, 0): (1, 0), (1, 1): (1, 1)}}

In [None]:
#export
class Sampler:
    
    def __init__(self, protocol, simulator, err_model=E0):
        self.protocol = protocol
        self.simulator = simulator
        self.n_qubits = protocol.n_qubits
        self.err_model = err_model
        self.trees = OrderedDict()
        self.protocol_locgrps = {cid: err_model.group(circuit) for cid, circuit in self.protocol._circuits.items()}
        self.protocol_wgts_combis = protocol_locgrp_wgts_combis(self.protocol_locgrps)
    
    def optimize(self, tree_node, circuit):
        """Must be overwritten by child class."""
        raise NotImplemented
            
    def run(self, n_samples: int, err_params: (dict,list), callbacks=[]) -> None:
        assert isinstance(err_params, (dict,list)), "`err_params` must be either dict or list."
        assert len(err_params) == len(self.err_model.groups), "`err_params` must have same length as `groups` defined in error model."
        prob_vecs = ranges_from_probs(err_params)
            
        if not isinstance(callbacks, CallbackList):
            callbacks = CallbackList(sampler=self, callbacks=callbacks)
        callbacks.on_sampler_begin()
        
        for prob_vec in tqdm(prob_vecs, desc="Total"):
            self.stop_sampling = False
            
            self.tree_idx = tuple(prob_vec)
            tree = self.trees.get(self.tree_idx, CountTree(min_path_weight=2 if self.protocol._ft else 1))
            self.trees[self.tree_idx] = tree
            tree.constants = protocol_subset_occurence(self.protocol_locgrps, self.protocol_wgts_combis, prob_vec)
                
            for _ in tqdm(range(n_samples), desc=f'p_phy={",".join(list(f"{p:.2E}" for p in prob_vec))}', leave=True):
                callbacks.on_protocol_begin()

                state = self.simulator(self.n_qubits)
                tree_node = None

                for name, circuit in self.protocol:
                    callbacks.on_circuit_begin()

                    tree_node = tree.add(name=name, parent=tree_node, nodetype=Variable)
                    tree_node.counts += 1
                    opt_out = dict()

                    if circuit:
                        if not circuit._noisy:
                            msmt = state.run(circuit)
                        else:
                            opt_out = self.optimize(tree_node, circuit, prob_vec)
                            tree_node = tree.add(name=opt_out['grp_wgts'], parent=tree_node, nodetype=Constant, cid=circuit.id,
                                                 is_deterministic=True if circuit._ff_det and not any(opt_out['grp_wgts']) else False)
                            tree_node.counts += 1
                            fault_circuit = self.err_model.run(circuit, opt_out['flocs'])

                            msmt = state.run(circuit, fault_circuit)
                        self.protocol.send(msmt)

                    elif name != None:
                        tree_node.marked = True
                    
                    callbacks.on_circuit_end(locals() | opt_out)

                callbacks.on_protocol_end()
                if self.stop_sampling: break
            
        callbacks.on_sampler_end()