In [None]:
# default_exp samplers.subset_helper

# Subset Helper

> Collection of functions to assist the subset sampling process.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
#export
import qsam.math as math
from qsam.circuit import partition
from qsam.samplers.constants import *

import itertools as it
import functools
import numpy as np

## Subset analytics

In [None]:
#export
def circuit_partitions(circ, partition_names):
    return [partition(circ, GATE_GROUPS[name]) for name in partition_names]

In [None]:
#export
def circuit_weight_vectors(w_max, w_exclude={}):
    w_exclude = set((w,) if isinstance(w, int) else w for w in w_exclude)
    w_maxs = [w_max] if isinstance(w_max, int) else w_max

    w_upto_w_maxs = [tuple(range(w_max+1)) for w_max in w_maxs]
    w_vecs = list(it.product( *w_upto_w_maxs ))
    filtered_w_vecs = [w for w in w_vecs if w not in w_exclude]
    return filtered_w_vecs

In [None]:
#export
def circuit_subset_occurence(partitions, partition_w_vecs, p_phy_per_partition):
    """Return (weight)x(p_phys) (parition) subset occurance matrix transforming p_SS vector to p_L vector"""
    n_partition_elems = np.array([len(p) for p in partitions])
    Aws = np.array([math.binom(w_vec, n_partition_elems, p_phy_per_partition) for w_vec in partition_w_vecs])
    Aws = np.product(Aws, axis=-1) # mult Aws for multi-parameter, i.e. multi-partitions
    return Aws

In [None]:
#export
def protocol_partitions(circuits_dict, partition_names):
    return {c_hash: circuit_partitions(circ, partition_names)
           for c_hash, circ in circuits_dict.items() if circ._noisy}

In [None]:
#export
def protocol_weight_vectors(partition_dict):
    return {c_hash: circuit_weight_vectors([len(p) for p in partitions]) 
            for c_hash, partitions in partition_dict.items()}

In [None]:
#export
def protocol_subset_occurence(partition_dict, w_vecs_dict, p_phys):
    return {c_hash: circuit_subset_occurence(partition_dict[c_hash], w_vecs, p_phys)
            for c_hash, w_vecs in w_vecs_dict.items()}

## Subset selection

In [None]:
#export
def w_plus1_filter(sampler, w_ids, circuit_hash, **kwargs):
    n_Aws = np.ma.array(sampler.Aws_pmax[circuit_hash])
    n_Aws[w_ids,] = np.ma.masked # mask existing (already sampled) subsets
    w_ids.append(np.argmax(n_Aws)) # add next most important subset as possible candidate
    return w_ids

In [None]:
#export
def random_sel(w_ids, **kwargs):
    return np.random.choice(w_ids)

In [None]:
#export
def ERV_sel(self, sampler, w_ids, circuit_hash, tree_node, **kwargs):
    erv_deltas = []
    v_L = sampler.tree.var(sampler.Aws_pmax)
    delta = sampler.tree.delta(sampler.Aws_pmax)
    for idx in w_ids:
        w_vec = sampler.w_vecs[circuit_hash][idx]

        _tree_node = sampler.tree.add(w_vec, parent=tree_node, ckey=(circuit_hash, idx))
        __tree_node = sampler.tree.add('FAIL', parent=_tree_node, is_fail=True)
        _delta = sampler.tree.delta(sampler.Aws_pmax)
        _rate = __tree_node.rate

        _tree_node.counts += 1        
        v_L_minus = sampler.tree.var(sampler.Aws_pmax)

        __tree_node.counts += 1
        v_L_plus = sampler.tree.var(sampler.Aws_pmax)

        _v_L = _rate * v_L_plus + (1 - _rate) * v_L_minus
        erv_delta = np.abs(v_L - _v_L) + (delta - _delta) 
        # erv_delta = v_L - _v_L + (delta - _delta) # Check with new def of var
        erv_deltas.append( erv_delta )

        # revert the change
        _tree_node.counts -= 1
        __tree_node.counts -= 1
        if _tree_node.counts == 0: sampler.tree.detach(_tree_node)
        if __tree_node.counts == 0: sampler.tree.detach(__tree_node)
    
    idx = np.argmax(erv_deltas)
    return w_ids[idx]