In [None]:
#| default_exp sampler.base

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from qsample.sampler.tree import CountTree, CircuitCountNode, SubsetCountNode

from qsample.callbacks import CallbackList
from tqdm.auto import tqdm
import itertools as it
import numpy as np
import qsample.math as math
from qsample.noise import E0
import dill as pickle

from collections.abc import Iterable

Definitions: 
* A *circuit location* (`cloc`) is a tuple (tick_index,qubit(s)) which specifies a location in a circuit. In the simplest case this tuple can be understood as a Euclidean coordinate in tick-qubit-space, for example (1,2) for a circuit element at tick 1 on qubit 2, for example a single-qubit gate. However, the second entry in the tuple can also be a tuple of qubits, i.e. (tick_index, (qubitA, qubitB)), for example (1,(2,3)) for a 2-qubit gate in tick 1 from qubit 2 (control) to qubit 3 (target). It is important that (1,(2,3)) $\neq$ (1,(3,2)) $\neq$ (1,2)+(1,3). 
* Common circuit locations can be grouped into a *location group* (`locgrp`), for example the set of all 1-qubit gate locations or the set of all idle qubits or all neighboring qubits of 1-qubit gates. A location group is defined by the error model, which sweeps through a circuit and extracts all elements belonging to a certain group. Note that circuit locations can appear in more than one location group and the union of all location groups must not necessarily give all possible circuit locations. This is for example the case for error models which only define errors on gates but not on idle qubits.
* By *location group weight* (`locgrp_wgt`) we refer to an integer number specifying the amount of location group elements on which to place faults. Note that the weight only specifies the amount, neither the specific element on which the fault is placed (this will be choosen uniformly random) nor the specific error which will be placed (this will be drawn uniformly random from a set of group errors (`locgrp_errsets`) defined by the error model).
* By *location group probability* (`locgrp_prob`) on the other hand, we refer to the probability of placing faults on **each** element of a group. Thus, during circuit execution it will be decided for each element of a location group if an error is placed based on the group probability.
* Both the above can be written as vectors over all location groups an error model specifies. The vector of location group weights (`locgrp_wgts`) thereby characterizes a so-called fault-weight `subset`, i.e. a particular amount of faults on each location group defined in an error model. These subsets form the sampling space on which we apply importance sampling during subset sampling.
* For subset sampling we analytically calculate the *subset occurence probability* (`Aws`) as a function of location group probabilities `locgrp_probs`for each location group. As a subset occurence probability results from a binomial distribution *binom(locgrp_wts, locgrp_lens, locgrp_probs)* we also refer to it as *binomial weight* of a subset, or simply $A_w$, as this is the symbol used in the paper.
* If one of the above quantities is preceeded by a `protocol_` we have a dictionary with *circuit ids* (`cids`) as keys and the given quantity for a circuit as value. Note that `locgrp_probs` are defined equally for each circuit in the protocol as those represent physical error rates of a (faulty) device which also won't be different for different circuits in reality. To sample from more than physical error rate a range `locgrp_probs_range` can be specified which contains a range of error rates for each location group.
* `locgrp_wgts_combis` are all possible `locgrp_wgts` vectors up to the total number of elements (`locgrp_len`) in each location group.

In [None]:
#| export

# Convert elements of list into list (if not already)
tolists = lambda l: [e if isinstance(e,Iterable) else [e] for e in l]

# Compute lens of sublists in list
lens = lambda list_of_lists: list(map(len,list_of_lists))

# Compute maximum len of list of sublists
maxlen = lambda list_of_lists: max(lens(list_of_lists))

# Fill list `l` to len `targ_len` by duplicating elements
pad = lambda l,targ_len: np.append(l, [l[-1]] * (targ_len - len(l)))

def equalize_lens(mixed_list: list):
    """Convert mixed list of lists and elements to list of lists, s.t.
    each sublist has the same length as the longest sublist."""
    list_of_lists = tolists(mixed_list)
    return [pad(e,maxlen(list_of_lists)) for e in list_of_lists]

def err_probs_tomatrix(err_probs: dict, groups: list) -> np.ndarray:
    """Convert dict of error probabilities into matrix of dimensions
    (error probability)x(circuit location group)."""
    err_probs = sort_by_list(err_probs, groups)
    return np.array(equalize_lens(err_probs.values())).T

def sort_by_list(d: dict, l: list) -> dict:
    """Sort dictionary keys by order of keys in list."""
    return dict(sorted(d.items(), key=lambda pair: l.index(pair[0])))

def subset_occurence(groups: list, subsets: np.array, group_prob_range: np.array) -> np.array:
    """Calculate matrix of subset occurences, Aws."""
    Aws = [math.binom(subset, lens(groups), group_prob_range) for subset in subsets]
    return np.product(Aws, axis=-1)

def all_subsets(groups: list) -> list:
    """Calculate all possible subset tuples from list of lists containing
    group elements for each group."""
    return list(it.product( *[tuple(range(N+1)) for N in lens(groups)]))

