# This notebook showcases special or exotic instances of DPPs
See the [exotic section](https://dppy.readthedocs.io/en/latest/exotic_dpps/index.html) of the documentation

In [None]:
%pylab inline

%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(0, os.path.abspath('../dppy'))

from exotic_dpps import *

## Uniform Spanning Trees

Special projection DPP associated to the Uniform measure on Spanning Trees (UST) of a graph.

In [None]:
from itertools import combinations
from collections import Counter

### Initial graph

In [None]:
# Build graph
g = nx.Graph()
edges = [(0,2), (0,3), (1,2), (1,4), (2,3), (2,4), (3,4)]
g.add_edges_from(edges)

# Display the graph
ust = UST(g)
ust.plot_graph()

#### Uniform Spanning Tree object
parametrized by a `networkx` graph (undirected and connected).

In [None]:
g = nx.Graph()
edges = [(0,2), (0,3), (1,2), (1,4), (2,3), (2,4), (3,4)]
g.add_edges_from(edges)

ust = UST(g)

ust.plot_graph()

In [None]:
ust.compute_kernel()
ust.plot_kernel()

In [None]:
for md in ('Aldous-Broder', 'Wilson', 'DPP_exact'):
    ust.sample(md); ust.plot()

### Check uniformity of samples from each procedure

###### Compute the list of spanning trees of the graph

In [None]:
potential_st = combinations(np.arange(ust.nb_edges), ust.nb_nodes-1) # Spanning trees have |V|-1 edges
potential_st = np.array(list(potential_st))

ust.compute_kernel()
# minors of transfer current matrix of size |V|-1, if non zeros then corresponding edges form a spanning tree
is_st = lambda x: la.det(ust.kernel[np.ix_(x, x)])>1e-8 
list_st_by_edge_label = potential_st[list(map(is_st, potential_st))]

nb_st = len(list_st_by_edge_label)
print('This graph has a total of {} spanning trees'.format(nb_st))

###### Sample from each sampling procedure and count the number of times each spanning tree has been sampled

In [None]:
nb_iter = 10000
modes = ('Aldous-Broder', 'Wilson', 'DPP_exact')

# For each algorithm (mode) count the number of occurence of each spanning tree
# A tree is encoded by it edge label (here number) for example with
# g.edges() = [(0, 2), (0, 3), (2, 1), (2, 3), (2, 4), (3, 4), (1, 4)]
# edge {0,2} has label 1, {2, 3} has label 4

count_ust = Counter({tuple(st_lab):0 for st_lab in list_st_by_edge_label})
dict_count_sampled_st = {md:count_ust.copy() for md in modes}

for md in modes:
    
    ust.flush_samples() # reset the list_of_samples attribute
    for _ in range(nb_iter): # sample nb_iter spanning tree
        ust.sample(md)
      
    # Extract edges of the spanning tree just sampled 
    sampled_st_edges = np.array([sampled_st.edges() for sampled_st in ust.list_of_samples])

    tmp = sampled_st_edges.reshape((nb_iter*(ust.nb_nodes-1), 2)) # Stack all edges
    tmp_labs = np.zeros(tmp.shape[0], dtype=int) # Consider an edge by its label

    # For an undirected graph edge {x,y} = {y,x}.
    # However networkx uses tuples (x,y) or (y,x) and not a set {x,y}
    for ind, ed in enumerate(ust.edges):
        tmp_labs[((tmp == ed) | (tmp == ed[::-1])).all(axis=1)] = ind
        
    tmp_labs = tmp_labs.reshape(nb_iter, (ust.nb_nodes-1)) # Regroup edge labels of the same tree
    tmp_labs.sort(axis=1) # Sort the edge labels to match the keys of the Counter object
   
    dict_count_sampled_st[md].update(map(tuple, tmp_labs)) # Update the counts of spanning trees

###### Display the histogram

In [None]:
fig = plt.figure(figsize=(16,4))

# set width of bar
bar_width = 0.25

# Set position of bar on X axis
pos = np.arange(len(list_st_by_edge_label))
# Make the plot
for i, md in enumerate(modes):
    plt.bar(pos+i*bar_width, list(dict_count_sampled_st[md].values()), width=bar_width, edgecolor='white', label=md)

plt.axhline(y=nb_iter/nb_st)
plt.legend(loc='best')
plt.title('Check uniformity of spanning trees generated after {} samples of each procedure'.format(nb_iter))
plt.show()

#plt.savefig('ust_histo.png')
#plt.savefig('ust_histo.eps')

## Carries Process

###### Choose base $b$ to sample i.i.d. digits in $\{0, \dots, b-1\}$

In [None]:
base = 10 # base
cp = CarriesProcess(base)

size = 100
cp.sample(size)

In [None]:
cp.plot()

In [None]:
cp.plot_vs_bernoullis()

## Poissonized Plancherel measure

###### Choose a $\theta$ to sample a permutation $\sigma \in \mathfrak{S}_N$ with $N \sim \mathcal{P}(\theta)$

In [None]:
theta=150 # Poisson parameter
pp_dpp = PoissonizedPlancherel(theta=theta)
pp_dpp.sample()
pp_dpp.plot()