# Exercise 4 - Tensor Networks
In this exercise, we will inspect the canonical parameterization of a graphical model and calculate the normalization constant to answer inference queries.

Later, we will compare the speed of calculating the normalization constant using different orders of tensor contractions.

In the event of a persistent problem, do not hesitate to contact the course instructors under
- paul.kahlmeyer@uni-jena.de

### Submission

- Deadline of submission:
        27.11.2022
- Submission on [moodle page](https://moodle.uni-jena.de/course/view.php?id=34630)

### Help
In case you cannot solve a task, you can use the saved values within the `help` directory:
- Load arrays with [Numpy](https://numpy.org/doc/stable/reference/generated/numpy.load.html)
```
np.load('help/array_name.npy')
```
- Load functions with [Dill](https://dill.readthedocs.io/en/latest/dill.html)
```
import dill
with open('help/some_func.pkl', 'rb') as f:
    func = dill.load(f)
```

to continue working on the other tasks.

## Graphical Models
Let $p(x)$ be a multivariate categorical on the sample space $\mathcal{X}$.
In the canonical parameterization we define $p$ to be an exponentiated sum of interaction order parameters:
\begin{align}
p(x) = \exp\left(q(x)\right)\,,
\end{align}
where $q(x)$ is a sum of all possible interaction orders
\begin{align}
q(x) = \sum\limits_{k=1}^n\sum\limits_{i=(i_1,\dots,i_k)}q_i(x_{i_1}, \dots, x_{i_k})\,.
\end{align}
In graphical models, we reduce the number of parameters by setting specific interactions $q_i$ to 0.

This notation is a little confusing, so lets exercise trough a **concrete example**.

Consider a multivariate categorical $p(x_0,x_1,x_2,x_3)$.
Furthermore we restrict ourselves to unary and pairwise interaction orders (interactions of order >2 have been set to 0).

This means, that we have single interaction parameter vectors $q_0, q_1, q_2, q_3$ and parwise interaction parameter matrices $q_{01}, q_{02}, q_{03}, q_{12}, q_{13}, q_{23}$.
The $q_i$ hold the (unary) interaction parameters for $x_i$ and $q_{ij}$ holds the interaction parameters for $x_i$ and $x_j$.

With these parameters, the canonical parameterization from above looks like this:
\begin{align}
q(x = [v_0, v_1, v_2, v_3]^T) &=\sum_{i=0}^3 q_i[v_i] + \sum_{j=0, j > i}^3 q_{ij}[v_i, v_j]\\
&=q_0[v_0] + q_1[v_1] + q_2[v_2] + q_3[v_3]\\
&+q_{01}[v_0, v_1] + q_{02}[v_0, v_2] + q_{03}[v_0, v_3]\\
&+q_{12}[v_1, v_2]+q_{13}[v_1, v_3]\\
&+q_{23}[v_2, v_3]\,.
\end{align}



### Task 1

Load $q_i$ and $q_ij$ from the pickeled files `q_i.p` and `q_ij.p` respectively.
How large are the sample spaces for each $x_i$?

In [29]:
import pickle
with open("q_i.p", 'rb') as f:
    q_i = pickle.load(f)
with open("q_ij.p", 'rb') as f:
    q_ij = pickle.load(f)
SAMPLE_SPACE_SIZES = [len(prob_table) for prob_table in q_i]
SAMPLE_SPACE = lambda i: range(SAMPLE_SPACE_SIZES[i])
N_VARS = len(SAMPLE_SPACE_SIZES)
SAMPLE_SPACE_SIZES

[15, 50, 100, 10]

## Normalization Constant

Here we have unnormalized probabilities, so we need to calculate the normalization constant first
\begin{align}
K &= \sum_{x}p(x)\\
&= \sum_{x}\exp\left(q(x)\right)\\
&= \sum_{x}\prod_{i} \exp(q_i[x_i])\prod_{j > i} \exp(q_{ij}[x_i, x_j])\\
&= \sum_{x}\prod_{i} t_i[x_i]\prod_{j > i} t_{ij}[x_i, x_j]\,,
\end{align}
where $t_i = \exp(q_i)$ and $t_{ij} = \exp(q_{ij})$ with the elementwise exponential function.

### Task 2

A straighforward way to calculate this constant is iterating over every $x$ and summing up the $p(x)$.

Calculate $K$ using for loops.

In [30]:
import itertools
import numpy as np

# calculate normalization constant

pairwise_product = lambda x, i, t_ij: np.prod([t_ij[i][j][x[i], x[j]] for j in range(i + 1, N_VARS)])
prob_unnormalized = lambda x, t_i, t_ij: np.prod([t_i[i][x[i]] * pairwise_product(x, i, t_ij) for i in range(N_VARS)])

def norm_const_naive(t_i: list, t_ij: list) -> float:
    '''
    Calculates normalization constant by iterating over each x.

    @Params:
        t_i... unary interaction parameters (exponentiated)
        t_ij... binary interaction parameters (exponentiated)

    @Returns:
        normalization constant
    '''

    norm = 0
    sample_spaces = [SAMPLE_SPACE(i) for i in range(N_VARS)]
    for x in itertools.product(*sample_spaces):
        norm += prob_unnormalized(x, t_i, t_ij)
    return norm


t_i = [np.exp(param_table) for param_table in q_i]
t_ij = [[np.exp(param_table) for param_table in param_tables] for param_tables in q_ij]
norm = norm_const_naive(t_i, t_ij)
norm


159744720.16636336

## Inference Queries

With this normalization constant, we can now actually calculate probabilities and answer inference queries.

### Task 3
Calculate the prior marginal 
\begin{align}
p(x_3)\,.
\end{align}

In [31]:
def prior_marginal(i, t_i, t_ij, norm):
    marginal_probs = np.zeros(SAMPLE_SPACE_SIZES[i])
    for x in itertools.product(*[SAMPLE_SPACE(j) for j in range(N_VARS)]):
        marginal_probs[x[i]] += prob_unnormalized(x, t_i, t_ij) / norm
    return marginal_probs
        
        
p_x_3 = prior_marginal(3, t_i, t_ij, norm)
p_x_3

array([0.17600392, 0.07294889, 0.08339296, 0.10914227, 0.07890277,
       0.10590401, 0.06383303, 0.08919156, 0.07370976, 0.14697082])

### Task 4

Calculate the probability 
\begin{equation}
p(x_2>20)\,.
\end{equation}

In [32]:
p_x_2 = prior_marginal(2, t_i, t_ij, norm)
p_x_2_g20 = 0
for x_2 in range(21, SAMPLE_SPACE_SIZES[2]):
    p_x_2_g20 += p_x_2[x_2]
p_x_2_g20

0.8110435319468282

## Tensor Contraction
Calculating $K$ by iterating over every $x$ is quite slow.
Lets look at how we can speed up this calculation.

We can rewrite the calculation of $K$ as

\begin{align}
K &= \sum_{x}p(x)\\
&= \sum_{x}\prod_{i} \exp(q_i[x_i])\prod_{j > i} \exp(q_{ij}[x_i, x_j])\\
&= \sum_{x}\prod_{i} t_i[x_i]\prod_{j > i} t_{ij}[x_i, x_j]\\
&= \sum_{v_0=1}^{n_0}\sum_{v_1=1}^{n_1}\sum_{v_2=1}^{n_2}\sum_{v_3=1}^{n_3}\prod_{i} t_i[v_i]\prod_{j > i} t_{ij}[v_i, v_j]\,.
\end{align}

In this form, calculating the normalization constant boils down to a single tensor contraction. 

Since contracting tensors in numpy is implemented in C under the hood, we can expect a significant speedup.

### Task 5
Calculate the normalization constant using a **single** contraction using the [Einstein-Summation](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).

For a brief introduction into `einsum`, see [here](https://ajcr.net/Basic-guide-to-einsum/) and [here](https://medium.com/ibm-data-ai/einsum-an-easy-intuitive-way-to-write-tensor-operation-9e12b8a80570).

Make sure that you result is correct by comparing the result to the naive implementation.

In [33]:
# the indices we want to sum and multiply over
indices = list(range(N_VARS))
pairwise_indices = [(i, j) for i in range(N_VARS) for j in range(i + 1, N_VARS)]
# the names of the indices
name = ['i', 'j', 'k', 'l']
# contract over all variables, and multiply pairwise
einsum_notation = ", ".join([name[i] for i in indices] + [name[i] + name[j] for i, j in pairwise_indices]) + " ->"
einsum_args = [t_i[i] for i in indices] + [t_ij[i][j] for i, j in pairwise_indices]
print(einsum_notation)
norm_einsum = np.einsum(einsum_notation, *einsum_args)
assert np.isclose(norm_einsum, norm)


i, j, k, l, ij, ik, il, jk, jl, kl ->


### Task 6

Compare the execution times of calculating $K$ the naive way vs. using `einsum`.

In [34]:
# TODO: compare execution times
%timeit norm_const_naive(t_i, t_ij)
%timeit np.einsum(einsum_notation, *einsum_args)

The slowest run took 5.33 times longer than the fastest. This could mean that an intermediate result is being cached.
1min 46s ± 1min 5s per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.9 ms ± 6.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Contraction order

We see that using contraction speeds up the calculation. This however is not the end of optimization:\
The order of contraction can be permutated, potentially reducing the number of calculations. Here we want to permutate the order in which the variables are marginalized out.

For example for two variables $x_0, x_1$:
\begin{align}
K &= \sum_{v_0=1}^{n_0}\sum_{v_1=1}^{n_1} t_0[v_0]t_1[v_1]t_{01}[v_0, v_1]\\
(1) &= \sum_{v_0=1}^{n_0}t_0[v_0]\sum_{v_1=1}^{n_1}t_1[v_1]t_{01}[v_0, v_1]\\
(2) &= \sum_{v_1=1}^{n_1}t_1[v_1]\sum_{v_0=1}^{n_0}t_0[v_1]t_{01}[v_0, v_1]\\
\end{align}

Can be calculated as (1)
1. Contracting $t_{01}$ and $t_{1}$ over the index $x_1$
2. Contracting the result from 1. with $t_0$ over the index $x_0$

or (2)
1. Contracting $t_{01}$ and $t_{0}$ over the index of $x_0$
2. Contracting the result from 1. with $t_1$ over the index of $x_1$

Depending on the tensor dimensions, one calculation can be faster than the other.


### Task 7

Implement the following function that contracts the tensors in a given order.

As an example for three variables, the order

```
['i', 'j', 'k']
```

with the tensor dictionary

```
tensor_dict = {
'i' : t_i,
'j' : t_j,
'k' : t_k,
'ij' : t_ij,
'ik' : t_ik,
'jk' : t_jk
}
```
will perform the following contractions

1. `tmp = np.einsum('i, ij, ik -> jk', t_i, t_ij, t_ik) # marginalize out i`
2. `tmp = np.einsum('j, jk, jk -> k', t_j, t_jk, tmp) # marginalize out j`
3. `tmp = np.einsum('k, k -> ', t_k, tmp) # marginalize out k`

Make sure that the results are correct and compare the times of different marginalization orders to those from Task 6.

In [35]:
def norm_const_order(order: list, tensor_dict: dict, show_notations=False) -> float:
    '''
    Calculates the normalization constant using tensor contraction with a specific order.

    @Params:
        order... list of variables in the order of their marginalization
        tensor_dict... dict that stores which tensors are for which variable combination

    @Returns:
        normalization constant K

    '''

    for order_index, value_index in enumerate(order):
        # what indices will be left after marginalization
        remaining = order[order_index + 1:]
        # indices should always be sorted in the notation
        sum_over = [value_index] + [value_index + r if value_index < r else r + value_index for r in remaining]
        sum_into = "".join(sorted(remaining))
        einsum_args = [tensor_dict[value_indices] for value_indices in sum_over]
        # if we're past the first marginalization, we have to include the last result
        if order_index > 0:
            last_sum_into = "".join(sorted([value_index] + remaining))
            sum_over.append(last_sum_into)
            einsum_args.append(contracted)
        # build and execute einsum
        einsum_notation = f"{', '.join(sum_over)} -> {sum_into}"
        if show_notations:
            print(einsum_notation)
        contracted = np.einsum(einsum_notation, *einsum_args)
    return contracted


tensor_dict = {name[i]: t_i[i] for i in range(N_VARS)}
tensor_dict |= {name[i] + name[j]: t_ij[i][j] for i, j in pairwise_indices}
# order = ["i", "j", "k", "l"]
order = ["i", "j", "l", "k"]
norm_einsum_ordered = norm_const_order(order, tensor_dict, show_notations=True)
assert np.isclose(norm_einsum_ordered, norm_einsum)


i, ij, il, ik -> jkl
j, jl, jk, jkl -> kl
l, kl, kl -> k
k, k -> 


## Optimal contraction order

We see that the contraction order has quite a lot of effect on the computation times.

In fact, the problem of finding the best contraction order is generally NP-hard and an active area of research.
In Python, the package [opt_einsum](https://optimized-einsum.readthedocs.io/en/stable/) provides heuristics to find an (near-)optimal contraction order.

### Task 8

Use `opt_einsum` to calculate $K$, make sure result is correct.
Again measure the execution time and compare to the other methods.

Note: if you are interested, you can use `opt_einsum.contract_path` to have a look at the optimal contraction order that was used.

In [36]:
# TODO: use opt_einsum to compute K
import opt_einsum
assert np.isclose(norm, opt_einsum.contract(einsum_notation, *einsum_args))

# TODO: timing
%timeit contract(einsum_notation, *einsum_args)
%timeit norm_const_order(order, tensor_dict)
opt_einsum.contract_path(einsum_notation, *einsum_args)

5.59 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.9 ms ± 337 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


([(1, 8), (1, 7), (0, 4), (0, 6), (0, 3), (3, 4), (0, 2), (1, 2), (0, 1)],
   Complete contraction:  i,j,k,l,ij,ik,il,jk,jl,kl->
          Naive scaling:  4
      Optimized scaling:  4
       Naive FLOP count:  7.500e+6
   Optimized FLOP count:  1.542e+6
    Theoretical speedup:  4.864e+0
   Largest intermediate:  1.500e+4 elements
 --------------------------------------------------------------------------------
 scaling        BLAS                current                             remaining
 --------------------------------------------------------------------------------
    2              0               jl,j->jl             i,k,l,ij,ik,il,jk,kl,jl->
    2              0               kl,k->kl               i,l,ij,ik,il,jk,jl,kl->
    2              0               il,i->il                 l,ij,ik,jk,jl,kl,il->
    2              0               il,l->il                   ij,ik,jk,jl,kl,il->
    3              0             jl,ij->jli                     ik,jk,kl,il,jli->
    3     