def protocol_all_subsets(protocol_groups: dict) -> dict:
    """Calculate all possible subset tuples for each circuit."""
    return {cid: all_subsets(groups_dict.values()) for cid,groups_dict in protocol_groups.items()}

def protocol_subset_occurence(protocol_groups: dict, protocol_subsets: dict, group_probs: dict) -> dict:
    """Calculate all subset occurences for each circuit."""
    return {cid: {subset: Aw for subset, Aw in zip(subsets, subset_occurence(protocol_groups[cid].values(),subsets,group_probs))}
                  for cid,subsets in protocol_subsets.items()}

In [None]:
#| export
class Sampler:
    """Base class for other Sampler classes to inherit
    
    Attributes
    ----------
    protocol : Protocol
        Protocol to sample "marked" events from
    simulator : StabilizerSimulator or StatevectorSimulator
        Simulator used to simulate circuit execution on
    n_qubits : int
        Number of qubits used in protocol
    err_model : ErrorModel
        Error model used during circuit simulation
    trees : dict of CountTree
        `CountTree`s used to accumulate sampling information. One tree per sample.
    """
    
    def __init__(self, protocol, simulator, err_probs={"0":{}}, err_model=None):
        """
        Parameters
        ----------
        protocol : Protocol
            Protocol to sample "marked" events from
        simulator : StabilizerSimulator or StatevectorSimulator
            Simulator used to simulate circuit execution on
        err_probs : dict
            Error probabilites per location group; must match location group names of error model
        err_model : ErrorModel
            Error model used during circuit simulation
        """
        self.protocol = protocol
        self.simulator = simulator
        self.n_qubits = protocol.n_qubits
        self.err_model = err_model() if err_model else E0()
        
        self._set_subsets()
        
        assert isinstance(err_probs, dict)
        assert set(err_probs.keys()) == set(self.err_model.groups)

        self.trees = dict()
        for prob_vec in err_probs_tomatrix(err_probs, self.err_model.groups):
            tree = CountTree(#fault_tolerance_level=1 if self.protocol.fault_tolerant else 0,
                             constants=protocol_subset_occurence(self.protocol_groups, self.protocol_subsets, prob_vec))
            self.trees[tuple(prob_vec)] = tree
         
    def _set_subsets(self):
        """(Re)calculate location groups per circuit in protocol and its possible subsets."""
        self.protocol_groups = {cid: self.err_model.group(circuit) for cid, circuit in self.protocol._circuits.items()}
        self.protocol_subsets = protocol_all_subsets(self.protocol_groups)
        
    def save(self, path):
        """Save sampler object to path"""
        with open(path, 'wb') as fp:
            pickle.dump(self, fp)
    
    @staticmethod
    def load(path):
        """Load and return sampler object from path"""
        with open(path, 'rb') as fp:
            data = pickle.load(fp)
        return data

    def optimize(self, tree_node, circuit):
        """Must be overwritten by child class."""
        raise NotImplemented
            
    def run(self, n_shots: int, callbacks=[]) -> None:
        """Run `n_shots` protocol simulations per sample (i.e. physical error rate).
        
        Parameters
        ----------
        n_shots : int
            Number of shots (i.e. protocol runs) obtained per sample
        callbacks : list
            List of callbacks executed during sampling process
        """
        
        if not isinstance(callbacks, CallbackList):
            callbacks = CallbackList(sampler=self, callbacks=callbacks)
        callbacks.on_sampler_begin()
        
        for prob_vec, tree in self.trees.items():
            self.stop_sampling = False
            self.tree_idx = prob_vec
   
            for _ in tqdm(range(n_shots), desc=f'p_phy={",".join(list(f"{p:.2E}" for p in prob_vec))}', leave=True):
                callbacks.on_protocol_begin()

                state = self.simulator(self.n_qubits)
                tree_node = None

                for name, circuit in self.protocol:
                    callbacks.on_circuit_begin()

                    tree_node = tree.add(name=name, parent=tree_node, node_type=CircuitCountNode)
                    # set tree_node.invariant = True if path has max weight 0.
                    tree_node.count += 1
                    opt_out = dict()

                    if circuit:
                        if not circuit._noisy:
                            msmt = state.run(circuit)
                            tree_node.invariant = True
                            # subset = (0,) # bad
                        else:
                            opt_out = self.optimize(tree_node, circuit, prob_vec)
                            fault_circuit = self.err_model.run(circuit, opt_out['flocs'])
                            msmt = state.run(circuit, fault_circuit)
                            subset = opt_out['subset']
                            tree_node = tree.add(name=subset, parent=tree_node, node_type=SubsetCountNode, circuit_id=circuit.id,
                                                 det=True if circuit._ff_det and not any(subset) else False)
                            tree_node.count += 1
                              
                        self.protocol.send(msmt)

                    elif name != None:
                        tree.marked_leaves.add(tree_node)
                    # set tree_node.invariant = True if path is invariant
                    
                    # at end of each circuit update tree.constants if necessary?
                    callbacks.on_circuit_end(locals() | opt_out)

                callbacks.on_protocol_end()
                if self.stop_sampling: break
            
        callbacks.on_sampler_end()